## GANs with mnist dataset
We'll look at how to make a simple generator and discriminator network and try to generate images from the mnist dataset

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *

In [None]:
path  = untar_data(URLs.MNIST)

In [None]:
path.ls()

In [None]:
bs,size=128,28

In [None]:
il = ImageList.from_folder(path, convert_mode='L')

In [None]:
defaults.cmap='binary'

In [None]:
sd = il.split_by_folder(train='training', valid='testing')

In [None]:
ll = sd.label_from_folder()

In [None]:
x,y = ll.train[0]

In [None]:
x.show()
print(y,x.shape)

Since these are numbers we don't want to apply any warping, flip or zoom augmentation as they will cease to look like numbers. We only apply some random padding

In [None]:
tfms = ([*rand_pad(padding=3, size=size, mode='zeros')], [])

In [None]:
ll.transform(tfms)
data = ll.databunch(bs=bs).normalize()

Always a good idea to look at these stats if you're not sure what format the image is in. 

In [None]:
def get_stats(tensor): return tensor.mean(),tensor.std(),tensor.min(),tensor.max()

In [None]:
get_stats(data.train_ds[1][-1].data)

In [None]:
def get_data(bs,size):
    return (GANItemList.from_folder(path, noise_sz=100,convert_mode='L')
                    .split_none()
                    .label_from_func(noop)
                    .transform(tfms,size=size,tfm_y=True)
                    .databunch(bs=bs).normalize(stats = [torch.tensor([0.5]), torch.tensor([0.5])],do_x=False,do_y=True))

In [None]:
data = get_data(bs,size)

In [None]:
data.show_batch(2)

We define some helper functions to get quickly get `conv_blocks` i.e layers with 2d convolutions,relu activation and batch normalization as we use them quite often and don't want to make mistakes

In [None]:
def conv2d(ni,nf,kernel_size=4,stride=2,bn=True):
    layers = [nn.Conv2d(ni,nf,kernel_size=kernel_size,stride=stride,padding=1,bias=False),nn.LeakyReLU(0.2,True)]
    if bn: layers.append(nn.BatchNorm2d(nf))
    return nn.Sequential(*layers)

`nn.ConvTranspose2d` is a deconvolution or a fractional convolution. Instead of downsampling (which is what is done with stride > 1 convolutions), we upsample (increase the grid size)

In [None]:
def convt2d(ni,nf,kernel_size=4,stride=2,padding=1,bn=True): 
    layers = [nn.ConvTranspose2d(ni,nf,kernel_size=kernel_size,stride=stride,padding=1,bias=False),nn.ReLU(True)]
    if bn: layers.append(nn.BatchNorm2d(nf))
    return nn.Sequential(*layers)

In [None]:
from fastai.vision.gan import AvgFlatten

In [None]:
critic = nn.Sequential(conv2d(3, 8,bn=False), #14
                       conv2d(8,8,kernel_size=3,stride=1),#7
                       conv2d(8,16),#7
                       conv2d(16, 32), #4
                       nn.Conv2d(32,1,kernel_size=3,stride=1,padding=0,bias=False), #2
                       AvgFlatten()) #make into vector

In [None]:
generator = nn.Sequential(convt2d(100,8,padding=0), #2
                          convt2d(8,16), #4
                          convt2d(16,32,stride=3), #12
                          convt2d(32,16), #24
                          convt2d(16,16,stride=1), #24
                          nn.ConvTranspose2d(16, 3, kernel_size=4, stride=1, padding=1,bias=False), #24
                          nn.Tanh())

In [None]:
learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
                        opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

In [None]:
learn.fit(10,1e-4)

In [None]:
bs,size=128,64
data = get_data(bs,size)

In [None]:
generator = basic_generator(in_size=size,n_channels=3,n_features=64,n_extra_layers=1)
critic = basic_critic(in_size=size,n_channels=3,n_features=64,n_extra_layers=1)

In [None]:
learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
                        opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

In [None]:
learn.fit(10,1e-4)