In [1]:
import os
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import FashionMNIST

import matplotlib.pyplot as plt

import pytorch_lightning as pl

In [2]:
random_seed = 42
torch.manual_seed(random_seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE=64
CHANNELS_IMG = 1
IMAGE_SIZE = 64
FEATURES_DISC = 64
FEATURES_GEN = 64
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4
N_CRITIC = 5
LAMBDA_GP = 10
NUM_CLASSES = 10
GEN_EMBEDDING = 100

In [3]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./data",
                 batch_size=BATCH_SIZE, num_workers=1):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.Resize(IMAGE_SIZE),
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
                ),
            ]
        )


    def prepare_data(self):
        FashionMNIST(self.data_dir, train=True, download=True)
        FashionMNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets
        if stage == "fit" or stage is None:
            mnist_full = FashionMNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [60000, 0])

        # Assign test dataset
        if stage == "test" or stage is None:
            self.mnist_test = FashionMNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)


In [4]:
# Detective: fake or no fake -> 1 output [0, 1] not in WGAN
# add labels as emdeddings
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d, num_classes, img_size):
        super().__init__()
        self.image_size = img_size
        self.disc = nn.Sequential(
            #Input size N x channels_img x 64 x64
            nn.Conv2d(
                channels_img+1, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4,2,1),#16x16
            self._block(features_d*2, features_d*4, 4,2,1),#8x8
            self._block(features_d*4, features_d*8, 4,2,1),#4x4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), #1x1
            #nn.Sigmoid(),
        )
        self.embed = nn.Embedding(num_classes, img_size*img_size)
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
          return nn.Sequential(
              nn.Conv2d(
                  in_channels,
                  out_channels,
                  kernel_size,
                  stride,
                  padding,
                  bias = False,
              ),
              nn.InstanceNorm2d(out_channels, affine = True), #change BatchNorm to InstanceNorm ->LayerNorm as in paper
              nn.LeakyReLU(0.2),
          )


    def forward(self, x, labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.image_size, self.image_size)
        x = torch.cat([x, embedding], dim=1) 
        return self.disc(x)

In [5]:
# Generate Fake Data: output like real data [1, 28, 28] and values -1, 1
class Generator(nn.Module):
    def __init__(self, 
                 latent_dim,
                 channels_img,
                 features_g,
                 num_classes, 
                 img_size,
                 embed_size,
                ):
        super().__init__()
        self.img_size = img_size
        self.gen = nn.Sequential(
            self._block(latent_dim+embed_size, features_g*16, 4, 1, 0), # N x latent_dim
            self._block(features_g*16, features_g*8, 4, 2, 1), #8x8
            self._block(features_g*8, features_g*4, 4, 2, 1), #16x16
            self._block(features_g*4, features_g*2, 4, 2, 1), #32x32
            nn.ConvTranspose2d(features_g*2,
                               channels_img,
                               kernel_size=4,
                               stride =2,
                               padding=1
                               ),
            nn.Tanh()
        )
        self.embed = nn.Embedding(num_classes, embed_size)
        

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
                  nn.ConvTranspose2d(
                      in_channels,
                      out_channels,
                      kernel_size,
                      stride,
                      padding,
                      bias = False,
                  ),
                  nn.BatchNorm2d(out_channels),
                  nn.ReLU(),
        )


    def forward(self, x, labels):
        # N x noise dim x 1 x 1
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, embedding], dim=1)
        return self.gen(x)

In [6]:
def initialize_weights(model):
      for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
              nn.init.normal_(m.weight.data, 0.0, 0.02)

