<a href="https://colab.research.google.com/github/avva14/image_generators/blob/main/conv_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
!pip install pillow --upgrade

In [None]:
!git clone https://github.com/avva14/common_utils.git

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
PATH_TO_TFDS = '/content/gdrive/MyDrive/tensorflow_datasets'
PATH_TO_MODELS = '/content/gdrive/MyDrive/models/moire'
PATH_TO_MOIRE = '/content/gdrive/MyDrive/Patterns/moiredata'

In [None]:
import os
import numpy as np
import cv2 as cv

In [None]:
import zipfile
from zipfile import ZipFile

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from math import ceil

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Parameters

In [None]:
ds = tfds.load('mnist', data_dir=PATH_TO_TFDS, download=False, split=['train', 'test'], shuffle_files=True)
train_set = ds[0].cache().shuffle(1024).repeat().as_numpy_iterator()
test_set = ds[1].cache().repeat().as_numpy_iterator()

In [None]:
rng = np.random.RandomState(1)

In [None]:
SIZE = 208
MNSZ = 28
MAX_NOISE = 0.5
MAX_NUM = 3
num_classes = 10

In [None]:
moirefiles = [os.path.join(PATH_TO_MOIRE, f) for f in os.listdir(PATH_TO_MOIRE)]

In [None]:
from common_utils.unet_generators import UnetMaskTestGenerator, UnetMaskTrainGenerator

## TF datasets

In [None]:
def ugen_test():
    return UnetMaskTestGenerator(test_set, rng, MAX_NUM, MAX_NOISE, MNSZ, SIZE)
def ugen_train():
    return UnetMaskTrainGenerator(train_set, rng, MAX_NUM, moirefiles, MAX_NOISE, MNSZ, SIZE)

In [None]:
dataset_test = tf.data.Dataset.from_generator(
    ugen_test,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32)
    )
)
dataset_train = tf.data.Dataset.from_generator(
    ugen_train,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32)
    )
)

In [None]:
BATCHSIZE = 128

In [None]:
batched_test = dataset_test.batch(BATCHSIZE)
batched_train = dataset_train.batch(BATCHSIZE)

In [None]:
batched_test_iterator = batched_test.as_numpy_iterator()
batched_train_iterator = batched_train.as_numpy_iterator()

## Model

In [None]:
from keras.models import Model, load_model

In [None]:
from keras.layers import Input, Layer, Conv2D, Conv2DTranspose, Lambda
from keras.layers import Dropout, MaxPooling2D
from keras.layers import ReLU, Concatenate, Cropping2D, UpSampling2D

In [None]:
class Conv2DPeriodic(Layer):
    def __init__(self, filters, kernel_size):
        super(Conv2DPeriodic, self).__init__()
        margin = (kernel_size[0] - 1) // 2
        self.la1 = Lambda(lambda x:x[:,-margin:,:,:])
        self.la2 = Lambda(lambda x:x[:,:margin,:,:])
        self.conv = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')
        self.merge = Concatenate(axis=1)
        self.crop = Cropping2D((margin,0))

    def call(self, x):
        xt = self.la1(x)
        xb = self.la2(x)
        xe = self.merge([xt,x,xb])
        x = self.conv(xe)
        x = self.crop(x)
        return x
class ConvTransposePeriodic(Layer):
    def __init__(self, filters, kernel_size, strides):
        super(ConvTransposePeriodic, self).__init__()
        margin = (kernel_size[0] - 1) // 2
        self.la1 = Lambda(lambda x: x[:,-margin:,:,:])
        self.la2 = Lambda(lambda x: x[:,:margin,:,:])
        self.cont = Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')
        self.mer = Concatenate(axis=1)
        self.crp = Cropping2D((2*margin,0))

    def call(self, x):
        xt = self.la1(x)
        xb = self.la2(x)
        xm = self.mer([xt,x,xb])
        x = self.cont(xm)
        x = self.crp(x)
        return x

In [None]:
DROP = 0.1
hidden_dim = 16

In [None]:
class ContractingBlock(Layer):
    def __init__(self, input_channels, drop=DROP):
        super(ContractingBlock, self).__init__()
        self.conv1 = Conv2DPeriodic(input_channels, (5,5))
        self.conv2 = Conv2DPeriodic(input_channels, (5,5))
        self.activation = ReLU()
        self.maxpool = MaxPooling2D(pool_size=(2,2), strides=(2,2))
        self.drop = Dropout(drop)

    def call(self, x):
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.maxpool(x)
        x = self.drop(x)
        return x

In [None]:
class FeatureMapBlock(Layer):
    def __init__(self, output_channels, name):
        super(FeatureMapBlock, self).__init__(name=name)
        self.conv = Conv2D(filters=output_channels, kernel_size=(1,1))

    def call(self, x):
        x = self.conv(x)
        return x

In [None]:
class ExpandingBlock(Layer):
    def __init__(self, input_channels, drop=DROP):
        super(ExpandingBlock, self).__init__()
        self.upsample = ConvTransposePeriodic(input_channels, kernel_size=(5,5), strides=(2,2))
        self.conv1 = Conv2DPeriodic(input_channels, kernel_size=(5,5))
        self.conv2 = Conv2DPeriodic(input_channels, kernel_size=(5,5))
        self.activation = ReLU()
        self.drop = Dropout(drop)

    def call(self, x, skip_con_x):
        x = self.upsample(x)
        x = Concatenate(axis=-1)([x, skip_con_x])
        x = self.drop(x)
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(x)
        return x

