<a href="https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/main/notebooks/notebook_train_flow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧙‍♂️ Training diffusion model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/master/notebooks/notebook_train_flow.ipynb)

### Initial setup ⚙️

In [2]:
import os

repo_dir = "PML_DL_Final_Project"

if not os.path.exists(repo_dir):
    !git clone https://github.com/Jac-Zac/PML_DL_Final_Project.git
else:
    print(f"Repository '{repo_dir}' already exists. Skipping clone.")

Cloning into 'PML_DL_Final_Project'...
remote: Enumerating objects: 488, done.[K
remote: Counting objects: 100% (218/218), done.[K
remote: Compressing objects: 100% (151/151), done.[K
remote: Total 488 (delta 130), reused 135 (delta 67), pack-reused 270 (from 1)[K
Receiving objects: 100% (488/488), 449.85 KiB | 13.63 MiB/s, done.
Resolving deltas: 100% (274/274), done.


In [3]:
if os.path.isdir(repo_dir):
    %cd $repo_dir
    !pip install dotenv -q
else:
    print(f"Directory '{repo_dir}' not found. Please clone the repository first.")

/content/PML_DL_Final_Project


### 📦 Imports

In [49]:
import torch
import numpy as np

from src.train.train import train
from src.utils.data import get_dataloaders
from src.utils.plots import plot_image_grid
from src.utils.environment import get_device, set_seed, load_pretrained_model

# train
from typing import Optional

import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from src.utils.environment import load_checkpoint
from src.utils.wandb import (
    initialize_wandb,
    log_epoch_metrics,
    log_sample_grid,
    log_training_step,
    save_best_model_artifact,
)
import src.utils.wandb as wandb



# Since on a notebook we can have nicer bars
from tqdm.notebook import tqdm as tqdm_notebook

## Flow Matching implementation

In [36]:
from typing import Optional

import torch
from torch import Tensor, nn


class FlowMatching:
    def __init__(self, img_size: int = 64, device: torch.device = torch.device("cpu")):
        self.img_size = img_size
        self.device = device

    def _sample_timesteps(self, batch_size: int) -> Tensor:
        return torch.rand(batch_size, device=self.device)

    def perform_training_step(
        self,
        model: nn.Module,
        x0: Tensor,  # noise ~ N(0,I)
        x1: Tensor,  # data in [-1,1]
        y: Optional[Tensor] = None,
    ) -> Tensor:
        B = x0.size(0)
        t = self._sample_timesteps(B)
        t4 = t.view(-1, 1, 1, 1)  # shape: [B, 1, 1, 1]
        x_t = (1 - t4) * x0 + t4 * x1  # linear OT path

        # True velocity & normalization
        dx = x1 - x0
        norm = dx.flatten(1).norm(dim=1, keepdim=True).view(-1, 1, 1, 1).clamp(min=1e-6)
        dx = dx / norm

        u_t = model(x_t, t, y=y)
        assert u_t.shape == dx.shape

        # Time-weighted MSE loss
        return ((u_t - dx).pow(2)).mean()

    @torch.no_grad()
    # def sample(
    #     self,
    #     model: nn.Module,
    #     x_init: Optional[Tensor] = None,
    #     steps: int = 100,
    #     y: Optional[Tensor] = None,
    #     log_intermediate: bool = False,
    #     t_sample_skip = 1,
    #     t_sample_times: Optional[list[int]] = None,
    # ) -> list[Tensor]:
    #     model.eval()
    #     B = (
    #         y.shape[0]
    #         if y is not None
    #         else (x_init.shape[0] if x_init is not None else 1)
    #     )
    #     C = x_init.shape[1] if x_init is not None else 1
    #     x_t = (
    #         x_init.to(self.device)
    #         if x_init is not None
    #         else torch.randn(B, C, self.img_size, self.img_size, device=self.device)
    #     )

    #     results = []
    #     dt = 1.0 / steps
    #     for i in range(steps):
    #         t = torch.full((B,), i / steps, device=self.device)
    #         u_t = model(x_t, t, y=y)
    #         x_t = x_t + u_t * dt
    #         if t_sample_times and i in t_sample_times:
    #             results.append(self.transform_sampled_image(x_t.clone()))
    #         elif i % t_sample_skip == 0:
    #             results.append(self.transform_sampled_image(x_t.clone()))

    #     results.append(self.transform_sampled_image(x_t))
    #     return results

    def sample(
        self,
        model: nn.Module,
        x_init: Optional[Tensor] = None,
        steps: int = 100,
        y: Optional[Tensor] = None,
        log_intermediate: bool = False,
        t_sample_times: Optional[list[int]] = None,
    ) -> list[Tensor]:
        model.eval()
        B = (
            y.shape[0]
            if y is not None
            else (x_init.shape[0] if x_init is not None else 1)
        )
        C = x_init.shape[1] if x_init is not None else 1
        x_t = (
            x_init.to(self.device)
            if x_init is not None
            else torch.randn(B, C, self.img_size, self.img_size, device=self.device)
        )

        results = []
        dt = 1.0 / steps
        for i in range(steps):
            t = torch.full((B,), i / steps, device=self.device)
            v = model(x_t, t, y=y)
            x_t = x_t + v * dt

            if log_intermediate and t_sample_times and i in t_sample_times:
                results.append(self.transform_sampled_image(x_t.clone()))

        return results

    @staticmethod
    def transform_sampled_image(image: Tensor) -> Tensor:
        return (image.clamp(-1, 1) + 1) / 2


### 🛠️ Configuration Parameters

In [12]:
epochs = 5
batch_size = 128
learning_rate = 2e-3
seed = 1337
checkpoint_path = "checkpoints/last.ckpt"
model_name = "unet"
method = "flow"  # or "flow"

### 🧪 Setup: Seed and Device

In [7]:
set_seed(seed)
device = get_device()
os.makedirs("checkpoints", exist_ok=True)

## 🧠 Model Training

#### 📥 Data Loading

In [8]:
# Returns DataLoaders that yield (image, timestep, label)
train_loader, val_loader = get_dataloaders(batch_size=batch_size)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 484kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.49MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.34MB/s]


#### Training

In [31]:
def train_one_epoch(
    model,
    dataloader,
    optimizer,
    method_instance,
    device,
    use_wandb,
    grad_clip: float = 1.0,
):
    model.train()
    total_loss = 0.0

    for images, labels in tqdm_notebook(dataloader, desc="Training", leave=False):
        images = images.to(device).mul_(2).sub_(1)  # transforms image values from [0,1] → [-1,1]
        y = labels.to(device)
        x0 = torch.randn_like(images)

        optimizer.zero_grad(set_to_none=True)
        loss = method_instance.perform_training_step(model=model, x0=x0, x1=images, y=y)
        loss.backward()

        if grad_clip is not None:
            clip_grad_norm_(model.parameters(), max_norm=grad_clip)

        optimizer.step()

        loss_val = loss.item()
        total_loss += loss_val

        if use_wandb:
            log_training_step(loss_val)

    return total_loss / max(1, len(dataloader))  #average loss per batch for the epoch


def validate(model, val_loader, method_instance, device):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for images, labels in tqdm_notebook(val_loader, desc="Validating", leave=False):
            images = images.to(device).mul_(2).sub_(1)
            y = labels.to(device)
            x0 = torch.randn_like(images)

            loss = method_instance.perform_training_step(model=model, x0=x0, x1=images, y=y)
            total_loss += loss.item()

    return total_loss / len(val_loader)


def train(
    num_epochs: int,
    device: torch.device,
    dataloader,
    val_loader,
    learning_rate: float = 1e-3,
    use_wandb: bool = False,
    checkpoint_path: Optional[str] = None,
    model_name: str = "unet",
    model_kwargs: Optional[dict] = None,
    method: str = "flow",
):
    model_kwargs = model_kwargs or {}

    # Initialize scheduler ahead of load_checkpoint
    dummy_optimizer = torch.optim.Adam([torch.zeros(1)], lr=learning_rate)
    scheduler = CosineAnnealingLR(dummy_optimizer, T_max=num_epochs, eta_min=1e-5)

    model, optimizer, _, start_epoch, best_val_loss = load_checkpoint(
        model_name=model_name,
        checkpoint_path=checkpoint_path,
        device=device,
        optimizer_class=torch.optim.Adam,
        optimizer_kwargs={"lr": learning_rate},
        model_kwargs=model_kwargs,
        scheduler=scheduler,
    )

    # Re-create scheduler with real optimizer if not loaded
    if not hasattr(scheduler, "optimizer") or scheduler.optimizer is not optimizer:
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)

    # Initialize method instance
    if method == "diffusion":
        # method_instance = Diffusion(img_size=28, device=device)
        print("Ti sei scordato di cambiare in flow matching")
    elif method == "flow":
        method_instance = FlowMatching(img_size=28, device=device)
    else:
        raise ValueError(f"Unsupported method: {method}")

    wandb_run = None
    if use_wandb:
        wandb_run = initialize_wandb(
            project="flow-project",
            config={
                "epochs": num_epochs,
                "lr": learning_rate,
                "model": model_name,
                "num_classes": model_kwargs.get("num_classes"),
                "method": method,
            },
        )

    for epoch in range(start_epoch, num_epochs + 1):
        print(f"\nEpoch {epoch}/{num_epochs}")

        train_loss = train_one_epoch(
            model, dataloader, optimizer, method_instance, device, use_wandb
        )
        val_loss = validate(model, val_loader, method_instance, device)

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]

        print(
            f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}"
        )

        # Check if this is the best model
        is_best = val_loss < best_val_loss
        if is_best:
            best_val_loss = val_loss

        if use_wandb:
            log_epoch_metrics(epoch, train_loss, val_loss, current_lr)
            log_sample_grid(model, method_instance, num_samples=5, num_timesteps=6)

            # Save best model artifact when we find a new best
            if is_best:
                save_best_model_artifact(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    epoch=epoch,
                    val_loss=val_loss,
                    train_loss=train_loss,
                )

    if wandb_run:
        wandb_run.finish()

    print(f"\nTraining complete. Best validation loss: {best_val_loss:.4f}")
    return model

