# Autoencoder CIFAR-10 Training and Experimentation

This notebook demonstrates training Autoencoder model on the CIFAR-10 dataset using a modular training pipeline implemented in PyTorch Lightning.


In [None]:
!git clone !git clone https://github.com/Reennon/gen-ai-cv-lab-1.git
%cd gen-ai-cv-lab-1
!pip install -r requirements.txt

In [None]:
import os
import dotenv
import wandb
import torch

from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from src.visualization.base_visualizer import BaseVisualizer
from src.training.trainer import train_model
from src.models.autoencoder import Autoencoder


In [None]:
dotenv.load_dotenv()

In [None]:
parameters = OmegaConf.load("./params/autoencoder.yaml")
wandb.login(key=os.environ["WANDB_KEY"])

In [None]:
wandb_project_name = "cifar-10-vae"
device = "cuda:0"

In [None]:
# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    # Additional transforms like normalization can be added here
])

# Load datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)


In [None]:
hparams = parameters.hyperparameters

In [None]:
# Edit hparams dict here as experiemnt, wandb will log the difference
hparams["lr"] = 3e-3

dict(hparams)

In [None]:
# Train the Autoencoder
train_model(Autoencoder, hparams, train_loader, val_loader)


In [None]:
# Visualize metrics from wandb
from IPython.display import display
wandb_url = wandb.run.url
display(f"Wandb Dashboard: {wandb_url}")


In [None]:
# Load the model
model = Autoencoder(hparams)
model.load_from_checkpoint('path/to/best_checkpoint.ckpt')
model.eval()

# Visualize original and reconstructed images
visualizer = BaseVisualizer(model, val_loader)
visualizer.visualize_reconstructions()