# 🧙‍♂️ Training diffusion model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/master/notebooks/notebook_train_diff.ipynb)

### Initial setup ⚙️

In [None]:
import os

repo_dir = "PML_DL_Final_Project"

if not os.path.exists(repo_dir):
    !git clone https://github.com/Jac-Zac/PML_DL_Final_Project.git
else:
    print(f"Repository '{repo_dir}' already exists. Skipping clone.")

In [None]:
if os.path.isdir(repo_dir):
    %cd $repo_dir
    !pip install dotenv -q
else:
    print(f"Directory '{repo_dir}' not found. Please clone the repository first.")

### 📦 Imports

In [None]:
import torch
import numpy as np

from src.train.train import train
from src.utils.data import get_dataloaders
from src.models.diffusion import Diffusion
from src.utils.plots import plot_image_grid
from src.utils.environment import get_device, set_seed, load_pretrained_model

# Since on a notebook we can have nicer bars
import tqdm.notebook as tqdm

### 🛠️ Configuration Parameters

In [None]:
epochs = 20
batch_size = 128
learning_rate = 2e-3
seed = 1337
checkpoint = None  # e.g., "checkpoints/last.ckpt"
model_name = "unet"
method = "diffusion"  # or "flow"

### 🧪 Setup: Seed and Device

In [None]:
set_seed(seed)
device = get_device()
os.makedirs("checkpoints", exist_ok=True)

## 🧠 Model Training

#### 📥 Data Loading

In [None]:
# Returns DataLoaders that yield (image, timestep, label)
train_loader, val_loader = get_dataloaders(batch_size=batch_size)

#### Training

In [None]:
# NOTE: Currently assumes 10 classes are hardcoded
num_classes = 10
model_kwargs = {"num_classes": num_classes}

model = train(
    num_epochs=epochs,
    device=device,
    dataloader=train_loader,
    val_loader=val_loader,
    learning_rate=learning_rate,
    use_wandb=True,
    checkpoint_path=checkpoint,
    model_name=model_name,
    model_kwargs=model_kwargs,
    method=method,
)

## 💡 Image Generation

#### 🛠️ Configuration Parameters

In [None]:
n_samples = 5
save_dir = "samples"
num_steps = 1000
model_name = "unet"
ckpt_path = "checkpoints/best_model.pth"  # or use your last checkpoint

#### 💨 Initialize Diffusion Process

In [None]:
from PIL import Image

# 💫 Create diffusion sampler
diffusion = Diffusion(img_size=28, device=device)
plot_image_grid(
    model,
    diffusion,
    n=n_samples,
    num_steps=num_steps,
    save_dir=save_dir,
    device=device,
    num_classes=num_classes,
    num_intermediate=5
)

# Display the images on the notebook
out_path = os.path.join(save_dir, "all_samples_grid.png")
display(Image.open(out_path))