In [22]:
# NOTE: Currently assumes 10 classes are hardcoded
num_classes = 10
model_kwargs = {"num_classes": num_classes}


# NOTE: Instead of using train directly you can write here your custom traiing code
# You can take inspiration from train to see how the checkpoints are saved

# NOTE: You can also directly copy all the code from train a cell above this and modify it inside the notebook
# similarly to what was done for the Flow Matching Class

# But if you use it directly you can directly use model you have from the train


In [33]:
flow_model = train(
    num_epochs=epochs,
    device= device,
    dataloader = train_loader,
    val_loader = val_loader,
    use_wandb = True,
    checkpoint_path = checkpoint_path,
    model_kwargs = model_kwargs,
    method = method
)



WANDB_API_KEY environment variable not set. Please enter your WandB API key: 2b5251f5decb927fd6bf99b552a2fa2175bd8f98


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc



Epoch 1/5


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validating:   0%|          | 0/79 [00:00<?, ?it/s]

Train Loss: 0.0006 | Val Loss: 0.0002 | LR: 0.000905
New best model saved! Epoch 1, Val Loss: 0.0002

Epoch 2/5


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validating:   0%|          | 0/79 [00:00<?, ?it/s]

Train Loss: 0.0002 | Val Loss: 0.0002 | LR: 0.000658
New best model saved! Epoch 2, Val Loss: 0.0002

