<a href="https://colab.research.google.com/github/avva14/image_generators/blob/main/conv_regression_single.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
!pip install pillow --upgrade

In [None]:
!git clone https://github.com/avva14/common_utils.git

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
PATH_TO_TFDS = '/content/gdrive/MyDrive/tensorflow_datasets'
PATH_TO_MODELS = '/content/gdrive/MyDrive/models/moire'
PATH_TO_MOIRE = '/content/gdrive/MyDrive/Patterns/moiredata'

In [None]:
import os
import numpy as np
import cv2 as cv

In [None]:
import zipfile
from zipfile import ZipFile

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from math import ceil

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Parameters

In [None]:
ds = tfds.load('mnist', data_dir=PATH_TO_TFDS, download=False, split=['train', 'test'], shuffle_files=True)
train_set = ds[0].cache().shuffle(1024).repeat().as_numpy_iterator()
test_set = ds[1].cache().repeat().as_numpy_iterator()

In [None]:
rng = np.random.RandomState(1)

In [None]:
SIZE = 208
MNSZ = 28
MAX_NOISE = 0.5
num_classes = 10

In [None]:
moirefiles = [os.path.join(PATH_TO_MOIRE, f) for f in os.listdir(PATH_TO_MOIRE)]

In [None]:
from common_utils.single_generators import SingleTestGenerator, SingleTrainGenerator

## TF datasets

In [None]:
def single_test():
    return SingleTestGenerator(test_set, rng, MAX_NOISE, MNSZ, SIZE)

def single_train():
    return SingleTrainGenerator(train_set, rng, moirefiles, MAX_NOISE, MNSZ, SIZE)

In [None]:
dataset_test = tf.data.Dataset.from_generator(
    single_test,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        (tf.TensorSpec(shape=(1), dtype=np.float32), tf.TensorSpec(shape=(1,4), dtype=np.float32))
    )
)
dataset_train = tf.data.Dataset.from_generator(
    single_train,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        (tf.TensorSpec(shape=(1), dtype=np.float32), tf.TensorSpec(shape=(1,4), dtype=np.float32))
    )
)

In [None]:
BATCHSIZE = 128

In [None]:
batched_test = dataset_test.batch(BATCHSIZE)
batched_train = dataset_train.batch(BATCHSIZE)

In [None]:
batched_test_iterator = batched_test.as_numpy_iterator()
batched_train_iterator = batched_train.as_numpy_iterator()

## Model

In [None]:
from keras.models import Model, load_model

In [None]:
from keras.layers import Layer, Conv2D, Flatten, Reshape
from keras.layers import Input, Dropout, MaxPooling2D, GlobalAveragePooling2D
from keras.layers import ReLU, Dense, Lambda, Concatenate, Cropping2D

In [None]:
class Conv2DPeriodic(Layer):
    def __init__(self, filters, kernel_size):
        super(Conv2DPeriodic, self).__init__()
        margin = (kernel_size[0] - 1) // 2
        self.la1 = Lambda(lambda x:x[:,-margin:,:,:])
        self.la2 = Lambda(lambda x:x[:,:margin,:,:])
        self.conv = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')
        self.merge = Concatenate(axis=1)
        self.crop = Cropping2D((margin,0))

    def call(self, x):
        xt = self.la1(x)
        xb = self.la2(x)
        xe = self.merge([xt,x,xb])
        x = self.conv(xe)
        x = self.crop(x)
        return x

In [None]:
class ContractingBlock(Layer):
    def __init__(self, num_channles):
        super(ContractingBlock, self).__init__()
        self.con1 = Conv2DPeriodic(num_channles, kernel_size=(5,5))
        self.con2 = Conv2DPeriodic(num_channles, kernel_size=(5,5))
        self.relu = ReLU()
        self.pool = MaxPooling2D((2,2))
        self.drop = Dropout(0.1)
    def call(self, x):
        x = self.con1(x)
        x = self.relu(x)
        x = self.con2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.drop(x)
        return x

