<a href="https://colab.research.google.com/github/Reennon/gen-ai-cv-lab-1/blob/main/notebooks/norm_flow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN with Norm Flow CIFAR-10 Training and Experimentation

This notebook demonstrates training GAN model on the CIFAR-10 dataset using a modular training pipeline implemented in PyTorch Lightning.


In [1]:
!git clone https://github.com/Reennon/gen-ai-cv-lab-1.git
%cd gen-ai-cv-lab-1

Cloning into 'gen-ai-cv-lab-1'...
remote: Enumerating objects: 377, done.[K
remote: Counting objects: 100% (78/78), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 377 (delta 37), reused 55 (delta 22), pack-reused 299 (from 1)[K
Receiving objects: 100% (377/377), 526.52 KiB | 2.06 MiB/s, done.
Resolving deltas: 100% (174/174), done.
/content/gen-ai-cv-lab-1


In [None]:
!pip install -r requirements.txt

Collecting python-dotenv==1.0.1 (from -r requirements.txt (line 1))
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting pytorch-lightning==2.4.0 (from -r requirements.txt (line 3))
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting omegaconf==2.3.0 (from -r requirements.txt (line 6))
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting matplotlib==3.9.3 (from -r requirements.txt (line 7))
  Downloading matplotlib-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning==2.4.0->-r requirements.txt (line 3))
  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning==2.4.0->-r requirements.txt (line 3))
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf==2.3.0->-r requirements.t

In [None]:
!git pull

In [None]:
import os
import dotenv
import wandb
import torch

from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from google.colab import userdata
from matplotlib import pyplot as plt

from src.visualization.gan_visualizer import GANVisualizer
from src.training.trainer import train_model
from src.models.norm_flow import GANWithFlow


In [None]:
os.environ["WANDB_KEY"] = userdata.get("wandb_key")
!echo $WANDB_KEY >> .env

In [None]:
dotenv.load_dotenv()

In [None]:
parameters = OmegaConf.load("./params/gan.yml")
wandb.login(key=os.environ["WANDB_KEY"])

In [None]:
wandb_project_name = "cifar-10-norm-flow"
device = "gpu"

In [None]:
# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Corrected for 3 channels
])

# Load datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64 * 32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)


In [None]:
dataiter = iter(train_loader)
images, labels = next(dataiter)
print(torch.min(images), torch.max(images))

In [None]:
images.shape

In [None]:
hparams = parameters.hyperparameters

In [None]:
# Edit hparams dict here as experiemnt, wandb will log the difference
hparams["lambda_lp"] = 10.0
hparams["epochs"] = 150
hparams['latent_dim'] = 150

parameters.run_parameters.experiment_name = f"norm-flow-{hparams['latent_dim']}-{hparams['lambda_lp']}"

dict(hparams)

In [None]:
wandb.finish()

In [None]:
parameters

In [None]:
# Train the Gan
model = train_model(GANWithFlow, parameters, train_loader, val_loader, wandb_project_name)


In [None]:
GAN(parameters.hyperparameters)

In [None]:
# Visualize metrics from wandb
from IPython.display import display
wandb_url = wandb.run.get_url()
display(f"Wandb Dashboard: {wandb_url}")


In [None]:
wandb.finish()

In [None]:
# Load the trained model
model.eval()

# Create the visualizer
visualizer = GANVisualizer(model, latent_dim=100)

# Visualize generated images
visualizer.visualize_generated_images(num_samples=16)

In [None]:
from google.colab import runtime
runtime.unassign()