# üßô‚Äç‚ôÇÔ∏è Sample Generation with Pretrained Model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/master/notebook.ipynb)

### Initial setup ‚öôÔ∏è

In [8]:
import os

repo_dir = "PML_DL_Final_Project"

if not os.path.exists(repo_dir):
    !git clone https://github.com/Jac-Zac/PML_DL_Final_Project.git
else:
    print(f"Repository '{repo_dir}' already exists. Skipping clone.")

Cloning into 'PML_DL_Final_Project'...
remote: Enumerating objects: 298, done.[K
remote: Counting objects: 100% (298/298), done.[K
remote: Compressing objects: 100% (209/209), done.[K
remote: Total 298 (delta 160), reused 218 (delta 86), pack-reused 0 (from 0)[K
Receiving objects: 100% (298/298), 80.87 KiB | 637.00 KiB/s, done.
Resolving deltas: 100% (160/160), done.


In [9]:
if os.path.isdir(repo_dir):
    %cd $repo_dir
    !pip install dotenv -q
else:
    print(f"Directory '{repo_dir}' not found. Please clone the repository first.")

/content/PML_DL_Final_Project/PML_DL_Final_Project


### üì¶ Imports

In [10]:
import torch
import numpy as np

from src.train.train import train
from src.utils.data import get_dataloaders
from src.models.diffusion import Diffusion
from src.utils.plots import plot_image_grid
from src.utils.environment import get_device, set_seed, load_pretrained_model

# Since on a notebook we can have nicer bars
import tdqm.notebook as tdqm

### üõ†Ô∏è Configuration Parameters

In [11]:
epochs = 10
batch_size = 128
learning_rate = 1e-3
seed = 1337
checkpoint = None  # e.g., "checkpoints/last.ckpt"
model_name = "unet"
method = "diffusion"  # or "flow"

### üß™ Setup: Seed and Device

In [12]:
set_seed(seed)
device = get_device()
os.makedirs("checkpoints", exist_ok=True)

## üß† Model Training

#### üì• Data Loading

In [13]:
train_loader, val_loader = get_dataloaders(batch_size=batch_size)

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9.91M/9.91M [00:00<00:00, 17.4MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 28.9k/28.9k [00:00<00:00, 501kB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.65M/1.65M [00:00<00:00, 4.48MB/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4.54k/4.54k [00:00<00:00, 10.9MB/s]


#### Training

In [None]:
# NOTE: Currently assumes 10 classes are hardcoded
num_classes = 10
model_kwargs = {"num_classes": num_classes}

model = train(
    num_epochs=epochs,
    device=device,
    dataloader=train_loader,
    val_loader=val_loader,
    learning_rate=learning_rate,
    use_wandb=True,
    checkpoint_path=checkpoint,
    model_name=model_name,
    model_kwargs=model_kwargs,
    method=method,
)

## üí° Image Generation

#### üõ†Ô∏è Configuration Parameters

In [None]:
n_samples = 5
save_dir = "samples"
max_steps = 1000
model_name = "unet"
ckpt_path = "checkpoints/best_model.pth"  # or use your last checkpoint

#### üîå Load Pretrained Model

In [None]:
num_classes = 10  # üî¢ Total number of class labels (e.g., digits 0‚Äì9 for MNIST)

model_kwargs = {
    "num_classes": num_classes,
    "time_emb_dim": 128,  # Must match training config
}

model = load_pretrained_model(
    model_name=model_name,
    ckpt_path=ckpt_path,
    device=device,
    model_kwargs=model_kwargs,
)

#### üí® Initialize Diffusion Process

In [None]:
# üí´ Create diffusion sampler
diffusion = Diffusion(img_size=28, device=device)

# üïí Define intermediate steps to visualize progression
num_intermediate = 5
intermediate_steps = np.linspace(max_steps, 0, num_intermediate + 1, dtype=int).tolist()


#### üñºÔ∏è Generate Samples

In [None]:
# üè∑Ô∏è Generate label batch: 0, 1, 2, ..., (n_samples - 1) % num_classes
y = torch.arange(n_samples) % num_classes
y = y.to(device)

# üñåÔ∏è Generate samples with intermediate steps logged
all_samples_grouped = diffusion.sample(
    model=model,
    t_sample_times=intermediate_steps,
    log_intermediate=True,
    y=y,  # üëà Conditional generation
)

print(f"‚úÖ Generated {n_samples} samples with labels: {y.tolist()}")

#### üß± Reshape & Display

In [None]:
# üì∑ Display the generated image grid inline (works in Colab or Jupyter)
from PIL import Image
from IPython.display import display

stacked = torch.stack(all_samples_grouped)         # (T, B, C, H, W)
permuted = stacked.permute(1, 0, 2, 3, 4)           # (B, T, C, H, W)
flat_samples = permuted.reshape(-1, *permuted.shape[2:])  # (B*T, C, H, W)

# üíæ Save generated image grid
os.makedirs(save_dir, exist_ok=True)
out_path = os.path.join(save_dir, "all_samples_grid.png")
plot_image_grid(flat_samples, out_path, num_samples=n_samples, timesteps=intermediate_steps)

display(Image.open(out_path))