<a href="https://colab.research.google.com/github/Reennon/gen-ai-cv-lab-1/blob/main/notebooks/gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN CIFAR-10 Training and Experimentation

This notebook demonstrates training GAN model on the CIFAR-10 dataset using a modular training pipeline implemented in PyTorch Lightning.


In [1]:
!git clone https://github.com/Reennon/gen-ai-cv-lab-1.git
%cd gen-ai-cv-lab-1

Cloning into 'gen-ai-cv-lab-1'...
remote: Enumerating objects: 335, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 335 (delta 15), reused 28 (delta 8), pack-reused 299 (from 1)[K
Receiving objects: 100% (335/335), 326.76 KiB | 18.15 MiB/s, done.
Resolving deltas: 100% (152/152), done.
/content/gen-ai-cv-lab-1


In [2]:
!pip install -r requirements.txt

Collecting python-dotenv==1.0.1 (from -r requirements.txt (line 1))
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting pytorch-lightning==2.4.0 (from -r requirements.txt (line 3))
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting omegaconf==2.3.0 (from -r requirements.txt (line 6))
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting matplotlib==3.9.3 (from -r requirements.txt (line 7))
  Downloading matplotlib-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning==2.4.0->-r requirements.txt (line 3))
  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning==2.4.0->-r requirements.txt (line 3))
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf==2.3.0->-r requirements.t

In [3]:
!git pull

Already up to date.


In [4]:
import os
import dotenv
import wandb
import torch

from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from google.colab import userdata
from matplotlib import pyplot as plt

from src.visualization.gan_visualizer import GANVisualizer
from src.training.trainer import train_model
from src.models.gan import GAN


In [5]:
os.environ["WANDB_KEY"] = userdata.get("wandb_key")
!echo $WANDB_KEY >> .env

In [6]:
dotenv.load_dotenv()

True

In [7]:
parameters = OmegaConf.load("./params/gan.yml")
wandb.login(key=os.environ["WANDB_KEY"])

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [8]:
wandb_project_name = "cifar-10-gan"
device = "gpu"

In [9]:
# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Corrected for 3 channels
])

# Load datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:12<00:00, 13.5MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [10]:
dataiter = iter(train_loader)
images, labels = next(dataiter)
print(torch.min(images), torch.max(images))

tensor(-1.) tensor(1.)


In [11]:
images.shape

torch.Size([64, 3, 32, 32])

In [12]:
hparams = parameters.hyperparameters

In [13]:
# Edit hparams dict here as experiemnt, wandb will log the difference
# hparams["lr"] = 1e-4
# hparams["epochs"] = 20
# hparams['latent_dim'] = 48

parameters.run_parameters.experiment_name = f"gan-{hparams['latent_dim']}-lr-{hparams['lr']}-epochs-{hparams['epochs']}"

dict(hparams)

{'lr': 0.0002, 'epochs': 50, 'latent_dim': 100}

In [14]:
wandb.finish()

In [15]:
parameters

{'run_parameters': {'experiment_name': 'gan-100-lr-0.0002-epochs-50'}, 'hyperparameters': {'lr': 0.0002, 'epochs': 50, 'latent_dim': 100}, 'training': {'accelerator': 'gpu', 'devices': 1}}

In [None]:
# Train the Gan
model = train_model(GAN, parameters, train_loader, val_loader, wandb_project_name)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mrkovalch[0m ([33mrkovalchuk[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type          | Params | Mode 
--------------------------------------------------------
0 | generator     | Generator     | 3.9 M  | train
1 | discriminator | Discriminator | 1.7 M  | train
--------------------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.226    Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  z = Variable(Tensor(np.random.normal(0, 1, (z.shape[0], self.latent_dim))))


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
GAN(parameters.hyperparameters)

In [None]:
# Visualize metrics from wandb
from IPython.display import display
wandb_url = wandb.run.get_url()
display(f"Wandb Dashboard: {wandb_url}")


In [None]:
wandb.finish()

In [None]:
# Load the trained model
model.eval()

# Create the visualizer
visualizer = GANVisualizer(model, latent_dim=100)

# Visualize generated images
visualizer.visualize_generated_images(num_samples=16)