<a href="https://colab.research.google.com/github/avva14/image_generators/blob/main/conv_solidvit_semantic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
!pip install pillow --upgrade

In [None]:
!git clone https://github.com/avva14/common_utils.git

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
PATH_TO_TFDS = '/content/gdrive/MyDrive/tensorflow_datasets'
PATH_TO_MODELS = '/content/gdrive/MyDrive/models/moire'
PATH_TO_MOIRE = '/content/gdrive/MyDrive/Patterns/moiredata'

In [None]:
import os
import numpy as np
import cv2 as cv

In [None]:
import zipfile
from zipfile import ZipFile

In [None]:
import matplotlib.pyplot as plt
from math import ceil

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Parameters

In [None]:
ds = tfds.load('mnist', data_dir=PATH_TO_TFDS, download=False, split=['train', 'test'], shuffle_files=True)
train_set = ds[0].cache().shuffle(1024).repeat().as_numpy_iterator()
test_set = ds[1].cache().repeat().as_numpy_iterator()

In [None]:
rng = np.random.RandomState(1)
rng_safe = np.random.RandomState(21)

In [None]:
SIZE = 208
MNSZ = 28
NDIV = 13
NDIV2 = NDIV*NDIV
PSIZ = SIZE // NDIV
MAX_NOISE = 0.5
num_classes = 10

In [None]:
moirefiles = [os.path.join(PATH_TO_MOIRE, f) for f in os.listdir(PATH_TO_MOIRE)]

In [None]:
from common_utils.vit_generators import VitSolidTrainGenerator, VitSolidTestGenerator

## TF datasets

In [None]:
def vgen_test():
    return VitSolidTestGenerator(test_set, rng, 2, num_classes, MAX_NOISE, MNSZ, SIZE, NDIV)
def vgen_train():
    return VitSolidTestGenerator(train_set, rng_safe, 2, num_classes, MAX_NOISE, MNSZ, SIZE, NDIV)

In [None]:
def vgen_train():
    return VitSolidTrainGenerator(train_set, rng, 2, num_classes, moirefiles, MAX_NOISE, MNSZ, SIZE, NDIV)

In [None]:
dataset_test = tf.data.Dataset.from_generator(
    vgen_test,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        tf.TensorSpec(shape=(NDIV2), dtype=np.float32)
    )
)
dataset_train = tf.data.Dataset.from_generator(
    vgen_train,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        tf.TensorSpec(shape=(NDIV2), dtype=np.float32)
    )
)

In [None]:
BATCHSIZE = 128

In [None]:
batched_test = dataset_test.batch(BATCHSIZE)
batched_train = dataset_train.batch(BATCHSIZE)

In [None]:
batched_test_iterator = batched_test.as_numpy_iterator()
batched_train_iterator = batched_train.as_numpy_iterator()

In [None]:
aa, mm = batched_train_iterator.next()
aa.shape, mm.shape

In [None]:
aa, mm = batched_test_iterator.next()
aa.shape, mm.shape

## Visualize data

