In [219]:
import os
import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("\data", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)

In [220]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str = PATH_DATASETS,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.13,), (0.31,)), #slight deviation from tutorial
            ]
        )
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        MNIST(self.data_dir, train = True, download = True)
        MNIST(self.data_dir, train = False, download = True)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train = True, transform = self.transform)
            val_size = int(len(mnist_full) * 0.1)
            train_size = len(mnist_full) - val_size
            self.mnist_train, self.mnist_val = random_split(mnist_full, [54016, 5984]) #This split because % 64 = 0

        if stage == "test" or stage is None:
            self.minst_test = MNIST(self.data_dir, train = False, transform = self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size = self.batch_size, num_workers = self.num_workers)

    # Gan does not use validation set but included for learning lightning purposes
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size = self.batch_size, num_workers = self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size = self.batch_size, num_workers = self.num_workers)

In [221]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape, classes):
        super().__init__()
        self.img_shape = img_shape
        self.classes = classes
        self.label_embedding = nn.Embedding(self.classes, self.classes)

        self.model = nn.Sequential(
            *self._create_layer(latent_dim + self.classes, 128, False),
            *self._create_layer(128, 256),
            *self._create_layer(256, 512),
            *self._create_layer(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def _create_layer(self, size_in, size_out, normalize=True):
        layers = [nn.Linear(size_in, size_out)]
        if normalize:
            layers.append(nn.BatchNorm1d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, z, c):
        input = torch.cat((self.label_embedding(c), z), -1)
        img = self.model(input)
        img = img.view(img.size(0), *self.img_shape)
        return img

In [222]:
class Discriminator(nn.Module):
    def __init__(self, img_shape, classes):
        super().__init__()
        self.classes = classes
        self.label_embedding = nn.Embedding(self.classes, self.classes)

        self.model = nn.Sequential(
            *self._create_layer(self.classes + int(np.prod(img_shape)), 2048, False, True),
            *self._create_layer(2048, 1024, True, True),
            *self._create_layer(1024, 512, True, True),
            *self._create_layer(512, 256, True, True),
            *self._create_layer(256, 128, False, False),
            *self._create_layer(128, 1, False, False),
            nn.Sigmoid()
        )

    def _create_layer(self, size_in, size_out, drop_out=True, act_func=True):
        layers = [nn.Linear(size_in, size_out)]
        if drop_out:
            layers.append(nn.Dropout(0.6)) 
        if act_func:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    #Increased dropout 0.4 -> 0.6 and added a 2048 -> 1024 layer seem to reduce noise but increase mode collapse
    #Added 4096 -> 2048 layer reduced surrounding noise even more but sometimes created inverted images
    #3096 -> 2048 layer just bad
    
    def forward(self, img, c):
        img_flat = img.view(img.size(0), -1)
        input = torch.cat((img_flat, self.label_embedding(c)), -1)
        validity = self.model(input)
        return validity

In [223]:
class GAN(L.LightningModule):
    def __init__(
        self, 
        channels, 
        width, 
        height, 
        latent_dim: int = 100, 
        lr: float = 1e-3,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = BATCH_SIZE,
        loss_type: str = "BCE",
        classes: int = 10
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape, classes=classes)
        self.discriminator = Discriminator(img_shape=data_shape, classes=classes)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)
        
        self.loss_type = loss_type
        self.classes = classes
        self.batch_size = batch_size

    def forward(self, z, c):
        return self.generator(z, c)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch):
        device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        imgs, labels = batch
        optimizer_g, optimizer_d = self.optimizers()

        # create noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)
        fake_labels = torch.randint(0, self.classes, (self.batch_size,), device=device)
        fake_labels = torch.full_like(fake_labels, fill_value=5, device=device)
        
        #train generator and generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z, fake_labels)


        # add sampled images to log
        sample_imgs = self.generated_imgs[:6]
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        if self.loss_type == "vanilla":
            g_loss = -torch.mean(torch.log(self.discriminator(self(z, fake_labels), fake_labels)))
        else:
            g_loss = self.adversarial_loss(self.discriminator(self(z, fake_labels), fake_labels), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs, labels), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self(z, fake_labels).detach(), fake_labels), fake)

        if self.loss_type == "vanilla":
            d_loss = -torch.mean(torch.log(self.discriminator(imgs, labels)) + torch.log(1. - fake))
        else:
            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        # Origninal vanilla implemtaton did not set ADAM betas so skip that here
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        # opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr)
        # opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
        return [opt_g, opt_d], []

In [224]:
dm = MNISTDataModule()
model = GAN(*dm.dims, loss_type = "BCE")
trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=5,
)
trainer.fit(model, dm)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 1.5 M 
1 | discriminator | Discriminator | 14.4 M
------------------------------------------------
15.9 M    Trainable params
0         Non-trainable params
15.9 M    Total params
63.783    Total estimated model params size (MB)


Epoch 0:   0%|          | 0/844 [97:42:10<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:40:16<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:32:28<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:31:32<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:31:04<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:30:47<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:30:09<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:29:33<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:28:54<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:28:41<?, ?it/s]
Epoch 0:   0%|          | 0/844 [97:24:17<?, ?it/s]
Epoch 0:   0%|          | 0/844 [96:47:00<?, ?it/s]
Epoch 4: 100%|██████████| 844/844 [00:40<00:00, 20.63it/s, v_num=140, g_loss=37.50, d_loss=31.20]  

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 844/844 [00:41<00:00, 20.53it/s, v_num=140, g_loss=37.50, d_loss=31.20]


In [226]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 68702), started 2 days, 18:34:53 ago. (Use '!kill 68702' to kill it.)