In [36]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)
BATCH_SIZE = 64
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS = int(os.cpu_count()/2)

In [2]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./data', batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,),(0.3081)),
            ]
        )

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        if stage=='fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000,5000])
        
        if stage=='test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [55]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)  # Assuming input images are 28x28
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 1)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the input
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.fc3(x))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.fc4(x)
        return torch.sigmoid(x)

In [56]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, 28*28)  # Assuming output images are 28x28

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = torch.tanh(self.fc4(x))  # Use tanh to output values between -1 and 1
        x = x.view(-1, 1, 28, 28)  # Reshape to image format
        return x


In [60]:
class GAN(pl.LightningModule):
    def __init__(self, latent_dim=100, lr=0.002):
        super().__init__()
        self.save_hyperparameters()
        self.generator = Generator(latent_dim=self.hparams.latent_dim)
        self.discriminator = Discriminator()

        # random noise 
        self.validation_z = torch.randn(6, self.hparams.latent_dim)
        self.automatic_optimization = False

    def forward(self, z):
        return self.generator(z)
    
    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch, batch_idx):
        real_imgs, _ = batch

        z = torch.randn(real_imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(real_imgs)

        opt_g, opt_d = self.optimizers()
        self.toggle_optimizer(opt_g)

        fake_imgs = self(z)
        y_hat = self.discriminator(fake_imgs)

        y = torch.ones(real_imgs.size(0), 1)
        y = y.type_as(real_imgs)

        g_loss = self.adversarial_loss(y_hat, y)
        opt_g.zero_grad()
        self.manual_backward(g_loss)
        opt_g.step()
        self.untoggle_optimizer(opt_g)
        
        log_dict = {"g_loss":g_loss}
        # return {"loss":g_loss, "progress_bar": log_dict, "log": log_dict}
        self.toggle_optimizer(opt_d)
        y_hat_real = self.discriminator(real_imgs)
        y_real = torch.ones(real_imgs.size(0), 1)
        y_real = y_real.type_as(real_imgs)

        real_loss = self.adversarial_loss(y_hat_real, y_real)
        y_hat_fake = self.discriminator(self(z).detach())
        y_fake = torch.zeros(real_imgs.size(0),1)
        y_fake = y_fake.type_as(real_imgs)

        fake_loss = self.adversarial_loss(y_hat_fake, y_fake)
        d_loss = (real_loss + fake_loss) / 2
        opt_d.zero_grad()
        self.manual_backward(d_loss)
        opt_d.step()
        self.untoggle_optimizer(opt_d)

        log_dict = {"d_loss":d_loss}
        # return {"loss":d_loss, "progress_bar": log_dict, "log": log_dict}

        self.log_dict({"d_loss":d_loss, "g_loss":g_loss})

    def configure_optimizers(self):
        lr = self.hparams.lr
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr)
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
        return [opt_g, opt_d], []
    
    def plot_imgs(self):
        z = self.validation_z.type_as(self.generator.lin1.weight)
        sample_imgs = self(z).cpu()
        print('Epoch ', self.current_epoch)
        for i in range(sample_imgs.size(0)):
            plt.subplot(2,3,i+1)
            plt.tight_layout()
            plt.imshow(sample_imgs.detach()[i,0,:,:].tolist(), cmap='gray_r', interpolation='none')
            plt.title('Generated Data')
            plt.xticks([])
            plt.yticks([])
            plt.axis('off')
        plt.show()
    
    def on_epoch_end(self):
        self.plot_imgs()

In [61]:
dm = MNISTDataModule()
model = GAN()
# model.plot_imgs()

In [62]:
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\robot\anaconda3\envs\gan_env\lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.

  | Name          | Type          | Params | Mode 
--------------------------------------------------------
0 | generator     | Generator     | 579 K  | train
1 | discriminator | Discriminator | 242 K  | train
--------------------------------------------------------
821 K     Trainable params
0         Non-trainable params
821 K     Total params
3.288     Total estimated model params size (MB)
c:\Users\robot\anaconda3\envs\gan_env\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 9: 100%|██████████| 860/860 [02:18<00:00,  6.19it/s, v_num=20]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 860/860 [02:18<00:00,  6.19it/s, v_num=20]