In [None]:
def unet(hidden_channels, num_stages, out_channels):
    inputs = Input(shape=(SIZE,SIZE,1))
    p = FeatureMapBlock(hidden_channels, 'input')(inputs)

    xlist = []
    nc = hidden_channels
    for _ in range(num_stages):
        xlist.append(p)
        p = ContractingBlock(nc)(p)
        nc *= 2

    p = FeatureMapBlock(nc, 'bottleneck')(p)

    for _ in range(num_stages):
        nc /= 2
        z = xlist.pop()
        p = ExpandingBlock(nc)(p, z)

    outputs = FeatureMapBlock(out_channels, 'output')(p)
    model = Model(inputs=inputs, outputs=outputs, name='unet')

    return model

In [None]:
unet_semantic = unet(hidden_dim, 4, num_classes+1)
unet_semantic.summary()

In [None]:
aa, mm = next(batched_train_iterator)
aa.shape, mm.shape

In [None]:
resaa = unet_semantic(aa, training=False)
resaa.shape

## Training

In [None]:
from keras.losses import SparseCategoricalCrossentropy
from keras.metrics import Metric, SparseCategoricalAccuracy
from keras.optimizers import Adam

In [None]:
sprs_metr = SparseCategoricalAccuracy()
sprs_loss = SparseCategoricalCrossentropy(from_logits=True)

In [None]:
class NotZeroAccuracy(Metric):
    def __init__(self, name="notzero_accuracy", **kwargs):
        super(NotZeroAccuracy, self).__init__(name=name, **kwargs)
        self.not_zeros = self.add_weight(name='nz', initializer='zeros')

    def update_state(self, y_true, y_pred):
        truemaxposes = tf.cast(tf.squeeze(y_true), tf.int64)
        predmaxposes = tf.argmax(y_pred, axis=-1)
        true_nonzero = tf.greater(truemaxposes, 0)
        where_equals = tf.logical_and(true_nonzero, tf.equal(predmaxposes, truemaxposes))

        denom = tf.math.count_nonzero(true_nonzero)
        numer = tf.math.count_nonzero(where_equals)
        self.not_zeros.assign_add(tf.cast(tf.divide(numer, denom), tf.float32))
    def result(self):
        return self.not_zeros
    def reset_state(self):
        self.not_zeros.assign(0.)


In [None]:
sprs_nz_metr = NotZeroAccuracy()

In [None]:
def nonzeroaccuracy(y_true, y_pred):
    truemaxposes = tf.cast(tf.squeeze(y_true), tf.int64)
    predmaxposes = tf.argmax(y_pred, axis=-1)
    true_nonzero = tf.greater(truemaxposes, 0)
    where_equals = tf.logical_and(true_nonzero, tf.equal(predmaxposes, truemaxposes))

    denom = tf.math.count_nonzero(true_nonzero)
    numer = tf.math.count_nonzero(where_equals)
    result = tf.divide(numer, denom)

    return tf.reduce_mean(result)

In [None]:
unet_semantic.compile(optimizer=Adam(learning_rate=0.001), loss=sprs_loss, metrics=[sprs_metr, nonzeroaccuracy])

In [None]:
unet_semantic.compile(optimizer=Adam(learning_rate=0.0001), loss=sprs_loss, metrics=[sprs_metr, nonzeroaccuracy])

In [None]:
unet_semantic.evaluate(batched_test, steps=10)

In [None]:
history = unet_semantic.fit(
    batched_train, steps_per_epoch=120, epochs=10, validation_data=batched_test, validation_steps=1)

In [None]:
unet_semantic.save(os.path.join(PATH_TO_MODELS, "conv_unet_v00"), "tf")

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from math import ceil

In [None]:
set_colors = ["#FFFFFF", "#FFAAAA", "#FFAA77", "#FFAA22", "#AAAA00", "#AA7700", "#FF4400", "#FF0000", "#AA0000", "#770000", "#220000"]
cmap = ListedColormap(set_colors, name="custom_cmap")

In [None]:
def display_batch(abatch, mbatch, nr, nc, NDIV=13):

    PSIZ = SIZE // NDIV

    fig, axxes = plt.subplots(ncols=nc,
                              nrows=nr,
                              figsize=(3*nc, 3*nr),
                              sharey=True, sharex=True)

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):


        if i % 2 == 0:
            j = i // 2
            a = abatch[j]
            m = np.squeeze(mbatch[j]) if mbatch[j].shape[-1]==1 else np.argmax(mbatch[j], axis=-1)
            ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)
        else:
            ax.imshow(m, aspect=1, cmap=cmap, vmin=0, vmax=10, interpolation=None)

        ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_xticklabels([])
        ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_ylim(0,SIZE-1)
        ax.set_xlim(0,SIZE-1)
        ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)

    fig.tight_layout()
    plt.show()

In [None]:
resaa = unet_semantic(aa, training=False)
resaa.shape

In [None]:
display_batch(aa, mm, 3, 4)

In [None]:
display_batch(aa, resaa, 3, 4)