<a href="https://colab.research.google.com/github/Expan75/team20-adverserial-artists/blob/dev-mnist/inspo_GAN_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## GAN starter code
Corresponding tutorial: [https://youtu.be/_pIMdDWK5sc](https://youtu.be/_pIMdDWK5sc)

In [None]:
!pip uninstall pytorch_lightning torchmetrics -y
!pip install pytorch_lightning torchmetrics

Found existing installation: pytorch-lightning 2.2.2
Uninstalling pytorch-lightning-2.2.2:
  Successfully uninstalled pytorch-lightning-2.2.2
Found existing installation: torchmetrics 1.3.2
Uninstalling torchmetrics-1.3.2:
  Successfully uninstalled torchmetrics-1.3.2
Collecting pytorch_lightning
  Using cached pytorch_lightning-2.2.2-py3-none-any.whl (801 kB)
Collecting torchmetrics
  Using cached torchmetrics-1.3.2-py3-none-any.whl (841 kB)
Installing collected packages: torchmetrics, pytorch_lightning
Successfully installed pytorch_lightning-2.2.2 torchmetrics-1.3.2


In [None]:
import os

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

import matplotlib.pyplot as plt

import pytorch_lightning as pl


random_seed = 42
torch.manual_seed(random_seed)

BATCH_SIZE=128
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS=int(os.cpu_count() / 2)

In [None]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./data",
                 batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
# Detective: fake or no fake -> 1 output [0, 1]
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # Simple CNN
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        # Flatten the tensor so it can be fed into the FC layers
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return torch.sigmoid(x)

In [None]:
# Generate Fake Data: output like real data [1, 28, 28] and values -1, 1
class Generator(nn.Module):
  def __init__(self, image_channels=1, latent_dim=100, image_size=28):
        super(Generator, self).__init__()
        # Define image processing pathway
        self.image_path = nn.Sequential(
            nn.Conv2d(image_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

        # Define noise processing pathway
        self.noise_path = nn.Sequential(
            nn.Linear(latent_dim, image_size * image_size),
            nn.ReLU(),
            nn.Unflatten(1, (16, image_size, image_size))
        )

        # Combine image and noise
        self.combine_path = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, image_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Use Tanh to keep the output in the range of [-1, 1] for normalized images
        )

  def forward(self, x, z):
        img_features = self.image_path(x)
        noise_features = self.noise_path(z)
        combined = torch.cat((img_features, noise_features), dim=1)
        output = self.combine_path(combined)
        return output

In [None]:
# GAN
class GAN(pl.LightningModule):
  def __init__(self, latent_dim = 100, lr = 0.0002):
    super().__init__()
    self.save_hyperparameters()

    self.generator = Generator(latent_dim = self.hparams.latent_dim)
    self.discriminator = Discriminator()

    #random noise
    self.validation_z = torch.randn(6, self.hparams.latent_dim)

  def forward(self, img, z):
    return self.generator(img, z)

  def adversarial_loss(self, y_hat, y):
    return F.binary_cross_entropy(y_hat, y)

  def training_step(self, batch, batch_idx, optimizer_idx):
    real_imgs, _ = batch

    #sample noise
    z = torch.randn(real_imgs.shape[0], self.hparams.latent_dim)
    z = z.type_as(real_imgs)

    #train generator: max log(D(G(z)))
    if optimizer_idx == 0:
      fake_imgs = self(z)
      y_hat = self.discriminator(fake_imgs)

      y = torch.ones(real_imgs.size(0), 1).type_as(real_imgs)
      y = y.type_as(real_imgs)

      g_loss = self.adversarial_loss(y_hat, y)

      log_dict = {"g_loss": g_loss}
      return {"loss": g_loss, "progress_bar": log_dict, "log": log_dict}

    #train disciminator: max log (D(x)) + log(1-D(G(z)))
    if optimizer_idx == 1:

      # how well can it label as real
      y_hat_real = self.discriminator(real_imgs)

      y_real = torch.ones(real_imgs.size(0), 1).type_as(real_imgs)
      #y_real = y_real.type_as(real_imgs)

      real_loss = self.adversarial_loss(y_hat_real, y_real)

      # how well can it label as fake
      fake_imgs = self(real_imgs,z).detach()
      y_hat_fake = self.discriminator(fake_imgs)

      y_fake = torch.zeros(real_imgs.size(0),1).type_as(real_imgs)
      #y_fake = y_fake.type_as(real_imgs)

      fake_loss = self.adversarial_loss(y_hat_fake, y_fake)

      d_loss = (real_loss + fake_loss) / 2
      log_dict = {"d_loss": d_loss}
      return {"loss": d_loss, "progress_bar": log_dict, "log": log_dict}

  def configure_optimizers(self):
    lr = self.hparams.lr
    opt_g = torch.optim.Adam(self.generator.parameters(), lr = lr)
    opt_d = torch.optim.Adam(self.discriminator.parameters(), lr = lr)
    return [opt_g, opt_d], []

  def plot_imgs(self):
    z = self.validation_z.type_as(self.generator.lin1.weight)
    sample_imgs = self(z).cpu()

    print('epoch', self.current_epoch)
    fig = plt.figure()
    for i in range(sample_imgs.size(0)):
      plt.subplot(2, 3, i+1)
      plt.tight_layout()
      plt.imshow(sample_imgs.detach()[i, 0, :, :], cmap = 'gray', interpolation = 'none')
      plt.title("Generated Data")
      plt.xticks([])
      plt.yticks([])
      plt.axis('off')
    plt.show()

    '''
    # Fetch a batch of real images from the validation dataloader
    real_images, _ = next(iter(self.dm.val_dataloader()))
    real_images = real_images.to(self.device)

    # Ensure noise is of the correct type
    z = self.validation_z.type_as(real_images)

    # Generate images using both real images and noise
    sample_imgs = self(real_images, z).cpu()

    print('Epoch:', self.current_epoch)
    fig = plt.figure(figsize=(10, 5))  # Set figure size
    for i in range(min(6, sample_imgs.size(0))):  # Limit to display only up to 6 images
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(sample_imgs[i, 0, :, :], cmap='gray', interpolation='none')
        plt.title("Generated Data")
        plt.xticks([])
        plt.yticks([])
        plt.axis('off')
    plt.show()
'''

  def on_epoch_end(self):
    validation_imgs, _ = next(iter(self.val_dataloader()))
    sample_imgs = self(validation_imgs, self.validation_z.type_as(validation_imgs)).cpu()
    self.plot_imgs(sample_imgs)




In [None]:
dm = MNISTDataModule(data_dir='path_to_data', batch_size=64, num_workers=4)
#dm.setup(stage='fit')
model = GAN()
model.datamodule = dm
#model.plot_imgs()

In [None]:
trainer = pl.Trainer(max_epochs = 20)
trainer.fit(model, dm)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to path_to_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 178668107.35it/s]


Extracting path_to_data/MNIST/raw/train-images-idx3-ubyte.gz to path_to_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to path_to_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 114494984.71it/s]