In [None]:
def display_batch(abatch, mbatch, nr, nc):

    fig, axxes = plt.subplots(ncols=2*nc,
                              nrows=nr,
                              figsize=(6*nc, 3*nr),
                              sharey=False, sharex=False)

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):

        if i % 2 == 0:
            j = i // 2
            a = abatch[j]
            m = mbatch[j].astype(np.int32) if len(mbatch[j].shape)==1 else np.argmax(mbatch[j], axis=-1)

            ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)

            ixes = np.where(m > 0)[0]
            ax.scatter(PSIZ*(ixes % NDIV) + PSIZ//2, PSIZ*(ixes // NDIV) + PSIZ//2, s=2)

            ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
            ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
            ax.set_ylim(0,SIZE-1)
            ax.set_xlim(0,SIZE-1)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        else:
            for p in ixes:
                ax.text((p%NDIV)+0.3,(p//NDIV)+0.3,f'{m[p]-1}')

            ax.set_yticks(np.arange(0, NDIV+1))
            ax.set_xticks(np.arange(0, NDIV+1))
            ax.set_ylim(0,NDIV)
            ax.set_xlim(0,NDIV)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)

    fig.tight_layout()
    plt.show()

In [None]:
def displayres_batch(abatch, mbatch, resbatch, nr, nc):

    fig, axxes = plt.subplots(ncols=3*nc,
                              nrows=nr,
                              figsize=(9*nc, 3*nr),
                              sharey=False, sharex=False)

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):

        if i % 3 == 0:
            j = i // 3
            a = abatch[j]
            m = mbatch[j].astype(np.int32)
            r = np.argmax(resbatch[j], axis=-1)

            ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)

            ixes = np.where(m > 0)[0]
            jxes = np.where(r > 0)[0]
            ax.scatter(PSIZ*(ixes % NDIV) + PSIZ//2, PSIZ*(ixes // NDIV) + PSIZ//2, s=2)

            ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
            ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
            ax.set_ylim(0,SIZE-1)
            ax.set_xlim(0,SIZE-1)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        elif i % 3 == 1:
            for p in ixes:
                ax.text((p%NDIV)+0.3,(p//NDIV)+0.3,f'{m[p]-1}')

            ax.set_yticks(np.arange(0, NDIV+1))
            ax.set_xticks(np.arange(0, NDIV+1))
            ax.set_ylim(0,NDIV)
            ax.set_xlim(0,NDIV)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        else:
            for p in jxes:
                col = 'black' if r[p]==m[p] else 'red'
                ax.text((p%NDIV)+0.3,(p//NDIV)+0.3,f'{r[p]-1}',c=col)

            ax.set_yticks(np.arange(0, NDIV+1))
            ax.set_xticks(np.arange(0, NDIV+1))
            ax.set_ylim(0,NDIV)
            ax.set_xlim(0,NDIV)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)

    fig.tight_layout()
    plt.show()

In [None]:
display_batch(aa, mm, 3, 2)

## Attention mask

In [None]:
from common_utils.picutils import intersection, intersectionp

In [None]:
def nnneighbors(p1, p2):
    x1 = p1 % NDIV
    y1 = p1 // NDIV
    x2 = p2 % NDIV
    y2 = p2 // NDIV
    closex = intersection(x1-1, x1+1, x2, x2)
    closey = intersectionp(y1-1, y1+1, y2, y2, NDIV)
    return closex and closey

In [None]:
atten_mask = np.reshape([nnneighbors(pp//NDIV2, pp%NDIV2) for pp in range(NDIV2*NDIV2)], (NDIV2,NDIV2)).astype(bool)

In [None]:
PSIZ

## Visualize mask

In [None]:
ixtoshow = np.where(mm[0]>0)[0]
nc = 2
nr = ceil(len(ixtoshow)/nc)

fig, axxes = plt.subplots(ncols=nc,
                          nrows=nr,
                          figsize=(2*nc, 2*nr),
                          sharey=True, sharex=True)

axxes = np.ravel(axxes)

for i, x in enumerate(axxes):
    if i >= len(ixtoshow):
        x.axis('off')
        continue
    p = ixtoshow[i]
    x.imshow(1-aa[0], aspect=1, cmap='gray')
    x.scatter((np.where(atten_mask[p])[0] % NDIV)*PSIZ+PSIZ//2,
              (np.where(atten_mask[p])[0] // NDIV)*PSIZ+PSIZ//2)
    x.set_ylim(0,SIZE)
    x.set_xlim(0,SIZE)
    x.set_yticks(PSIZ*np.arange(0, NDIV+1))
    x.set_xticks(PSIZ*np.arange(0, NDIV+1))
    x.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)

fig.tight_layout()
plt.show()

## Model

In [None]:
from keras.models import Model, load_model
from keras.layers import Input

In [None]:
from keras.layers import Conv2D, Lambda, Cropping2D, Concatenate, ReLU, MaxPooling2D

In [None]:
from keras.layers import Layer, Embedding
from keras.layers import LayerNormalization, MultiHeadAttention, Add, Flatten, Dropout, Dense
from keras.layers import Reshape, Permute

In [None]:
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
eps = 1e-6
DROP = 0.1

In [None]:
class Conv2DPeriodic(Layer):
    def __init__(self, filters, kernel_size):
        super(Conv2DPeriodic, self).__init__()
        margin = (kernel_size[0] - 1) // 2
        self.la1 = Lambda(lambda x:x[:,-margin:,:,:])
        self.la2 = Lambda(lambda x:x[:,:margin,:,:])
        self.conv = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')
        self.merge = Concatenate(axis=1)
        self.crop = Cropping2D((margin,0))

    def call(self, x):
        xt = self.la1(x)
        xb = self.la2(x)
        xe = self.merge([xt,x,xb])
        x = self.conv(xe)
        x = self.crop(x)
        return x

In [None]:
class ContractingBlock(Layer):
    def __init__(self, num_channles):
        super(ContractingBlock, self).__init__()
        self.con1 = Conv2DPeriodic(num_channles, kernel_size=(5,5))
        self.con2 = Conv2DPeriodic(num_channles, kernel_size=(5,5))
        self.relu = ReLU()
        self.pool = MaxPooling2D((2,2))
        self.drop = Dropout(0.1)
    def call(self, x):
        x = self.con1(x)
        x = self.relu(x)
        x = self.con2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.drop(x)
        return x

In [None]:
class ConvPatcherEmbedder(Model):
    def __init__(self, imgsize, projectiondim, patchnum):
        super(ConvPatcherEmbedder, self).__init__()

        self.dim = projectiondim
        self.patchnum2 = patchnum*patchnum
        self.patchsize = imgsize // patchnum

        self.conv1 = ContractingBlock(projectiondim//2)
        self.conv2 = ContractingBlock(projectiondim//2)
        self.conv3 = ContractingBlock(projectiondim)
        self.conv4 = ContractingBlock(projectiondim)
        self.reshape = Reshape((-1,projectiondim))
        self.positions = tf.range(start=0, limit=self.patchnum2, delta=1)

        self.position_embedding = Embedding(input_dim=self.patchnum2, output_dim=projectiondim)

    def call(self, inputs):

        coded = self.conv1(inputs)
        coded = self.conv2(coded)
        coded = self.conv3(coded)
        coded = self.conv4(coded)
        coded = self.reshape(coded)
        pos_embed = self.position_embedding(self.positions)

        emd = coded + pos_embed
        return emd
    def get_config(self):
        return {"project_dim": self.dim, "patch_size": self.patchsize}

In [None]:
pe = ConvPatcherEmbedder(SIZE, projection_dim, NDIV)

In [None]:
respe = pe(aa, training=False)
respe.shape

In [None]:
pe.summary()

In [None]:
class SkippedMultiHeadAttention(Layer):
    def __init__(self, numheads, projectiondim, droprate, mask=None):
        super(SkippedMultiHeadAttention, self).__init__()
        self.ln = LayerNormalization(epsilon=eps)
        self.add = Add()
        self.mha = MultiHeadAttention(num_heads=numheads, key_dim=projectiondim, dropout=droprate)
        self.mask = None
        if mask is not None:
            self.mask = tf.cast(tf.convert_to_tensor(mask), tf.bool)

    def call(self, x):
        x1 = self.ln(x)
        ao, scores = self.mha(x1, x1, attention_mask=self.mask, return_attention_scores=True)
        x2 = self.add([ao, x1])
        return x2, scores

class SkippedMultiLayer(Layer):
    def __init__(self, transformerunits, dropoutrate):
        super(SkippedMultiLayer, self).__init__()
        self.ln = LayerNormalization(epsilon=eps)
        self.add = Add()
        self.drop = Dropout(dropoutrate)
        self.denses = []
        for units in transformerunits:
            self.denses.append(Dense(units, activation=tf.nn.gelu))

    def call(self, x2):
        x3 = self.ln(x2)
        for l in self.denses:
            x3 = l(x3)
            x3 = self.drop(x3)
        x = self.add([x3, x2])
        return x

class OutputMultiLayer(Layer):
    def __init__(self, transformerunits, dropoutrate):
        super(OutputMultiLayer, self).__init__()
        self.ln = LayerNormalization(epsilon=eps)
        self.flat = Flatten()
        self.drop = Dropout(dropoutrate)
        self.denses = []
        for units in transformerunits:
            self.denses.append(Dense(units, activation=tf.nn.gelu))

    def call(self, x):
        x = self.ln(x)
        x = self.flat(x)
        for l in self.denses:
            x = l(x)
            x = self.drop(x)
        return x

class SelfAttention(Model):
    def __init__(self, numheads, projectiondim, attenmask, transformerunits, drop):
        super(SelfAttention, self).__init__()
        self.dim = projectiondim
        self.heads = numheads
        self.units = transformerunits
        self.drop = drop
        self.mha = []
        self.ml = []
        for units in transformerunits:
            self.mha.append(SkippedMultiHeadAttention(numheads, projectiondim, drop, attenmask))
            self.ml.append(SkippedMultiLayer(transformerunits, drop))

    def call(self, x):
        for a, b in zip(self.mha, self.ml):
            x, s = a(x)
            x = b(x)
        return x, s

    def get_config(self):
        return {
            "project_dim": self.dim,
            "num_heads": self.heads,
            "transformer_units": self.units,
            "drop": self.drop,
            }

In [None]:
sa = SelfAttention(num_heads, projection_dim, atten_mask, transformer_units, DROP)

In [None]:
ressa, resscore = sa(respe, training=False)
ressa.shape, resscore.shape

In [None]:
sa.summary()

In [None]:
class ToSemanticFeature(Model):
    def __init__(self, hidden, numclass):
        super(ToSemanticFeature, self).__init__()
        self.dens1 = Dense(hidden, activation='relu')
        self.dens2 = Dense(numclass)

    def call(self, x):
        x = self.dens1(x)
        x = self.dens2(x)
        return x

In [None]:
sf = ToSemanticFeature(2*projection_dim, num_classes+1)

In [None]:
ressf = sf(ressa, training=False)
ressf.shape

In [None]:
sf.summary()

## Load model

In [None]:
from keras.models import Model, load_model
from keras.layers import Input

In [None]:
# pe = load_model(os.path.join(PATH_TO_MODELS, "pe_semantic_v05"))
# sa = load_model(os.path.join(PATH_TO_MODELS, "sa_semantic_v05"))
# sf = load_model(os.path.join(PATH_TO_MODELS, "sf_semantic_v05"))

## Assemble model

In [None]:
def create_vit_semantic(encod, selfattention, finall):
    inputs = Input(shape=(SIZE, SIZE, 1))

    encoded_patches = encod(inputs)
    features, scores = selfattention(encoded_patches)

    logts = finall(features)

    model = Model(inputs=inputs, outputs=[logts, scores], name='vit_detect')
    return model

In [None]:
vit_model = create_vit_semantic(pe, sa, sf)
vit_model.summary()

In [None]:
resaa, resscore = vit_model(aa, training=False)
resaa.shape

## Training

In [None]:
from keras.losses import SparseCategoricalCrossentropy
from keras.metrics import SparseCategoricalAccuracy
from keras.optimizers import Adam

In [None]:
cat_metr = SparseCategoricalAccuracy()
cat_loss = SparseCategoricalCrossentropy(from_logits=True)

In [None]:
cat_loss(mm, resaa)

In [None]:
def nonzeroaccuracy(y_true, y_pred):
    truemaxposes = tf.cast(tf.squeeze(y_true), tf.int64)
    predmaxposes = tf.argmax(y_pred, axis=-1)
    true_nonzero = tf.greater(truemaxposes, 0)
    where_equals = tf.logical_and(true_nonzero, tf.equal(predmaxposes, truemaxposes))

    denom = tf.math.count_nonzero(true_nonzero)
    numer = tf.math.count_nonzero(where_equals)
    result = tf.divide(numer, denom)

    return tf.reduce_mean(result)

In [None]:
class ModelTrain(Model):
    def __init__(self, mdl, **kwargs):
        super().__init__(**kwargs)
        self.model = mdl

    def compile(self, optimizer, losspos, metricpos, **kwargs):
        super().compile(**kwargs)
        self.opt = optimizer
        self.loss_posit = losspos
        self.metr_posit = metricpos

    def train_step(self, batch, **kwargs):
        X, y = batch
        with tf.GradientTape() as tape:
            bx, sc = self.model(X)
            loss = self.loss_posit(y, bx)
            grad = tape.gradient(loss, self.model.trainable_variables)
        self.opt.apply_gradients(zip(grad, self.model.trainable_variables))
        acc = self.metr_posit(y, bx)
        return {"loss":loss, "acc":acc}

    def test_step(self, batch, **kwargs):
        X, y = batch
        bx, sc = self.model(X, training=False)
        loss = self.loss_posit(y, bx)
        acc = self.metr_posit(y, bx)
        return {"loss":loss, "acc":acc}

    def call(self, X, **kwargs):
        return self.model(X, **kwargs)

In [None]:
vit_train = ModelTrain(vit_model)

In [None]:
vit_train.compile(Adam(learning_rate=0.0001), cat_loss, nonzeroaccuracy)

In [None]:
vit_train.evaluate(batched_test, steps=1)

In [None]:
history = vit_train.fit(batched_train, steps_per_epoch=100, epochs=20, validation_data=batched_test, validation_steps=1)

In [None]:
pe.save(os.path.join(PATH_TO_MODELS, "pe_conv_semantic_v05"), "tf")
sa.save(os.path.join(PATH_TO_MODELS, "sa_conv_semantic_v05"), "tf")
sf.save(os.path.join(PATH_TO_MODELS, "sf_conv_semantic_v05"), "tf")

In [None]:
resaa, resscore = vit_model(aa, training=False)
resaa.shape

In [None]:
displayres_batch(aa, mm, resaa, 3, 2)