<a href="https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/main/notebooks/notebook_train_flow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧙‍♂️ Training diffusion model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/master/notebooks/notebook_train_flow.ipynb)

### Initial setup ⚙️

In [1]:
!git fetch https://github.com/Jac-Zac/PML_DL_Final_Project.git

fatal: not a git repository (or any of the parent directories): .git


In [2]:
import os

repo_dir = "PML_DL_Final_Project"

if not os.path.exists(repo_dir):
    !git clone https://github.com/Jac-Zac/PML_DL_Final_Project.git
else:
    print(f"Repository '{repo_dir}' already exists. Skipping clone.")

Cloning into 'PML_DL_Final_Project'...
remote: Enumerating objects: 648, done.[K
remote: Counting objects: 100% (106/106), done.[K
remote: Compressing objects: 100% (67/67), done.[K
remote: Total 648 (delta 58), reused 69 (delta 39), pack-reused 542 (from 1)[K
Receiving objects: 100% (648/648), 2.09 MiB | 5.25 MiB/s, done.
Resolving deltas: 100% (382/382), done.


In [3]:
if os.path.isdir(repo_dir):
    %cd $repo_dir
    !pip install dotenv -q
else:
    print(f"Directory '{repo_dir}' not found. Please clone the repository first.")

/content/PML_DL_Final_Project


### 📦 Imports

In [4]:
import torch
import numpy as np

from src.train.train import train
from src.utils.data import get_dataloaders
from src.utils.plots import plot_image_grid
from src.utils.environment import get_device, set_seed, load_pretrained_model

from src.models.flow import FlowMatching

# Since on a notebook we can have nicer bars
from tqdm.notebook import tqdm as tqdm_notebook

### 🛠️ Configuration Parameters

In [5]:
epochs = 20
batch_size = 128
learning_rate = 2e-3
seed = 1337
checkpoint_path = "checkpoints/last.ckpt"
model_name = "unet"
method = "flow"  # or "flow"

### 🧪 Setup: Seed and Device

In [6]:
set_seed(seed)
device = get_device()
os.makedirs("checkpoints", exist_ok=True)

## 🧠 Model Training

#### 📥 Data Loading

In [7]:
# Returns DataLoaders that yield (image, timestep, label)
train_loader, val_loader = get_dataloaders(batch_size=batch_size)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 542kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.48MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.66MB/s]


#### Training

In [9]:
# NOTE: Currently assumes 10 classes are hardcoded
num_classes = 10
# HACK: Hard-coded number of classes (MNIST = 10)
model_kwargs = {
    "num_classes": 10,
    "out_channels": 1,
    "time_emb_dim": 128,
    # NOTE: Change time embedding to learned for flow which is more sensible
    "time_embedding_type": "mlp" if method == "flow" else "sinusoidal",
}


# NOTE: Instead of using train directly you can write here your custom traiing code
# You can take inspiration from train to see how the checkpoints are saved

# NOTE: You can also directly copy all the code from train a cell above this and modify it inside the notebook
# similarly to what was done for the Flow Matching Class

# But if you use it directly you can directly use model you have from the train


In [10]:
flow_model = train(
    num_epochs=epochs,
    device= device,
    dataloader = train_loader,
    val_loader = val_loader,
    use_wandb = True,
    checkpoint_path = checkpoint_path,
    model_kwargs = model_kwargs,
    method = method
)



WANDB_API_KEY environment variable not set. Please enter your WandB API key: d8dc3ad9caae0ce9504bfcf61be898d8fbdbef18


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjacopozac[0m ([33mjac-zac[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



Epoch 1/20




Train Loss: 0.8725 | Val Loss: 0.6424 | LR: 0.000994
New best model saved! Epoch 1, Val Loss: 0.6424

Epoch 2/20




Train Loss: 0.6032 | Val Loss: 0.5769 | LR: 0.000976
New best model saved! Epoch 2, Val Loss: 0.5769

Epoch 3/20




Train Loss: 0.5672 | Val Loss: 0.5625 | LR: 0.000946
New best model saved! Epoch 3, Val Loss: 0.5625

Epoch 4/20




Train Loss: 0.5531 | Val Loss: 0.5567 | LR: 0.000905
New best model saved! Epoch 4, Val Loss: 0.5567

Epoch 5/20




Train Loss: 0.5476 | Val Loss: 0.5292 | LR: 0.000855
New best model saved! Epoch 5, Val Loss: 0.5292

Epoch 6/20




KeyboardInterrupt: 

## 💡 Image Generation

#### 🛠️ Configuration Parameters

In [None]:
n_samples = 5     #number of classes I want to sample
save_dir = "samples"
max_steps = 1000
num_timesteps = 6

ckpt_path = "checkpoints/best_model.pth"  # or use your last checkpoint

In [None]:
from PIL import Image
import torchvision.utils as vutils
import matplotlib.pyplot as plt

# 💫 Create diffusion sampler
flow = FlowMatching(img_size=28, device=device)

plot_image_grid(
    flow_model,
    flow,
    num_intermediate=num_timesteps,
    n=n_samples,
    max_steps=max_steps,
    save_dir=save_dir,
    device=device,
    num_classes=num_classes,
)


out_path = os.path.join(save_dir, "all_samples_grid.png")
display(Image.open(out_path))