# 🧙‍♂️ Training diffusion model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/master/notebooks/notebook_train.ipynb)

### Initial setup ⚙️

In [1]:
import os
import sys

repo_dir = "PML_DL_Final_Project"

def in_colab():
    # Colab sets this environment variable
    return 'COLAB_GPU' in os.environ

if in_colab():
    # In Colab: clone repo if not present
    if not os.path.exists(repo_dir):
        !git clone https://github.com/Jac-Zac/PML_DL_Final_Project.git
        os.chdir(repo_dir)
        # Install requirements quietly
        !pip install -r requirements.txt -q
    else:
        os.chdir(repo_dir)
        print(f"Repository '{repo_dir}' already exists. Skipping clone.")
else:
    # Local: assume repo is already cloned
    print(f"Local Run, make sure you are inside '{repo_dir}' with the latest updates (git pull).")
    print(f"Moving to root directory to have correct access to all of the files")
    os.chdir("..")

Local Run, make sure you are inside 'PML_DL_Final_Project' with the latest updates (git pull).
Moving to root directory to have correct access to all of the files


### 📦 Imports

In [12]:
import torch
from torch import nn, Tensor
import numpy as np

from src.utils.data import get_dataloaders
from src.models.unet import DiffusionUNet
from src.utils.environment import get_device, set_seed, load_pretrained_model

# Since on a notebook we can have nicer bars
import tqdm.notebook as tqdm

### 🛠️ Configuration Parameters

In [6]:
epochs = 20
batch_size = 128
learning_rate = 2e-3
seed = 1337
checkpoint = None  # e.g., "checkpoints/last.ckpt"
model_name = "unet"
method = "diffusion"  # or "flow"

### 🧪 Setup: Seed and Device

In [7]:
set_seed(seed)
device = get_device()
os.makedirs("checkpoints", exist_ok=True)

## 🧠 Model Training

#### 📥 Data Loading

In [8]:
# Returns DataLoaders that yield (image, timestep, label)
train_loader, val_loader = get_dataloaders(batch_size=batch_size)

#### Model

In [16]:
class Flow(nn.Module):
    def __init__(self, dim: int = 28*28, h: int = 256):
        super().__init__()
        # Replace the MLP with the UNet (DiffusionUNet) architecture
        self.model = DiffusionUNet(in_channels=1, out_channels=1)
    
    def forward(self, t: Tensor, x_t: Tensor) -> Tensor:
        # Reshape the flat image tensors back into image shapes
        x_images = x_t.view(-1, 1, 28, 28)
        out = self.model(x_images, t)
        return out.view(out.size(0), -1)  # flatten output to match dx_t shape

        return self.model(x_images, t)
    
    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor) -> Tensor:
        # Implement Euler's step using the UNet model
        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)
        mid_time = t_start + (t_end - t_start) / 2
        mid_x = x_t + self(t_start, x_t) * (t_end - t_start) / 2
        return x_t + (t_end - t_start) * self(mid_time, mid_x)


#### Training

In [None]:
flow = Flow()
optimizer = torch.optim.Adam(flow.parameters(), 1e-3)
loss_fn = nn.MSELoss()

for epoch in range(1):
    for images, _, _ in train_loader:
        x_1 = images.view(images.size(0), -1)
        x_0 = torch.randn_like(x_1)
        t = torch.rand(len(x_1), 1)
        x_t = (1 - t) * x_0 + t * x_1
        dx_t = x_1 - x_0
        optimizer.zero_grad()
        loss_fn(flow(t=t, x_t=x_t), dx_t).backward()
        optimizer.step()

print('Training completed')