# 🧙‍♂️ Training diffusion model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/master/notebooks/notebook_train_flow.ipynb)

### Initial setup ⚙️

In [None]:
!git fetch https://github.com/Jac-Zac/PML_DL_Final_Project.git

In [None]:
import os

repo_dir = "PML_DL_Final_Project"

if not os.path.exists(repo_dir):
    !git clone https://github.com/Jac-Zac/PML_DL_Final_Project.git
else:
    print(f"Repository '{repo_dir}' already exists. Skipping clone.")

In [None]:
if os.path.isdir(repo_dir):
    %cd $repo_dir
    !pip install dotenv -q
else:
    print(f"Directory '{repo_dir}' not found. Please clone the repository first.")

### 📦 Imports

In [None]:
import torch
import numpy as np

from src.train.train import train
from src.utils.data import get_dataloaders
from src.utils.plots import plot_image_grid
from src.utils.environment import get_device, set_seed, load_pretrained_model

from src.models.flow import FlowMatching

# Since on a notebook we can have nicer bars
from tqdm.notebook import tqdm as tqdm_notebook

### 🛠️ Configuration Parameters

In [None]:
epochs = 20
batch_size = 128
learning_rate = 2e-3
seed = 1337
checkpoint_path = "checkpoints/last.ckpt"
model_name = "unet"
method = "flow"  # or "flow"
dataset_name = "MNIST"

### 🧪 Setup: Seed and Device

In [None]:
set_seed(seed)
device = get_device()
os.makedirs("checkpoints", exist_ok=True)

## 🧠 Model Training

#### 📥 Data Loading

In [None]:
# Returns DataLoaders that yield (image, timestep, label)
train_loader, val_loader = get_dataloaders(batch_size=batch_size, dataset_name=dataset_name)

#### Training

In [None]:
# NOTE: Currently assumes 10 classes are hardcoded
num_classes = 10
# HACK: Hard-coded number of classes (MNIST = 10) / FashionMNIST
model_kwargs = {
    "num_classes": 10,
    "out_channels": 1,
    "time_emb_dim": 128,
    # NOTE: Change time embedding to learned for flow which is more sensible
    "time_embedding_type": "mlp" if method == "flow" else "sinusoidal",
}


In [None]:
flow_model = train(
    num_epochs=epochs,
    device= device,
    dataloader = train_loader,
    val_loader = val_loader,
    use_wandb = True,
    checkpoint_path = checkpoint_path,
    model_kwargs = model_kwargs,
    method = method
)

## 💡 Image Generation

#### 🛠️ Configuration Parameters

In [None]:
n_samples = 5     #number of classes I want to sample
save_dir = "samples"
num_steps = 10
num_intermediate = 10

ckpt_path = "checkpoints/best_model.pth"  # or use your last checkpoint

In [None]:
from PIL import Image
import os
import matplotlib.pyplot as plt

# Create sampler instance
flow = FlowMatching(img_size=28, device=device)

# Call the plotting function (note argument order!)
plot_image_grid(
    model=flow_model,
    method_instance=flow,
    n=n_samples,
    num_intermediate=num_intermediate,
    num_steps=num_steps,
    save_dir=save_dir,
    device=device,
    num_classes=num_classes,
)

# Display the saved image
out_path = os.path.join(save_dir, "all_samples_grid.png")
display(Image.open(out_path))