Epoch 3/5


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validating:   0%|          | 0/79 [00:00<?, ?it/s]

Train Loss: 0.0002 | Val Loss: 0.0002 | LR: 0.000352
New best model saved! Epoch 3, Val Loss: 0.0002

Epoch 4/5


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validating:   0%|          | 0/79 [00:00<?, ?it/s]

Train Loss: 0.0002 | Val Loss: 0.0001 | LR: 0.000105
New best model saved! Epoch 4, Val Loss: 0.0001

Epoch 5/5


Training:   0%|          | 0/469 [00:00<?, ?it/s]

Validating:   0%|          | 0/79 [00:00<?, ?it/s]

Train Loss: 0.0001 | Val Loss: 0.0001 | LR: 0.000010
New best model saved! Epoch 5, Val Loss: 0.0001


0,1
epoch,▁▃▅▆█
learning_rate,█▆▄▂▁
train/loss_epoch,█▂▁▁▁
train/loss_step,█▇▆▄▄▄▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▂▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁
val/loss,█▃▂▁▁

0,1
best_train_loss,0.00015
best_val_epoch,5.0
best_val_loss,0.00014
best_val_lr,1e-05
epoch,5.0
learning_rate,1e-05
train/loss_epoch,0.00015
train/loss_step,0.00012
val/loss,0.00014



