In [2]:
# IMPORTING LIBRARIES

import os

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

import matplotlib.pyplot as plt
import pytorch_lightning as pl

random_seed = 42
torch.manual_seed(random_seed)

BATCH_SIZE = 128
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS = int(os.cpu_count() / 2)

In [3]:
# CREATING DATA MODULE

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./data",
                batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTenor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # ASSIGNING TRAIN/VAL DATASETS
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # ASSIGNING TESTING DATASET
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return Dataloader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return Dataloader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [4]:
# DISCRIMINATOR: Detect if Data is Fake or Not -> 1 output [0, 1]

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # SIMPLE CNN
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        # FLATTEN TENSOR SO IT CAN BE FED INTO FC LAYERS
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return torch.sigmoid(x)

In [None]:
# GENERATOR: Generates Fake Data

class Generator(nn.module):
    def __init__(self, latent_dim):
        super().__init__()
        self.lin1 = nn.Linear(latent_dim, 7*7*64) # [n, 256, 7, 7]
        self.ct1 = nn.ConvTranspose2d(64, 32, 4, stride=2) # [n, 64, 16, 16]
        self.ct2 = nn.ConvTranspose2d(32, 16, 4, stride=2) # [n, 16, 34, 34]
        self.conv = nn.Conv2d(16, 1, kernel_size=7) # [n, 1, 28, 28]

    def forward(self, x):
        # PASS LATENT SPACE INPUT IN LINEAR LAYER AND RESHAPE
        x = self.lin1(x)
        x = F.relu(x)
        x = x.view(-1, 64, 7, 7) # 256

        # UPSAMPLE (TRANSPOSED CONV) 16X16 (64 FEATURE MAPS)
        x = self.ct1(x)
        x = F.relu(x)

        # UPSAMPLE TO 34X34 (16 FEATURE MAPS)
        x = self.ct2(x)
        x = F.relu(x)

        # CONVOLUTION TO 28X28 (1 FEATURE MAP)
        return self.conv(x)