In [3]:
import fastai
from fastai.vision.all import *
from fastdownload import FastDownload
import matplotlib.pyplot as plt
from models.GAN import GANModule
from torchsummary import summary

In [None]:
path = Path('./data')

In [None]:
Path.BASE_PATH = path

In [None]:
if not (path/'archive').exists():
    loader = FastDownload(base=path.name, module=fastai.data)
    loader.update('https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/camel.npy')
    loader.download('https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/camel.npy')

In [None]:
def get_quickdraw_tensors(path):
    files = get_files(path, extensions=['.npy'])
    
    label = 0
    for f in files:
        subject = torch.from_numpy(np.load(f))
        subject = subject.view(subject.shape[0], 1, 28, 28)
        labels = torch.tensor([label] * len(subject))
        
        if label == 0:
            x = subject
            y = labels
        else:
            x = torch.cat((x, subject))
            y = torch.cat((y, labels))
        
        label += 1
    
    return x, y

In [None]:
camels, labels = get_quickdraw_tensors(path)

In [None]:
camels.shape

In [None]:
plt.imshow(camels[0].permute(1, 2, 0))

In [None]:
plt.imshow(camels[10].permute(1, 2, 0))

In [None]:
gan = GAN(input_shape=[1, 28, 28],
          disc_conv_filters=[64, 64, 128, 128],
          disc_conv_kernels=[5, 5, 5, 5],
          disc_conv_strides=[2, 2, 2, 1],
          disc_batch_norm_mom=None,
          disc_dropout_prob=0.4,
          gen_unflattened_shape=[64, 7, 7],
          gen_upsample_scale=[2, 2, 1, 1],
          gen_conv_filters=[128, 64, 64, 1],
          gen_conv_kernels=[5, 5, 5, 5],
          gen_conv_strides=[1, 1, 1, 1],
          gen_batch_norm_mom=0.9,
          gen_dropout_prob=None,
          z_dim=100
         )

In [None]:
summary(gan.discriminator, (1, 28, 28))

In [None]:
summary(gan.generator, (100,))

In [None]:
generated_imgs = torch.tensor([]).to('cuda')
for i in range(125):
    noise = torch.normal(0, 1, (64, 100)).to('cuda')
    generated_imgs = torch.cat((generated_imgs, gan.generator(noise)))

In [None]:
generated_imgs.shape

In [None]:
plt.imshow(generated_imgs[0].detach().cpu().permute(1, 2, 0))

In [None]:
plt.imshow(generated_imgs[1].detach().cpu().permute(1, 2, 0))