<a href="https://colab.research.google.com/github/avva14/image_generators/blob/main/check_vit_generators.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
!pip install pillow --upgrade

In [None]:
!git clone https://github.com/avva14/common_utils.git

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
PATH_TO_TFDS = '/content/gdrive/MyDrive/tensorflow_datasets'
PATH_TO_MOIRE = '/content/gdrive/MyDrive/Patterns/moiredata'

In [None]:
import os
import numpy as np
import cv2 as cv

In [None]:
import zipfile
from zipfile import ZipFile

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from math import ceil

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Display utils

In [None]:
def display_generated(generator, nr, nc):

    fig, axxes = plt.subplots(ncols=2*nc,
                              nrows=nr,
                              figsize=(6*nc, 3*nr),
                              sharey=False, sharex=False)

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):

        if i % 2 == 0:
            a, m = next(generator)

            k = np.any(m[:,1:], axis=-1) * (np.argmax(m[:,1:], axis=-1) + 1)
            f = np.max(m[:,1:], axis=-1)

            m = m.astype(np.int32)
            ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)

            ixes = np.where(k > 0)[0]
            ax.scatter(PSIZ*(ixes % NDIV) + PSIZ//2, PSIZ*(ixes // NDIV) + PSIZ//2, s=2)

            ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
            ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
            ax.set_ylim(0,SIZE-1)
            ax.set_xlim(0,SIZE-1)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        else:
            ax.scatter(PSIZ*(ixes % NDIV) + PSIZ//2, PSIZ*(ixes // NDIV) + PSIZ//2, s=2)

            for p in ixes:
                ax.text((p%NDIV)+0.3,(p//NDIV)+0.3,f'{k[p]-1}',alpha=f[p])

            ax.set_yticks(np.arange(0, NDIV+1))
            ax.set_xticks(np.arange(0, NDIV+1))
            ax.set_ylim(0,NDIV)
            ax.set_xlim(0,NDIV)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)

    fig.tight_layout()
    plt.show()

def display_batch(abatch, mbatch, nr, nc):

    fig, axxes = plt.subplots(ncols=2*nc,
                              nrows=nr,
                              figsize=(6*nc, 3*nr),
                              sharey=False, sharex=False)

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):

        if i % 2 == 0:
            j = i // 2
            a = abatch[j]
            m = mbatch[j]

            k = np.any(m[:,1:], axis=-1) * (np.argmax(m[:,1:], axis=-1) + 1)
            f = np.max(m[:,1:], axis=-1)

            m = m.astype(np.int32)
            ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)

            ixes = np.where(k > 0)[0]
            ax.scatter(PSIZ*(ixes % NDIV) + PSIZ//2, PSIZ*(ixes // NDIV) + PSIZ//2, s=2)

            ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
            ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
            ax.set_ylim(0,SIZE-1)
            ax.set_xlim(0,SIZE-1)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        else:
            ax.scatter(PSIZ*(ixes % NDIV) + PSIZ//2, PSIZ*(ixes // NDIV) + PSIZ//2, s=2)

            for p in ixes:
                ax.text((p%NDIV)+0.3,(p//NDIV)+0.3,f'{k[p]-1}',alpha=f[p])

            ax.set_yticks(np.arange(0, NDIV+1))
            ax.set_xticks(np.arange(0, NDIV+1))
            ax.set_ylim(0,NDIV)
            ax.set_xlim(0,NDIV)
            ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)

    fig.tight_layout()
    plt.show()

## Parameters

In [None]:
ds = tfds.load('mnist', data_dir=PATH_TO_TFDS, download=False, split=['train', 'test'], shuffle_files=True)
train_set = ds[0].cache().shuffle(1024).repeat().as_numpy_iterator()
test_set = ds[1].cache().repeat().as_numpy_iterator()

In [None]:
rng = np.random.RandomState(1)

In [None]:
SIZE = 416
MNSZ = 28
NDIV = 16
NDIV2 = NDIV*NDIV
PSIZ = SIZE // NDIV
MAX_NOISE = 0.5
num_classes = 10

In [None]:
moirefiles = [os.path.join(PATH_TO_MOIRE, f) for f in os.listdir(PATH_TO_MOIRE)]

In [None]:
from common_utils.vit_generators import VitTrainGenerator, VitTestGenerator

## Test generator

In [None]:
testd = VitTestGenerator(test_set, rng, 2, num_classes, MAX_NOISE, MNSZ, SIZE, NDIV)

In [None]:
%%time
for _ in range(100):
    a, m = next(testd)
a.shape, m.shape

In [None]:
display_generated(testd, 2, 4)

## Train generator

In [None]:
traind = VitTrainGenerator(train_set, rng, 3, num_classes, moirefiles, MAX_NOISE, MNSZ, SIZE, NDIV)

In [None]:
%%time
for _ in range(100):
    a, m = next(traind)
a.shape, m.shape

In [None]:
display_generated(traind, 2, 4)

## TF datasets

In [None]:
def ugen_test():
    return VitTestGenerator(test_set, rng, 2, num_classes, MAX_NOISE, MNSZ, SIZE, NDIV)
def ugen_train():
    return VitTrainGenerator(train_set, rng, 3, num_classes, moirefiles, MAX_NOISE, MNSZ, SIZE, NDIV)

In [None]:
dataset_test = tf.data.Dataset.from_generator(
    ugen_test,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        tf.TensorSpec(shape=(NDIV2,num_classes+1), dtype=np.float32)
    )
)
dataset_train = tf.data.Dataset.from_generator(
    ugen_train,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        tf.TensorSpec(shape=(NDIV2,num_classes+1), dtype=np.float32)
    )
)

In [None]:
BATCHSIZE = 64

In [None]:
batched_test = dataset_test.batch(BATCHSIZE)
batched_train = dataset_train.batch(BATCHSIZE)

In [None]:
batched_test_iterator = batched_test.as_numpy_iterator()
batched_train_iterator = batched_train.as_numpy_iterator()

In [None]:
aa, mm = batched_test_iterator.next()
aa.shape, mm.shape

In [None]:
display_batch(aa, mm, 2, 4)

In [None]:
aa, mm = batched_train_iterator.next()
aa.shape, mm.shape

In [None]:
display_batch(aa, mm, 2, 4)