In [None]:
import os
import glob
import re

from typing import Optional

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_lightning as pl

from PIL import Image

In [None]:
DATA_DIR = 'data\Abstract_gallery'
CHECKPOINTS_DIR = 'checkpoints'
GENERATED_IMGS_DIR = 'generated_imgs'
IMAGE_SIZE = (64, 64)
BATCH_SIZE = 64
IMAGE_CHANNELS = 3
LATENT_SIZE = 256
EPOCHS = 300
NUM_WORKERS = 0
random_seed = 42
torch.manual_seed(random_seed)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
def save_images(image_tensor, epoch, num_images=25):
    path = os.path.join(GENERATED_IMGS_DIR, 'generated_img' + '_' + str(epoch) + '.png')
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    save_image(image_grid, path)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
def get_noise(cur_batch_size, z_dim):
    noise = torch.randn(cur_batch_size, z_dim, 1, 1, device=device)
    return noise

In [None]:
class AbstractGalleryDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len([image for image in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, image))])

    def __getitem__(self, index):
        image_path = os.path.join(self.root_dir, os.listdir(self.root_dir)[index])
        image = Image.open(image_path)

        if self.transform:
            image = self.transform(image)

        return image

In [None]:
class AbstractGalleryDataModule(pl.LightningDataModule):
    def __init__(self, data_dir=DATA_DIR, batch_size=BATCH_SIZE):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [
                transforms.Resize(IMAGE_SIZE),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]
        )

    def setup(self, stage: Optional[str] = None):
        df = AbstractGalleryDataset(DATA_DIR, transform=self.transform)
        self.train, self.test = random_split(df, [1950, 832])

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, num_workers=NUM_WORKERS)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, num_workers=NUM_WORKERS)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            # n x latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 64 * 8, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            # n x 64*8 x 4 x 4
            nn.ConvTranspose2d(64 * 8, 64 * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            # n x 64*4 x 8 x 8
            nn.ConvTranspose2d(64 * 4, 64 * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            # n x 64*2 x 16 x 16
            nn.ConvTranspose2d(64 * 2, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # n x 64 x 32 x 32
            nn.ConvTranspose2d(64, IMAGE_CHANNELS, kernel_size=4, stride=2, padding=1),
            # n x IMAGE_CHANNELS x 64 x 64
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # n x IMAGE_CHANNELS x 64 x 64
            nn.Conv2d(IMAGE_CHANNELS, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            # n x 64 x 32 x 32
            nn.Conv2d(64, 64 * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2),
            # n x 64*2 x 16 x 16
            nn.Conv2d(64 * 2, 64 * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2),
            # n x 64*4 x 8 x 8
            nn.Conv2d(64 * 4, 64 * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2),
            # n x 64*8 x 4 x 4
            nn.Conv2d(64 * 8, 1, kernel_size=4, stride=1, padding=0),
            # n x 1 x 1 x 1
            nn.Sigmoid()
        )

    def forward(self, x):
        pred = self.model(x)
        return pred.view(len(pred), -1)

In [None]:
class DCGAN(pl.LightningModule):
    def __init__(self, lr=0.0002, latent_dim=100, b1=0.5, b2=0.999):
        super().__init__()
        self.save_hyperparameters()

        self.generator = Generator(self.hparams.latent_dim)
        self.discriminator = Discriminator()

        self.generator.apply(weights_init)
        self.discriminator.apply(weights_init)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def generator_step(self, x, z):
        # generate fake images
        self.gen_imgs = self.generator(z)

        y_hat_fake = self.discriminator(self.gen_imgs)

        generator_loss = self.adversarial_loss(y_hat_fake, torch.ones_like(y_hat_fake))
        self.log('generator_loss', generator_loss, prog_bar=True, sync_dist=True)
        return generator_loss

    def discriminator_step(self, x, z):
        y_hat_real = self.discriminator(x)
        y_hat_fake = self.discriminator(self.gen_imgs.detach())

        real_loss = self.adversarial_loss(y_hat_real, torch.ones_like(y_hat_real))
        fake_loss = self.adversarial_loss(y_hat_fake, torch.zeros_like(y_hat_fake))
        discriminator_loss = (real_loss + fake_loss) / 2
        self.log('discriminator_loss', discriminator_loss, prog_bar=True, sync_dist=True)
        return discriminator_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_imgs = batch

        # sample noise
        z = get_noise(real_imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(real_imgs)

        loss = 0.0
        if optimizer_idx == 0:
            loss = self.generator_step(real_imgs, z)

        elif optimizer_idx == 1:
            loss = self.discriminator_step(real_imgs, z)
        return loss

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        generator_optimizer = optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))

        return [generator_optimizer, discriminator_optimizer], []

    def on_epoch_end(self):
        save_images(self.gen_imgs, self.current_epoch)

In [None]:
dm = AbstractGalleryDataModule()
model = DCGAN()

In [None]:
trainer = pl.Trainer(max_epochs=EPOCHS,
                     accelerator='gpu',
                     devices=1,
                     default_root_dir=CHECKPOINTS_DIR,
                     log_every_n_steps=31)
trainer.fit(model, dm)

In [None]:
l = glob.glob(f'{GENERATED_IMGS_DIR}/generated_img_*.png')
l.sort(key=lambda f: int(re.sub('\D', '', f)))
frames = [Image.open(image) for image in l]
frame_one = frames[0]
frame_one.save(f'{GENERATED_IMGS_DIR}/generated_images.gif', format='GIF', append_images=frames, save_all=True,
               duration=200, loop=0, optimize=False)