In [None]:
import os
import argparse

import torch
from pl_bolts.models.gans import DCGAN
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torchvision.utils import make_grid
import pytorch_lightning as pl
from pytorch_lightning import seed_everything

import numpy as np
import matplotlib.pyplot as plt

from typing import Optional, List, Any

In [None]:
seed_everything(6)

In [None]:
config = argparse.Namespace(
    gpu=True,
    data_dir="../input/pytorch-challange-flower-dataset/dataset",
    n_epochs=40,
    batch_size=128,
    
    min_image_size=500,
    image_size=64,
    n_channels=3,
)

In [None]:
class ImageDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, min_image_size: str,
                 image_size: Optional[int] = 64,
                 batch_size: Optional[int] = 1) -> None:
        super(ImageDataModule, self).__init__()
        
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.image_size = image_size

    def setup(self, stage: Optional[str] = None) -> None:
        pipeline = T.Compose([
            T.CenterCrop(500),
            T.Resize(self.image_size),
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.train_dataset = ImageFolder(
            self.data_dir, transform=pipeline
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          shuffle=True)

In [None]:
def demo_single(model: torch.nn.Module) -> None:
    """Create a demo of a single image generated by model."""
    single = (model(torch.rand(1, 100)).squeeze(0) * 0.5 + 0.5).detach()
    plt.imshow(single.permute(1, 2, 0))
    plt.show()

In [None]:
def demo_grid(model: torch.nn.Module) -> None:
    """Create a 3 x 5 demo of images generated by model."""
    grid = make_grid(model(torch.randn(15, 100)).detach(), nrow=5, normalize=True)
    plt.imshow(grid.permute(1, 2, 0))
    plt.show()

In [None]:
flower_dm = ImageDataModule(config.data_dir, config.min_image_size,
                            config.image_size, config.batch_size)

model = DCGAN(image_channels=config.n_channels)

trainer = pl.Trainer(
    gpus=1 if config.gpu else 0,
    max_epochs=config.n_epochs,
)

In [None]:
trainer.fit(model, datamodule=flower_dm)

demo_single(model)
demo_grid(model)