In [1]:
import os
from argparse import ArgumentParser
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from apex import amp
import pytorch_lightning as pl
from matplotlib import pyplot as plt


In [2]:
class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.Normalize((0.5,), (0.5,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 64, 64)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers,pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers,pin_memory=True)

In [3]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape  
        self.model = nn.Sequential(
          #N x channelsx1x1
          nn.ConvTranspose2d(latent_dim, img_shape[1]*16 , kernel_size=4, stride=1, padding=0, bias=False),
          nn.BatchNorm2d(img_shape[1]*16),
          nn.ReLU(True),
          #N x img_shape[1]*16x4x4 
          nn.ConvTranspose2d(img_shape[1]*16, img_shape[1]*8, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm2d(img_shape[1]*8),
          nn.ReLU(True),
            
          #N x img_shape[1]*16x4x4 
          nn.ConvTranspose2d(img_shape[1]*8, img_shape[1]*4, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm2d(img_shape[1]*4),
          nn.ReLU(True),
  
          nn.ConvTranspose2d(img_shape[1]*4, img_shape[1]*2, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm2d(img_shape[1]*2),
          nn.ReLU(True),
            
          
          nn.ConvTranspose2d(img_shape[1]*2, 1, kernel_size=4, stride=2, padding=1, bias=False),
          nn.Tanh()
        )
        
    def forward(self, z):
        img = self.model(z)
        return img
    
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
          nn.Conv2d(1, img_shape[1], 4, 2, 1, bias=False),
          nn.LeakyReLU(0.2, inplace=True),
            
          nn.Conv2d(img_shape[1], img_shape[1] * 2, 4, 2, 1, bias=False),
          nn.BatchNorm2d(img_shape[1] * 2),
          nn.LeakyReLU(0.2, inplace=True),
            
          nn.Conv2d(img_shape[1] * 2, img_shape[1] * 4, 3, 2, 1, bias=False),
          nn.BatchNorm2d(img_shape[1] * 4),
          nn.LeakyReLU(0.2, inplace=True), 
            
            
          nn.Conv2d(img_shape[1]*4,img_shape[1]*8, 4, 1, 0, bias=False),
          nn.BatchNorm2d(img_shape[1] * 8),
          nn.LeakyReLU(0.2, inplace=True),
            
          nn.Conv2d(img_shape[1]*8,1, 4, 2, 0, bias=False),
  
          nn.Sigmoid()  
        )
            
    def forward(self, img):
        out=self.model(img).view(-1,1)
        return out
    



In [4]:
 class GAN(pl.LightningModule):

    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 64,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape )
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim,1, 1)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim,1, 1)

  
    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
        
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], 100,1, 1)
        z = z.type_as(imgs)

        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # log sampled images
            sample_imgs = self.generated_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs,padding=2, normalize=True)
            self.logger.experiment.add_image('generated_images', grid, 0)

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)

            fake_loss = self.adversarial_loss(
                self.discriminator(self(z).detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []
    
    def on_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs,padding=2, normalize=True)
        self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
   

In [5]:

dm = MNISTDataModule()
dm.size()

(1, 64, 64)

In [6]:
model = GAN(*dm.size())
model.generator


Generator(
  (model): Sequential(
    (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()

In [17]:
s=model.generator.model[0].weight
s.size()

torch.Size([100, 1024, 4, 4])

In [14]:
validation_z=torch.randn(8, 100,1, 1)
z = validation_z.type_as(model.generator.model[0].weight)
z.size()

torch.Size([8, 100, 1, 1])

In [12]:
model.discriminator.model[0].weight


Parameter containing:
tensor([[[[ 0.2467,  0.0032, -0.1278, -0.1803],
          [ 0.1709,  0.0856,  0.0804,  0.1430],
          [-0.1906, -0.2415, -0.2464,  0.0902],
          [ 0.1265,  0.2290,  0.1105,  0.2384]]],


        [[[-0.1610,  0.2318,  0.0982,  0.0828],
          [-0.0387,  0.1072, -0.1932,  0.1758],
          [-0.1745, -0.2169,  0.1665, -0.2090],
          [ 0.1505, -0.0163,  0.0807,  0.0605]]],


        [[[-0.0284,  0.0714,  0.1753, -0.1356],
          [ 0.1821,  0.1455, -0.2499,  0.2424],
          [ 0.1388,  0.1589, -0.1242, -0.0056],
          [-0.0895, -0.1250, -0.0850,  0.0216]]],


        ...,


        [[[-0.0224, -0.1795, -0.0660,  0.0231],
          [-0.0548, -0.2251,  0.0346,  0.1126],
          [ 0.2393, -0.1878, -0.0976,  0.2489],
          [-0.2247,  0.1017,  0.2456, -0.1515]]],


        [[[-0.2315,  0.1054, -0.1532,  0.1417],
          [ 0.2332,  0.2095,  0.0579,  0.0191],
          [-0.1184,  0.0687,  0.1775,  0.0041],
          [ 0.1787, -0.0408,  0.004

In [None]:
x = torch.randn(1, 100,1, 1)
x.size()


In [None]:
out = model(x)
out.size()
tensor = out.squeeze(0)
tensor=tensor.view(64,64)
tensor=tensor.detach().numpy()
plt.imshow(tensor)
plt.show()

In [None]:
x1 = torch.randn(1, 1,64, 64)
model.discriminator


In [None]:
o=model.discriminator(x1)
o.size()




In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(model, dm)

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/   --host localhost 
