<a href="https://colab.research.google.com/github/Reennon/gen-ai-cv-lab-1/blob/main/notebooks/norm_flow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN with Norm Flow CIFAR-10 Training and Experimentation

This notebook demonstrates training GAN model on the CIFAR-10 dataset using a modular training pipeline implemented in PyTorch Lightning.


In [1]:
!git clone https://github.com/Reennon/gen-ai-cv-lab-1.git
%cd gen-ai-cv-lab-1

Cloning into 'gen-ai-cv-lab-1'...
remote: Enumerating objects: 377, done.[K
remote: Counting objects: 100% (78/78), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 377 (delta 37), reused 55 (delta 22), pack-reused 299 (from 1)[K
Receiving objects: 100% (377/377), 526.52 KiB | 2.06 MiB/s, done.
Resolving deltas: 100% (174/174), done.
/content/gen-ai-cv-lab-1


In [2]:
!pip install -r requirements.txt

Collecting python-dotenv==1.0.1 (from -r requirements.txt (line 1))
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting pytorch-lightning==2.4.0 (from -r requirements.txt (line 3))
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting omegaconf==2.3.0 (from -r requirements.txt (line 6))
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting matplotlib==3.9.3 (from -r requirements.txt (line 7))
  Downloading matplotlib-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning==2.4.0->-r requirements.txt (line 3))
  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning==2.4.0->-r requirements.txt (line 3))
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf==2.3.0->-r requirements.t

In [5]:
!git pull

remote: Enumerating objects: 15, done.[K
remote: Counting objects:   6% (1/15)[Kremote: Counting objects:  13% (2/15)[Kremote: Counting objects:  20% (3/15)[Kremote: Counting objects:  26% (4/15)[Kremote: Counting objects:  33% (5/15)[Kremote: Counting objects:  40% (6/15)[Kremote: Counting objects:  46% (7/15)[Kremote: Counting objects:  53% (8/15)[Kremote: Counting objects:  60% (9/15)[Kremote: Counting objects:  66% (10/15)[Kremote: Counting objects:  73% (11/15)[Kremote: Counting objects:  80% (12/15)[Kremote: Counting objects:  86% (13/15)[Kremote: Counting objects:  93% (14/15)[Kremote: Counting objects: 100% (15/15)[Kremote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects:  25% (1/4)[Kremote: Compressing objects:  50% (2/4)[Kremote: Compressing objects:  75% (3/4)[Kremote: Compressing objects: 100% (4/4)[Kremote: Compressing objects: 100% (4/4), done.[K
remote: Total 10 (delta 6), reused 10 (delta 6), pack-reused 0 (fr

In [6]:
import os
import dotenv
import wandb
import torch

from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from google.colab import userdata
from matplotlib import pyplot as plt

from src.visualization.gan_visualizer import GANVisualizer
from src.training.trainer import train_model
from src.models.norm_flow import GANWithFlow


In [8]:
os.environ["WANDB_KEY"] = userdata.get("wandb_key")
!echo $WANDB_KEY >> .env

In [9]:
dotenv.load_dotenv()

True

In [10]:
parameters = OmegaConf.load("./params/norm_flow.yml")
wandb.login(key=os.environ["WANDB_KEY"])

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [11]:
wandb_project_name = "cifar-10-norm-flow"
device = "gpu"

In [12]:
# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Corrected for 3 channels
])

# Load datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64 * 32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 43.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [13]:
dataiter = iter(train_loader)
images, labels = next(dataiter)
print(torch.min(images), torch.max(images))

tensor(-1.) tensor(1.)


In [14]:
images.shape

torch.Size([2048, 3, 32, 32])

In [15]:
hparams = parameters.hyperparameters

In [16]:
# Edit hparams dict here as experiemnt, wandb will log the difference
hparams["lambda_lp"] = 10.0
hparams["epochs"] = 150
hparams['latent_dim'] = 150

parameters.run_parameters.experiment_name = f"norm-flow-{hparams['latent_dim']}-{hparams['lambda_lp']}"

dict(hparams)

{'lr': 0.0002, 'epochs': 150, 'latent_dim': 150, 'lambda_lp': 10.0}

In [17]:
wandb.finish()

In [18]:
parameters

{'run_parameters': {'experiment_name': 'norm-flow-150-10.0'}, 'hyperparameters': {'lr': 0.0002, 'epochs': 150, 'latent_dim': 150, 'lambda_lp': 10.0}, 'training': {'accelerator': 'gpu', 'devices': 1}}

In [19]:
# Train the Gan
model = train_model(GANWithFlow, parameters, train_loader, val_loader, wandb_project_name)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Currently logged in as: [33mrkovalch[0m ([33mrkovalchuk[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | generator     | Generator       | 6.1 M  | train
1 | discriminator | Discriminator   | 663 K  | train
2 | norm_flow     | NormalizingFlow | 194 K  | train
----------------------------------------------------------
7.0 M     Trainable params
0         Non-trainable params
7.0 M     Total params
27.951    Total estimated model params size (MB)
80        Modules in train mode
0         Modules in eval mode
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (25) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:384: `ModelCheckpoint(monitor='val_loss')` could not find the monitored key in the returned metrics: ['d_loss', 'g_loss', 'epoch', 'step']. HINT: Did you call `log('val_loss', value)` in the `LightningModule`?
INFO:pytorch_lightning.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
GAN(parameters.hyperparameters)

In [None]:
# Visualize metrics from wandb
from IPython.display import display
wandb_url = wandb.run.get_url()
display(f"Wandb Dashboard: {wandb_url}")


In [None]:
wandb.finish()

In [None]:
# Load the trained model
model.eval()

# Create the visualizer
visualizer = GANVisualizer(model, latent_dim=100)

# Visualize generated images
visualizer.visualize_generated_images(num_samples=16)

In [None]:
from google.colab import runtime
runtime.unassign()