In [1]:
import torch
import pytorch_lightning as pl
from torchvision.datasets import MNIST
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision import transforms
import torch.nn.functional as F

In [2]:
mnist = MNIST('', download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]))

In [3]:
class Generator(nn.Module):
    def __init__(self, img_size, network_width, network_depth):
        super().__init__()
        self.img_size = img_size
        self.network_depth = network_depth
        self.l1 = nn.Linear(self.img_size**2, network_width)
        self.l2 = nn.Linear(network_width, network_width)
        self.l3 = nn.Linear(network_width, self.img_size**2)
        self.bn = nn.BatchNorm1d(network_width)
        
    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = self.l1(x)
        x = self.bn(F.leaky_relu(x))
        for _ in range(self.network_depth - 2):
            x = self.l2(x)
            x = self.bn(F.leaky_relu(x))
        x =  F.tanh(self.l3(x))
        return x.view(-1, 1, self.img_size, self.img_size)

class Discriminator(nn.Module):
    def __init__(self, img_size, network_width, network_depth):
        super().__init__()
        self.network_depth = network_depth
        self.l1 = nn.Linear(img_size**2, network_width)
        self.l2 = nn.Linear(network_width, network_width)
        self.l3 = nn.Linear(network_width, 1)
        self.sigmoid = nn.Sigmoid()
        self.bn = nn.BatchNorm1d(network_width)

    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = self.l1(x)
        x = self.bn(F.leaky_relu(x))
        for _ in range(self.network_depth - 2):
            x = self.l2(x)
            x = self.bn(F.leaky_relu(x))
        x = self.sigmoid(self.l3(x))
        return x

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(100, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod((28,28)))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1,28,28)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod((28,28))), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [4]:
class GAN(pl.LightningModule):
    def __init__(self, dataset, image_size=28, batch_size=64, generator_width=800, generator_depth=6,
                discriminator_width=400, discriminator_depth=3, discriminator_training_loops=1, num_workers=16):
        super().__init__()
        self.image_size = image_size
        self.batch_size = batch_size
        self.generator_width = generator_width
        self.generator_depth = generator_depth
        self.discriminator_width = discriminator_width
        self.discriminator_depth = discriminator_depth
        self.discriminator_training_loops = discriminator_training_loops
        self.generator = Generator()#self.image_size, self.generator_width, self.generator_depth)
        self.discriminator = Discriminator()#self.image_size, self.discriminator_width, self.discriminator_depth)
        self.automatic_optimization=False
        self.dataset = dataset
        self.num_workers = num_workers
        
    def train_dataloader(self):
        return DataLoader(self.dataset,
                          self.batch_size * self.discriminator_training_loops,
                          shuffle=True,
                          num_workers=self.num_workers)
    
    def training_step(self, batch, batch_idx):
        X, _ = batch
        generator_optimizer, discriminator_optimizer = self.optimizers()
        for data_batch in torch.split(X, self.batch_size):
            noise_batch = torch.randn(self.batch_size, 100).cuda()
            discriminator_optimizer.zero_grad()
            loss_discriminator = self.discriminator_loss(data_batch, noise_batch)
            self.manual_backward(loss_discriminator)
            
            grad_max = max([torch.max(p.grad) for p in self.discriminator.parameters() if p.grad is not None])
            torch.nn.utils.clip_grad_value_(self.discriminator.parameters(), 0.5)
            grad_max_a = max([torch.max(p.grad) for p in self.discriminator.parameters() if p.grad is not None])
            
            
            discriminator_optimizer.step()
        generator_optimizer.zero_grad()
        noise_batch = torch.randn(self.batch_size, 100).cuda()
        loss_generator = self.generator_loss(noise_batch)
        self.manual_backward(loss_generator)

        grad_max2 = max([torch.max(p.grad) for p in self.generator.parameters() if p.grad is not None])
        torch.nn.utils.clip_grad_value_(self.generator.parameters(), 0.5)
        grad_max2_a = max([torch.max(p.grad) for p in self.generator.parameters() if p.grad is not None])
            
        generator_optimizer.step()
        self.logger.experiment.add_scalar("Generator_loss", loss_generator, self.current_epoch)
        self.logger.experiment.add_scalar("Discriminator_loss", loss_discriminator, self.current_epoch)
        self.logger.experiment.add_scalar("Grad_max_before", grad_max, self.current_epoch)
        self.logger.experiment.add_scalar("Grad_max_after", grad_max_a, self.current_epoch)
        self.logger.experiment.add_scalar("Grad_max_before2", grad_max2, self.current_epoch)
        self.logger.experiment.add_scalar("Grad_max_after2", grad_max2_a, self.current_epoch)
        return 
    
    def configure_optimizers(self):
        generator_optimizer = Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        discriminator_optimizer = Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        return generator_optimizer, discriminator_optimizer
    
    def generator_loss(self, noise):
        return -torch.mean(torch.log(self.discriminator(self.generator(noise)) + 1e-16))
    
    def discriminator_loss(self, data, noise):
        return -0.5 * torch.mean(torch.log(self.discriminator(data) + 1e-16)) - 0.5 * torch.mean(torch.log(1-self.discriminator(self.generator(noise)) + 1e-16))

In [5]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger('my_logs', 'GAN')
tr = Trainer(max_epochs=500, gpus=1, logger=logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [6]:
%reload_ext tensorboard
%tensorboard --logdir my_logs/

In [None]:
import numpy as np
model = GAN(mnist)
tr.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 1.5 M 
1 | discriminator | Discriminator | 533 K 
------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.174     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]