<a href="https://colab.research.google.com/github/avva14/image_generators/blob/main/check_generators.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
!pip install pillow --upgrade

In [None]:
!git clone https://github.com/avva14/common_utils.git

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
PATH_TO_TFDS = '/content/gdrive/MyDrive/tensorflow_datasets'
PATH_TO_MOIRE = '/content/gdrive/MyDrive/Patterns/moiredata'

In [None]:
import os
import numpy as np
import cv2 as cv

In [None]:
import zipfile
from zipfile import ZipFile

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from math import ceil

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Display utils

In [None]:
def rectangperiodic(bx, edgecolor='r'):
    '''
    bx -- list of boxes [x, y, w, h]
    Outputs list of rectangels considering periodic boundary conditions
    '''
    res = []
    ypos = bx[1] % SIZE
    if (ypos + bx[3] < SIZE):
        rect = Rectangle((bx[0], ypos), bx[2], bx[3], edgecolor=edgecolor, fill=False)
        res.append(rect)
    else:
        rect = Rectangle((bx[0], ypos), bx[2], SIZE-ypos, edgecolor=edgecolor, fill=False)
        res.append(rect)
        rect = Rectangle((bx[0], 0), bx[2], ypos+bx[3]-SIZE, edgecolor=edgecolor, fill=False)
        res.append(rect)
    return res
def display_generated(generator, nr, nc, NDIV=13):

    PSIZ = SIZE // NDIV

    fig, axxes = plt.subplots(ncols=nc,
                              nrows=nr,
                              figsize=(3*nc, 3*nr),
                              sharey=True, sharex=True)

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):
        a, lb = next(generator)
        l, b = lb
        l = np.squeeze(l).astype(np.int32)

        ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)

        for bx in b:
            rcts = rectangperiodic(bx)
            for rct in rcts:
                ax.add_patch(rct)

        ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_xticklabels([])
        ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_ylim(0,SIZE-1)
        ax.set_xlim(0,SIZE-1)
        ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        ax.set_title(l)

    fig.tight_layout()
    plt.show()
def display_batch(abatch, lbbatch, nr, nc, NDIV=13):

    PSIZ = SIZE // NDIV

    fig, axxes = plt.subplots(ncols=nc,
                              nrows=nr,
                              figsize=(3*nc, 3*nr),
                              sharey=True, sharex=True)

    ll, bb = llbb

    axxes = np.ravel(axxes)

    for i, ax in enumerate(axxes):
        a = abatch[i]
        b = bb[i]
        l = np.squeeze(ll[i]).astype(np.int32)

        ax.imshow(1-a, aspect=1, cmap='gray', vmin=0, vmax=1)

        for bx in b:
            rcts = rectangperiodic(bx)
            for rct in rcts:
                ax.add_patch(rct)

        ax.set_yticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_xticklabels([])
        ax.set_xticks(PSIZ*np.arange(0, NDIV+1))
        ax.set_ylim(0,SIZE-1)
        ax.set_xlim(0,SIZE-1)
        ax.grid(color='g', linestyle='-.', linewidth=0.7, alpha=0.95)
        ax.set_title(l)

    fig.tight_layout()
    plt.show()

## Parameters

In [None]:
ds = tfds.load('mnist', data_dir=PATH_TO_TFDS, download=False, split=['train', 'test'], shuffle_files=True)
train_set = ds[0].cache().shuffle(1024).repeat().as_numpy_iterator()
test_set = ds[1].cache().repeat().as_numpy_iterator()

In [None]:
rng = np.random.RandomState(1)

In [None]:
SIZE = 416
MNSZ = 28
MAX_NOISE = 0.5
num_classes = 10

In [None]:
moirefiles = [os.path.join(PATH_TO_MOIRE, f) for f in os.listdir(PATH_TO_MOIRE)]

In [None]:
from common_utils.single_generators import SingleTestGenerator, SingleTrainGenerator

## Test generator

In [None]:
testd = SingleTestGenerator(test_set, rng, MAX_NOISE, MNSZ, SIZE)

In [None]:
%%time
for _ in range(100):
    a, lb = next(testd)
    l, b = lb
a.shape, l.shape, b.shape

In [None]:
display_generated(testd, 2, 3)

## Train generator

In [None]:
traind = SingleTrainGenerator(train_set, rng, moirefiles, MAX_NOISE, MNSZ, SIZE)

In [None]:
%%time
for _ in range(100):
    a, lb = next(traind)
    l, b = lb
a.shape, l.shape, b.shape

In [None]:
display_generated(traind, 2, 3)

## TF datasets

In [None]:
def single_test():
    return SingleTestGenerator(test_set, rng, MAX_NOISE, MNSZ, SIZE)

In [None]:
def single_train():
    return SingleTrainGenerator(train_set, rng, moirefiles, MAX_NOISE, MNSZ, SIZE)

In [None]:
dataset_test = tf.data.Dataset.from_generator(
    single_test,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        (tf.TensorSpec(shape=(1), dtype=np.float32), tf.TensorSpec(shape=(1,4), dtype=np.float32))
    )
)
dataset_train = tf.data.Dataset.from_generator(
    single_train,
    output_signature=(
        tf.TensorSpec(shape=(SIZE,SIZE,1), dtype=np.float32),
        (tf.TensorSpec(shape=(1), dtype=np.float32), tf.TensorSpec(shape=(1,4), dtype=np.float32))
    )
)

In [None]:
BATCHSIZE = 12

In [None]:
batched_test = dataset_test.batch(BATCHSIZE)
batched_train = dataset_train.batch(BATCHSIZE)

In [None]:
batched_test_iterator = batched_test.as_numpy_iterator()
batched_train_iterator = batched_train.as_numpy_iterator()

In [None]:
aa, llbb = batched_test_iterator.next()
ll, bb = llbb
aa.shape, ll.shape, bb.shape

In [None]:
display_batch(aa, llbb, 2, 3)

In [None]:
aa, llbb = batched_train_iterator.next()
ll, bb = llbb
aa.shape, ll.shape, bb.shape

In [None]:
display_batch(aa, llbb, 2, 3)