Extracting path_to_data/MNIST/raw/train-labels-idx1-ubyte.gz to path_to_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to path_to_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 82263487.53it/s]

Extracting path_to_data/MNIST/raw/t10k-images-idx3-ubyte.gz to path_to_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to path_to_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 20820250.02it/s]


Extracting path_to_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to path_to_data/MNIST/raw



INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


RuntimeError: Training with multiple optimizers is only supported with manual optimization. Remove the `optimizer_idx` argument from `training_step`, set `self.automatic_optimization = False` and access your optimizers in `training_step` with `opt1, opt2, ... = self.optimizers()`.

In [None]:
# Set number of epochs
num_epochs = 20

# Training loop
for epoch in range(num_epochs):
    # Set model to training mode
    model.train()

    # Iterate over the training dataset
    for batch_idx, batch in enumerate(dm.train_dataloader()):
        # Forward pass
        loss_dict = model.training_step(batch, batch_idx, 0)  # 0 for generator optimizer_idx

        # Backward pass
        model.optimizer.zero_grad()
        loss_dict['loss'].backward()
        model.optimizer.step()

        # Print progress
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Step {batch_idx+1}/{len(dm.train_dataloader())}, '
                  f"Generator Loss: {loss_dict['loss'].item():.4f}")


AttributeError: 'MNISTDataModule' object has no attribute 'mnist_train'