In [None]:
# class ContractingBlock(Layer):
#     def __init__(self, num_channles):
#         super(ContractingBlock, self).__init__()
#         self.con1 = Conv2D(filters=num_channles, kernel_size=(5,5), padding='same', activation='relu')
#         self.con2 = Conv2D(filters=num_channles, kernel_size=(5,5), padding='same', activation='relu')
#         self.pool = MaxPooling2D((2,2))
#         self.drop = Dropout(0.1)
#     def call(self, x):
#         x = self.con1(x)
#         x = self.con2(x)
#         x = self.pool(x)
#         x = self.drop(x)
#         return x

In [None]:
def odmodel(numclass):
    inputs = Input((SIZE,SIZE,1))
    x = ContractingBlock(8)(inputs)
    x = ContractingBlock(16)(x)
    x = ContractingBlock(32)(x)
    x = ContractingBlock(64)(x)
    x = ContractingBlock(128)(x)
    y = Flatten()(x)
    x = GlobalAveragePooling2D(name='pooling')(x)
    x = Dense(128, activation='relu')(x)
    outclass = Dense(numclass)(x)
    y = Dense(16, activation='relu')(y)
    y = Dense(4)(y)
    outshape = Reshape((1,-1))(y)
    model = Model(inputs=inputs, outputs=[outclass, outshape], name='test')
    return model

In [None]:
md = odmodel(num_classes)
md.summary()

In [None]:
aa, llbb = next(batched_train_iterator)
aa.shape

In [None]:
ll, bb = llbb
ll.shape, bb.shape

In [None]:
resclassaa, resposaa = md(aa, training=False)
resclassaa.shape, resposaa.shape

## Training

In [None]:
from keras.losses import Loss, SparseCategoricalCrossentropy, MeanSquaredError
from keras.metrics import SparseCategoricalAccuracy, MeanAbsoluteError
from keras.optimizers import Adam

In [None]:
class MeanSquaredErrorPeriodic(Loss):
    def __init__(self, size, name="mse_periodic", **kwargs):
        super().__init__(name=name, **kwargs)
        self.adding = tf.constant([0,size,0,0], dtype=tf.float32)
    def call(self, y_true, y_pred):
        ydiff = y_pred-y_true
        stacked = tf.stack([ydiff, ydiff+self.adding, ydiff-self.adding])
        miniz = tf.reduce_min(tf.square(stacked), axis=0)
        mse = tf.reduce_mean(miniz)
        return mse

In [None]:
mse_loss = MeanSquaredError()
mse_loss_periodic = MeanSquaredErrorPeriodic(SIZE)
sprs_loss = SparseCategoricalCrossentropy(from_logits=True)

In [None]:
mse_metr = MeanAbsoluteError()
sprs_metr = SparseCategoricalAccuracy()

In [None]:
sprs_loss(ll, resclassaa), mse_loss(bb, resposaa), mse_loss_periodic(bb, resposaa)

In [None]:
class ModelTrain(Model):
    def __init__(self, mdl, **kwargs):
        super().__init__(**kwargs)
        self.model = mdl

    def compile(self, optimizer, lossclass, losspos, metricclass, metricpos, weight, **kwargs):
        super().compile(**kwargs)
        self.opt = optimizer
        self.loss_class = lossclass
        self.loss_posit = losspos
        self.metr_class = metricclass
        self.metr_posit = metricpos
        self.w = weight

    def train_step(self, batch, **kwargs):
        X, y = batch
        label, box = y
        with tf.GradientTape() as tape:
            p = self.model(X)
            loss1 = self.loss_class(label, p[0])
            loss2 = self.loss_posit(box, p[1])
            loss = loss1 + self.w * loss2
            grad = tape.gradient(loss, self.model.trainable_variables)
        self.opt.apply_gradients(zip(grad, self.model.trainable_variables))
        acc1 = self.metr_class(label, p[0])
        acc2 = self.metr_posit(box, p[1])
        return {"loss":loss, "classloss":loss1, "posloss":loss2, "classacc":acc1, "posacc":acc2}

    def test_step(self, batch, **kwargs):
        X, y = batch
        label, box = y
        p = self.model(X, training=False)
        loss1 = self.loss_class(label, p[0])
        loss2 = self.loss_posit(box, p[1])
        loss = loss1 + self.w * loss2
        acc1 = self.metr_class(label, p[0])
        acc2 = self.metr_posit(box, p[1])
        return {"loss":loss, "classloss":loss1, "posloss":loss2, "classacc":acc1, "posacc":acc2}

    def call(self, X, **kwargs):
        return self.model(X, **kwargs)