In [7]:
class GAN(pl.LightningModule):
    def __init__(self,
                 latent_dim=100,
                 lr=LEARNING_RATE,
                 b1: float = 0.0,
                  b2: float = 0.9,
                  batch_size: int = BATCH_SIZE,
                  **kwargs,):
        super().__init__()
        self.automatic_optimization = False
        self.save_hyperparameters()

        self.generator = Generator(latent_dim=self.hparams.latent_dim,
                                   channels_img=CHANNELS_IMG,
                                   features_g = FEATURES_GEN,
                                    num_classes = NUM_CLASSES, 
                                    img_size = IMAGE_SIZE,
                                    embed_size = GEN_EMBEDDING)
        self.discriminator = Discriminator(channels_img = CHANNELS_IMG,
                                           features_d = FEATURES_DISC,
                                          num_classes = NUM_CLASSES, 
                                           img_size = IMAGE_SIZE
                                           )

        initialize_weights(self.generator)
        initialize_weights(self.discriminator)

        #random noise
        self.validation_z = torch.rand(6, self.hparams.latent_dim, 1, 1)

    def forward(self, z, labels):
        return self.generator(z, labels)
    
    def gradient_penalty(self, labels, real, fake, device):
          BATCH_SIZE, C, H, W = real.shape
          epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
          interpolated_images = (real * epsilon + fake * (1 - epsilon)).requires_grad_(True)
          interpolated_images = interpolated_images.to(device)

          # calculate discriminator's score
          mixed_scores = self.discriminator(interpolated_images, labels)

          gradient = torch.autograd.grad(
              inputs = interpolated_images,
              outputs = mixed_scores,
              grad_outputs = torch.ones_like(mixed_scores),
              create_graph=True,
              retain_graph = True,
              
          )[0]

          gradient = gradient.view(gradient.shape[0], -1).to(device) #flatten all other dim
          gradient_norm = gradient.norm(2, dim=1) #L2 regularization
          gp = torch.mean((gradient_norm - 1) ** 2)
          return gp

    #def adversarial_loss(self, y_hat, y):
        #return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch):
        opt_g, opt_d = self.optimizers()
        real_imgs, labels = batch
                
        #sample noise
        z = torch.rand(real_imgs.shape[0], self.hparams.latent_dim, 1, 1)
        z = z.type_as(real_imgs)
        
        #train generator max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        self.toggle_optimizer(opt_g)
        fake_imgs = self(z, labels)
        y_hat = self.discriminator(fake_imgs, labels).reshape(-1)
        g_loss = - torch.mean(y_hat)

        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        opt_g.step()
        opt_g.zero_grad()
        self.untoggle_optimizer(opt_g)


        # train discriminator max E[critic(real)] - E[critic(fake)]
        self.toggle_optimizer(opt_d)

        y_hat_real = self.discriminator(real_imgs, labels).reshape(-1)
        y_hat_fake = self.discriminator(self(z, labels).detach(), labels).reshape(-1)
        gp = self.gradient_penalty(labels = labels, 
                                   real=real_imgs,
                                  fake=self(z, labels).detach(),
                                  device=device)
           
        d_loss = (-(torch.mean(y_hat_real) - torch.mean(y_hat_fake)))  + LAMBDA_GP * gp
            
        self.log("d_loss", d_loss, prog_bar=True)
        
        self.manual_backward(d_loss, retain_graph=True)
        opt_d.step()
        opt_d.zero_grad()
        self.untoggle_optimizer(opt_d)


    #change optimizers to RMSProp — root mean square propagation
    def configure_optimizers(self):
        n_critic = N_CRITIC

        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas = (b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas = (b1, b2))
        return (
            {'optimizer': opt_g, 'frequency': 1},
            {'optimizer': opt_d, 'frequency': n_critic}
        )


    def on_train_epoch_end(self):
        z = self.validation_z.to(device)
        labels = torch.tensor([0,1,2,9,4,8]).to(device)

        # log sampled images
        sample_imgs = self(z, labels)
        grid_fake = torchvision.utils.make_grid(sample_imgs, normalize=True)
        self.logger.experiment.add_image("generated_images", grid_fake, self.current_epoch)
        

# New Section

In [9]:
!pip install wandb

Collecting wandb
  Obtaining dependency information for wandb from https://files.pythonhosted.org/packages/35/d3/6bfe29e4ba1eb2400d478caf8e3af9a1c366390390069cda59a7c6bf6063/wandb-0.16.1-py3-none-any.whl.metadata
  Downloading wandb-0.16.1-py3-none-any.whl.metadata (9.8 kB)
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Obtaining dependency information for GitPython!=3.1.29,>=1.0.0 from https://files.pythonhosted.org/packages/8d/c4/82b858fb6483dfb5e338123c154d19c043305b01726a67d89532b8f8f01b/GitPython-3.1.40-py3-none-any.whl.metadata
  Downloading GitPython-3.1.40-py3-none-any.whl.metadata (12 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Obtaining dependency information for sentry-sdk>=1.0.0 from https://files.pythonhosted.org/packages/aa/39/a40c841782b775eec1602a82387b4e91322ccafd842fd60fc4deb9f13f7d/sentry_sdk-1.39.1-py2.py3-none-any.whl.metadata
  Downloading sentry_sdk-1.39.1-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading 

In [8]:
import wandb
wandb.login()

ModuleNotFoundError: No module named 'wandb'

In [None]:
wandb.init(project="GAN", sync_tensorboard=True)

In [None]:
dm = MNISTDataModule()
model = GAN()

In [None]:
trainer = pl.Trainer(max_epochs= NUM_EPOCHS, accelerator="auto")
trainer.fit(model, dm)

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
wandb.finish()