In [17]:
import os

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl

In [18]:
random_seed = 42
torch.manual_seed(random_seed)

BATCH_SIZE = 128
GPU = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = int(os.cpu_count() / 2)

In [19]:
print(GPU)
print(NUM_WORKERS)

cpu
4


In [20]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, 
                 data_dir: str,
                 batch_size: int = BATCH_SIZE,
                 num_workers: int = NUM_WORKERS
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.137,), (0.3081, )),
            ]
        )
        self.dims = (1,28,28)
        self.num_classes = 10
    
    def prepare_data(self):
        # Downloading the datasets
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        
        # Assign test dataset for use in Dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )
        
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)
    

In [21]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize = True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize == True:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

In [22]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()
    
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        
        return validity

In [None]:
class GAN(pl.LightningModule):
    def __init__(self, latent_dim = 100, lr = 0.002):
        super().__init__()
        self.save_hyperparameters()
        
        self.generator = Generator(latent_dim=self.hparams.latent_dim)
        self.discriminator = Discriminator()
        
        # Random noise
        self.validation_z = torch.randn(6, self.hparams.latent_dim)
    
    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        pass
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        opt_g = torch.optim.Adam(self.generator.p)
        pass


In [25]:
dm = MNISTDataModule(data_dir="../data/")
model = GAN(*dm.dims)
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=5,
)
trainer.fit(model, dm)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name          | Type          | Params | Mode  | In sizes | Out sizes     
------------------------------------------------------------------------------------
0 | generator     | Generator     | 1.5 M  | train | [2, 100] | [2, 1, 28, 28]
1 | discriminator | Discriminator | 533 K  | train | ?        | ?             
------------------------------------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.174     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode


Epoch 4: 100%|██████████| 430/430 [00:13<00:00, 32.83it/s, v_num=1, g_loss=2.600, d_loss=0.0806] 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 430/430 [00:13<00:00, 32.72it/s, v_num=1, g_loss=2.600, d_loss=0.0806]
