In [1]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets



In [2]:
# Number of workers for dataloader
workers = 4

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 1

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 2

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
# ngpu = 1

In [3]:
class DCGAN(pl.LightningModule):

    def __init__(self):
        super().__init__()

        # Important: This property activates manual optimization.
        self.automatic_optimization = False

        self.criterion = nn.BCELoss()

        self.generator = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(inplace=True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(inplace=True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(inplace=True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
        self.generator.apply(self.weights_init)

        self.discriminator = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        self.discriminator.apply(self.weights_init)
    
    def forward(self, x):
        fake = self.generator(x)
        return fake

    def training_step(self, batch, batch_idx):

        optim_generator, optim_discriminator = self.optimizers()

        X, _ = batch
        batch_size = X.shape[0]

        real_label = torch.ones((batch_size,), device=self.device) # 1 means image considered real
        fake_label = torch.zeros((batch_size,), device=self.device) # 0 means image considered real


        # Generate fake images with generator from batch of latent vectors
        noise = torch.randn(batch_size, nz, 1, 1, device=self.device)
        fake = self(noise) # uses forward function


        # DISCRIMINATOR
        d_real = self.discriminator(X).view(-1)
        loss_d_real = self.criterion(d_real, real_label)

        d_fake = self.discriminator(fake.detach()).view(-1) # detach to not backprop through generator
        loss_d_fake = self.criterion(d_fake, fake_label)

        loss_d = loss_d_real + loss_d_fake

        optim_discriminator.zero_grad()
        self.manual_backward(loss_d)
        optim_discriminator.step()

        # GENERATOR
        d_fake = self.discriminator(fake).view(-1)
        loss_g = self.criterion(d_fake, real_label)

        optim_generator.zero_grad()
        self.manual_backward(loss_g)
        optim_generator.step()


    # def validation_step(self, batch, batch_idx):


    def configure_optimizers(self):
        optim_generator = optim.Adam(self.generator.parameters(), lr=lr, betas=(beta1, 0.999))
        optim_discriminator = optim.Adam(self.discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
        return optim_generator, optim_discriminator


    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)


In [4]:
# MNIST
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
    ])
dataset = datasets.MNIST(
    root='../data',
    # train=True,
    transform=transform
)

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [5]:
model = DCGAN()
trainer = pl.Trainer(
    gpus=1,
    max_epochs=num_epochs,
)
trainer.fit(model, dataloader)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type       | Params
---------------------------------------------
0 | criterion     | BCELoss    | 0     
1 | generator     | Sequential | 3.6 M 
2 | discriminator | Sequential | 2.8 M 
---------------------------------------------
6.3 M     Trainable params
0         Non-trainable params
6.3 M     Total params
25.353    Total estimated model params size (MB)


Epoch 1: 100%|██████████| 469/469 [00:28<00:00, 16.71it/s, v_num=63]
