In [None]:
from fastai.vision.all import *
from models.GAN import Generator, Critic, GANLearner
from fastai.vision.gan import generate_noise
from torchsummary import summary
from util.util import GANImageBlock

In [None]:
path = Path('data/celebA/img_align_celeba/img_align_celeba')

In [None]:
Path.BASE_PATH = path

In [None]:
datablock = DataBlock(blocks=(TransformBlock, GANImageBlock),
                      get_items=get_image_files,
                      get_x=generate_noise,
                      splitter=RandomSplitter(seed=42),
                      item_tfms=Resize(64)
                     )

In [None]:
dataloaders = datablock.dataloaders(path, bs=64)

In [None]:
dataloaders.show_batch(max_n=9)

In [None]:
generator = Generator(z_dim=100,
                      unflattened_shape=[512, 4, 4],
                      upsample_scale=[2, 2, 2, 2],
                      filters=[256, 128, 64, 3],
                      kernels=[5, 5, 5, 5],
                      strides=[1, 1, 1, 1],
                      batch_norm_mom=0.9,
                      dropout_prob=None
                     )

In [None]:
summary(generator, (100,))

In [None]:
critic = Critic(input_shape=[3, 64, 64],
                filters=[64, 128, 256, 512],
                kernels=[5, 5, 5, 5],
                strides=[2, 2, 2, 2],
                batch_norm_mom=None,
                dropout_prob=None
               )

In [None]:
summary(critic, (3, 64, 64))

In [None]:
learner = GANLearner.wgangp(dataloaders=dataloaders,
                            generator=generator,
                            critic=critic,
                            opt_func=Adam
                           )

learner.recorder.train_metrics=True
learner.recorder.valid_metrics=False

In [None]:
learner.lr_find()

In [None]:
learner.fit(10, 8e-4)