<a href="https://colab.research.google.com/github/FurqanBhat/GANs/blob/main/WGANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorch-lightning

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch.autograd as autograd

import os




In [None]:
random_seed=42
torch.manual_seed(random_seed)

BATCH_SIZE=128
AVAIL_GPUS=min(1,torch.cuda.device_count())
NUM_WORKERS=int(os.cpu_count()/2)

In [None]:
class MNISTDataModule(pl.LightningDataModule):

  def __init__(self, data_dir="./data", batch_size=BATCH_SIZE,
               num_workers=NUM_WORKERS):
    super().__init__()
    self.data_dir=data_dir
    self.batch_size=batch_size
    self.num_workers=num_workers

    self.transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.1307,),(0.3081,)),
        ]
    )


  def prepare_data(self):
    MNIST(self.data_dir, train=True, download=True)
    MNIST(self.data_dir, train=False, download=True)


  def setup(self, stage=None):
    #assign trail/val datasets
    if stage == "fit" or stage is None:
      mnist_full=MNIST(self.data_dir, train=True, transform=self.transform)
      self.mnist_train, self.mnist_val=random_split(mnist_full, [55000,5000])

    #assign test datasets
    if stage =="test" or stage is None:
      self.mnist_test=MNIST(self.data_dir, train=False, transform=self.transform)


  def train_dataloader(self):
    return DataLoader(self.mnist_train, self.batch_size, num_workers=self.num_workers)

  def val_dataloader(self):
    return DataLoader(self.mnist_val, self.batch_size, num_workers=self.num_workers)

  def test_dataloader(self):
    return DataLoader(self.mnist_test, self.batch_size, num_workers=self.num_workers)



In [None]:
class Discriminator(nn.Module):

  def __init__(self):
    super().__init__()
    #simple cnn
    self.conv1=nn.Conv2d(1,10,kernel_size=5)
    self.conv2=nn.Conv2d(10,20, kernel_size=5)
    self.conv2_drop=nn.Dropout2d()
    self.fc1=nn.Linear(320,50)
    self.fc2=nn.Linear(50,1)

  def forward(self, x):
    x=F.relu(F.max_pool2d(self.conv1(x),2)) #(28,28)->(24,24)->(12,12)
    x=F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2)) #(12,12)->(8,8)->(4,4)
    x=x.view(-1, 320)
    x=F.relu(self.fc1(x))
    x=F.dropout(x, training=self.training)
    x=self.fc2(x)
   # return torch.sigmoid(x)
    return x #for WGAN



In [None]:
# Generate Fake Data: output like real data [1, 28, 28] and values -1, 1
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.lin1 = nn.Linear(latent_dim, 7*7*64)  # [n, 64, 7, 7]
        self.ct1 = nn.ConvTranspose2d(64, 32, 4, stride=2) # [n, 64, 16, 16]
        self.ct2 = nn.ConvTranspose2d(32, 16, 4, stride=2) # [n, 16, 34, 34]
        self.conv = nn.Conv2d(16, 1, kernel_size=7)  # [n, 1, 28, 28]


    def forward(self, x):
        # Pass latent space input into linear layer and reshape
        x = self.lin1(x)
        x = F.relu(x)
        x = x.view(-1, 64, 7, 7)

        # Upsample (transposed conv) 16x16 (64 feature maps)
        x = self.ct1(x) #(n, 64, 7, 7)->(n,32, 16, 16)
        x = F.relu(x)

        # Upsample to 34x34 (16 feature maps)
        x = self.ct2(x) #(n,32, 16, 16)->(n,16,34,34)
        x = F.relu(x)

        # Convolution to 28x28 (1 feature map)
        return self.conv(x) #(n,16,34,34)->(n,1,28,28)

In [None]:
class WGAN(pl.LightningModule):

    def __init__(self, latent_dim=100, lr=0.0001, gp_lambda=10): # Reduced LR for WGAN
        super().__init__()

        self.save_hyperparameters()

        self.generator = Generator(latent_dim=self.hparams.latent_dim)
        self.discriminator = Discriminator() # Discriminator is now a critic

        self.validation_z = torch.rand(6, self.hparams.latent_dim)

        self.automatic_optimization = False

    # Implement the gradient penalty
    def compute_gradient_penalty(self, real_samples, fake_samples):
        alpha = torch.randn((real_samples.size(0), 1, 1, 1), device=self.device)
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = self.discriminator(interpolates)
        gradients = autograd.grad(outputs=d_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(d_interpolates.size(), device=self.device),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def forward(self, z):
      return self.generator(z)

    def training_step(self, batch, batch_idx):
        optim_g, optim_d = self.optimizers()

        real_imgs, _ = batch
        batch_size = real_imgs.size(0)

        # Train Discriminator (Critic)
        optim_d.zero_grad()
        z = torch.randn(batch_size, self.hparams.latent_dim, device=self.device)
        fake_imgs = self(z).detach()

        real_validity = self.discriminator(real_imgs)
        fake_validity = self.discriminator(fake_imgs)

        # Wasserstein loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)

        # Gradient penalty
        gradient_penalty = self.compute_gradient_penalty(real_imgs.data, fake_imgs.data)
        d_loss = d_loss + self.hparams.gp_lambda * gradient_penalty

        self.manual_backward(d_loss)
        optim_d.step()

        # Train Generator
        # Only train the generator every N discriminator steps
        if batch_idx % 5 == 0: # Example: train generator every 5 discriminator steps
          optim_g.zero_grad()
          z = torch.randn(batch_size, self.hparams.latent_dim, device=self.device)
          fake_imgs = self(z)
          fake_validity = self.discriminator(fake_imgs)

          # Generator loss (minimize -D(fake_images))
          g_loss = -torch.mean(fake_validity)

          self.manual_backward(g_loss)
          optim_g.step()

          self.log_dict({"g_loss": g_loss}, prog_bar=True)


        self.log_dict({"d_loss": d_loss}, prog_bar=True)


    def configure_optimizers(self):
        lr = self.hparams.lr
        optim_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.9))
        optim_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
        return [optim_g, optim_d], []

    def plot_imgs(self):
        print("plot_imgs called")
        z = self.validation_z.type_as(self.generator.lin1.weight)
        sample_imgs = self.forward(z).cpu()

        print(f"epoch {self.current_epoch}")
        fig = plt.figure()
        for i in range(sample_imgs.size(0)):
            plt.subplot(2, 3, i + 1)
            plt.tight_layout()
            # Slice to remove the channel dimension
            plt.imshow(sample_imgs.detach()[i, 0, :, :], cmap="gray_r", interpolation='none')
            plt.title('Generated Data')
            plt.xticks([])
            plt.yticks([])
            plt.axis('off')

        plt.show()

    def on_epoch_end(self):
        print("onepcohcend")
        self.plot_imgs()

    def on_train_epoch_end(self):
        print("training_epoch_end called")
        self.on_epoch_end()

In [None]:
dm=MNISTDataModule()
model2=WGAN()

In [None]:
model2.plot_imgs()

In [None]:
trainer=pl.Trainer(max_epochs=40, accelerator="gpu", devices=AVAIL_GPUS)

In [None]:
trainer.fit(model2, dm)