In [None]:
md_train = ModelTrain(md)

In [None]:
md_train.compile(Adam(learning_rate=0.001), sprs_loss, mse_loss, sprs_metr, mse_metr, 0.00001)

In [None]:
md_train.compile(Adam(learning_rate=0.001), sprs_loss, mse_loss_periodic, sprs_metr, mse_metr, 0.00001)

In [None]:
md_train.compile(Adam(learning_rate=0.001), sprs_loss, mse_loss_periodic, sprs_metr, mse_metr, 0.0001)

In [None]:
md_train.compile(Adam(learning_rate=0.001), sprs_loss, mse_loss_periodic, sprs_metr, mse_metr, 0.001)

In [None]:
md_train.evaluate(batched_test, steps=5)

In [None]:
history = md_train.fit(
    batched_train, steps_per_epoch=120, epochs=20, validation_data=batched_test, validation_steps=1)

In [None]:
md.save(os.path.join(PATH_TO_MODELS, "conv_regression_v01"), "tf")

In [None]:
resllbb = md(aa, training=False)
resclassaa, resposaa = resllbb
resclassaa.shape, resposaa.shape

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from math import ceil

In [None]:
def rectangperiodic(bx, edgecolor='r'):
    '''
    bx -- list of boxes [x, y, w, h]
    Outputs list of rectangels considering periodic boundary conditions
    '''
    res = []
    ypos = bx[1] % SIZE
    if (ypos + bx[3] < SIZE):
        rect = Rectangle((bx[0], ypos), bx[2], bx[3], edgecolor=edgecolor, fill=False)
        res.append(rect)
    else:
        rect = Rectangle((bx[0], ypos), bx[2], SIZE-ypos, edgecolor=edgecolor, fill=False)
        res.append(rect)
        rect = Rectangle((bx[0], 0), bx[2], ypos+bx[3]-SIZE, edgecolor=edgecolor, fill=False)
        res.append(rect)
    return res
def display_generated(generator, nr, nc, NDIV=13):

    PSIZ = SIZE // NDIV

    fig, axxes = plt.subplots(ncols=nc,
                              nrows=nr,
                              figsize=(3*nc, 3*nr),
                              sharey=True, sharex=True)

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):
        a, lb = next(generator)
        l, b = lb
        l = np.squeeze(l).astype(np.int32)

        ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)

        for bx in b:
            rcts = rectangperiodic(bx)
            for rct in rcts:
                ax.add_patch(rct)

        ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_xticklabels([])
        ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_ylim(0,SIZE-1)
        ax.set_xlim(0,SIZE-1)
        ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        ax.set_title(l)

    fig.tight_layout()
    plt.show()
def display_batch(abatch, lbbatch, nr, nc, NDIV=13):

    PSIZ = SIZE // NDIV

    fig, axxes = plt.subplots(ncols=nc,
                              nrows=nr,
                              figsize=(3*nc, 3*nr),
                              sharey=True, sharex=True)

    ll, bb = lbbatch

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):
        a = abatch[i]
        b = bb[i]
        if ll.shape[-1]==1:
            l = np.squeeze(ll[i]).astype(np.int32)
        else:
            l = np.argmax(ll[i])

        ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)

        for bx in b:
            rcts = rectangperiodic(bx)
            for rct in rcts:
                ax.add_patch(rct)

        ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_xticklabels([])
        ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_ylim(0,SIZE-1)
        ax.set_xlim(0,SIZE-1)
        ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        ax.set_title(l)

    fig.tight_layout()
    plt.show()

In [None]:
display_batch(aa, llbb, 2, 3)

In [None]:
display_batch(aa, resllbb, 2, 3)