Training complete. Best validation loss: 0.0001


## 💡 Image Generation

#### 🛠️ Configuration Parameters

In [53]:
def log_sample_grid(
    model, diffusion, num_samples=5, num_timesteps=6, max_timesteps=1000
):
    """Generate and log sample grid showing diffusion process."""
    t_sample_times = torch.linspace(
        max_timesteps, 0, steps=num_timesteps, dtype=torch.int32
    ).tolist()

    # Create batch of conditioning labels [0, 1, 2, ..., num_samples-1]
    y = torch.arange(num_samples, device=diffusion.device)

    # Sample all at once in batch
    all_samples_grouped = diffusion.sample(
        model,
        t_sample_times=t_sample_times,
        log_intermediate=True,
        y=y,
    )
    # all_samples_grouped shape: (T, B, C, H, W)

    # Rearrange to (B, T, C, H, W)
    stacked = torch.stack(all_samples_grouped)  # (T, B, C, H, W)
    permuted = stacked.permute(1, 0, 2, 3, 4)  # (B, T, C, H, W)

    # For each sample in batch, create a horizontal grid of its timesteps
    rows = []
    for sample_idx in range(num_samples):
        sample_images = permuted[sample_idx]  # (T, C, H, W)
        row_grid = vutils.make_grid(
            sample_images, nrow=num_timesteps, normalize=True, value_range=(-1, 1)
        )
        rows.append(row_grid)

    # Stack rows vertically
    final_grid = torch.cat(rows, dim=1)

    # Log to wandb
    wandb.log(
        {
            "sampling_intermediate_steps": wandb.Image(
                final_grid.permute(1, 2, 0).cpu().numpy()
            )
        }
    )


In [51]:
n_samples = 5
save_dir = "samples"
max_steps = 1000
model_name = "unet"
num_timesteps = 6

ckpt_path = "checkpoints/best_model.pth"  # or use your last checkpoint

In [52]:
from PIL import Image
import torchvision.utils as vutils

# flow = FlowMatching(img_size=28, device = device)

# 💫 Create diffusion sampler
flow = FlowMatching(img_size=28, device=device)

log_sample_grid(model= flow_model, diffusion=flow, num_samples=n_samples, num_timesteps=num_timesteps, max_timesteps=max_steps)
# plot_image_grid(
#     flow_model,
#     flow,
#     n=n_samples,
#     max_steps=max_steps,
#     save_dir=save_dir,
#     device=device,
#     num_classes=num_classes,
# )

# out_path = os.path.join(save_dir, "all_samples_grid.png")
# display(Image.open(out_path))



# NOTE: Currently assumes 10 classes are hardcoded
#see in to plot_image to make it for flow amtching

# plot_image_grid(
#     model,
#     flow,
#     n=n_samples,
#     max_steps=max_steps,
#     save_dir=save_dir,
#     device=device
#     num_classes=num_classes
# )

# But if you use it directly you can directly use model you have from the train

# model = train(
# num_epochs=epochs,
# device=device,
# dataloader=train_loader,
# val_loader=val_loader,
# learning_rate=learning_rate,
# use_wandb=True,
# checkpoint_path=checkpoint,
# model_name=model_name,
# model_kwargs=model_kwargs,
# method=method,
# )


AttributeError: module 'src.utils.wandb' has no attribute 'log'