# 🧙‍♂️ Sample Generation with Pretrained Model + LLLA

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Jac-Zac/PML_DL_Final_Project/blob/master/notebooks/notebook_llla.ipynb)

### Initial setup ⚙️

In [1]:
import os

repo_dir = "PML_DL_Final_Project"

def in_colab():
    # Colab sets this environment variable
    return 'COLAB_GPU' in os.environ

if in_colab():
    # In Colab: clone repo if not present
    if not os.path.exists(repo_dir):
        !git clone https://github.com/Jac-Zac/PML_DL_Final_Project.git
        os.chdir(repo_dir)
        # Install requirements quietly
        !pip install -r requirements.txt -q
    else:
        os.chdir(repo_dir)
        print(f"Repository '{repo_dir}' already exists. Skipping clone.")
else:
    # Local: assume repo is already cloned
    print(f"Local Run, make sure you are inside '{repo_dir}' with the latest updates (git pull).")
    print(f"Moving to root directory to have correct access to all of the files")
    os.chdir("..")

Local Run, make sure you are inside 'PML_DL_Final_Project' with the latest updates (git pull).
Moving to root directory to have correct access to all of the files


### 📦 Imports

In [3]:
import torch

from src.utils.data import get_dataloaders
from src.models.diffusion import Diffusion
from src.utils.plots import plot_image_grid
from src.utils.environment import get_device, set_seed, load_pretrained_model

# Since on a notebook we can have nicer bars
import tqdm.notebook as tqdm

### 🧪 Setup: Seed and Device

In [4]:
seed = 1337
set_seed(seed)
device = get_device()
os.makedirs("checkpoints", exist_ok=True)

## 💡 Image Generation

#### 🛠️ Configuration Parameters

In [5]:
n_samples = 5
save_dir = "samples"
max_steps = 1000
model_name = "unet"
method = "diffusion"  # or "flow"
ckpt_path = "checkpoints/best_model.pth"  # or use your last checkpoint

#### 🔌 Load Pretrained Model

In [None]:
import wandb
num_classes = 10  # Total number of class labels (e.g., digits 0–9 for MNIST)

model_kwargs = {
    "num_classes": num_classes,
    "time_emb_dim": 128,  # Must match training config
}
# Model name as expected by your `get_model` function
model_name = "unet"
ckpt_path = "jac-zac/diffusion-project/best-model:v22"

# Load pretrained MAP model using best checkpoint
# By default since I'm passing the path to an artifact it will use Wandb
# Search there direcly via the API
model = load_pretrained_model(
    model_name=model_name,
    ckpt_path=ckpt_path,
    device=device,
    model_kwargs=model_kwargs,
    use_wandb=True,
)

#### Define Class for QUDiffusion

In [12]:

from typing import List, Optional, Tuple

import torch
from torch import Tensor, nn


class QUDiffusion(Diffusion):
    """
    Diffusion model with uncertainty estimation capabilities.
    Extends the base Diffusion class to support Laplace approximation models.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def sample_from_gaussian(self, mean: Tensor, var: Tensor) -> Tensor:
        """Sample from Gaussian distribution with given mean and variance."""
        std = torch.sqrt(torch.clamp(var, min=1e-8))
        return mean + std * torch.randn_like(mean)

    def perform_training_step(
        self,
        model: nn.Module,
        x_0: Tensor,
        y: Optional[Tensor] = None,
        t: Optional[Tensor] = None,
    ) -> Tensor:
        """Override to use accurate_forward during training if available."""
        x_0 = x_0.to(self.device)
        if t is None:
            t = self._sample_timesteps(x_0.size(0))
        x_t, noise = self._sample_q(x_0, t)

        noise_pred = model(x_t, t, y=y)

        return self.loss_simple(noise, noise_pred)

    @torch.no_grad()
    def sample_step(
        self,
        model: nn.Module,
        x_t: Tensor,
        t: Tensor,
        y: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Override sample_step to optionally include uncertainty.
        """
        return self._sample_step_with_uncertainty(model, x_t, t, y)

    def _sample_step_with_uncertainty(
        self,
        model: nn.Module,
        x_t: Tensor,
        t: Tensor,
        y: Optional[Tensor] = None,
    ) -> Tensor:
        """Sampling step with uncertainty estimation."""

        # NOTE: TO REVIEW
        beta_t = self.beta[t].view(-1, 1, 1, 1)
        alpha_t = self.alpha[t].view(-1, 1, 1, 1)
        alpha_bar_t = self.alpha_bar[t].view(-1, 1, 1, 1)

        # Get noise prediction with uncertainty
        noise_pred, noise_var = model(x_t, t, y=y)

        # Standard diffusion coefficients
        coef1 = 1.0 / alpha_t.sqrt()
        coef2 = (1.0 - alpha_t) / (1.0 - alpha_bar_t).sqrt()

        # Compute mean of x_prev
        x_prev_mean = coef1 * (x_t - coef2 * noise_pred)

        # Add scheduled noise
        if t[0] > 1:
            scheduled_noise = torch.randn_like(x_t) * beta_t.sqrt()
        else:
            scheduled_noise = torch.zeros_like(x_t)

        x_prev = x_prev_mean + scheduled_noise
        return x_prev

    @torch.no_grad()
    def sample_with_uncertainty(
        self,
        model: nn.Module,
        t_sample_times: Optional[List[int]] = None,
        channels: int = 1,
        log_intermediate: bool = False,
        y: Optional[Tensor] = None,
    ) -> Tuple[List[Tensor], List[Tensor]]:
        """
        Sample with uncertainty estimation at specified timesteps.

        Args:
            model: LaplaceApproxModel instance
            uncertainty_schedule: Boolean list indicating which timesteps to use uncertainty
            t_sample_times: Timesteps to log intermediates
            channels: Number of image channels
            log_intermediate: Whether to log intermediate results
            y: Conditional labels

        Returns:
            intermediates: List of generated samples
            uncertainties: List of uncertainty estimates (if return_uncertainties=True)
        """
        model.eval()
        batch_size = 1 if y is None else y.size(0)

        # NOTE: TO REVIEW

        # NOTE: Always use uncertainty or implement bayeskip
        # Initialize uncertainty schedule if not provided
        uncertainty_start = int(0 * self.noise_steps)
        uncertainty_schedule = [i >= uncertainty_start for i in range(self.noise_steps)]

        # Pad uncertainty schedule if too short
        while len(uncertainty_schedule) < self.noise_steps:
            uncertainty_schedule.append(False)

        # Initialize sampling
        x_t = torch.randn(
            batch_size, channels, self.img_size, self.img_size, device=self.device
        )

        intermediates = []
        uncertainties = []

        # Track uncertainty for propagation
        var_x_t = torch.zeros_like(x_t)

        for i in reversed(range(self.noise_steps)):
            t = torch.full((batch_size,), i, device=self.device, dtype=torch.long)
            use_uncertainty = uncertainty_schedule[i]

            # Perform sampling step
            x_t = self.sample_step(model, x_t, t, y=y)

            # Update uncertainty if tracking
            if use_uncertainty and hasattr(model, "forward"):
                # Get uncertainty estimate
                with torch.no_grad():
                    _, noise_var = model(x_t, t, y=y)
                    if noise_var is not None:
                        # Simple uncertainty propagation
                        beta_t = self.beta[t].view(-1, 1, 1, 1)
                        var_x_t = var_x_t + noise_var + beta_t
                    else:
                        var_x_t = var_x_t + self.beta[t].view(-1, 1, 1, 1)
            else:
                # Just add scheduled noise variance
                var_x_t = var_x_t + self.beta[t].view(-1, 1, 1, 1)

            # Store total uncertainty
            uncertainties.append(var_x_t.sum(dim=(1, 2, 3)).cpu())

            # Log intermediate if requested
            if log_intermediate and t_sample_times and i in t_sample_times:
                intermediates.append(self.transform_sampled_image(x_t.clone()))

        # Add final sample
        intermediates.append(self.transform_sampled_image(x_t))
        model.train()

        return intermediates, uncertainties

    @torch.no_grad()
    def sample(
        self,
        model: nn.Module,
        t_sample_times: Optional[List[int]] = None,
        channels: int = 1,
        log_intermediate: bool = False,
        y: Optional[Tensor] = None,
    ) -> Tuple[List[Tensor], List[Tensor]]:
        """
        Override sample method to optionally use uncertainty.

        If uncertainty_schedule is provided, uses uncertainty sampling,
        otherwise falls back to deterministic sampling for backward compatibility.
        """
        intermediates, uncertainties = self.sample_with_uncertainty(
            model=model,
            t_sample_times=t_sample_times,
            channels=channels,
            log_intermediate=log_intermediate,
            y=y,
        )
        return intermediates, uncertainties

NameError: name 'artifact_dir' is not defined

### 💪 Fit Laplace approximation

In [7]:
from src.models.llla_model import LaplaceApproxModel

batch_size = 128

# Prepare data loaders for the Laplace fit
train_loader, _ = get_dataloaders(batch_size=batch_size)

# Wrap diffusion model with your CustomModel for Laplace last layer approx
# This fits also the model by default
laplace_model = LaplaceApproxModel(diff_model, train_loader, args=None, config=None)



ModuleNotFoundError: No module named 'laplace'

### 💨 Initialize Diffusion Process
<!-- #endregion -->

In [None]:

# Initialize uncertainty-aware diffusion (same interface as base class)
diffusion = QUDiffusion(img_size=28, device=device)

# Works exactly like base Diffusion class
samples = diffusion.sample(model=laplace_model)

# Or get detailed uncertainty information
samples, uncertainties = diffusion.sample_with_uncertainty(
    model=laplace_model,
    channels=3,
)

# Try to make some samples

# NOTE: ADD A NEW FUNCTION FOR PLOT WITH UNCERTAINTY
  # plot_image_grid(
  # model,
  # method_instance,
  # num_intermediate=5,
  # n=args.n,
  # max_steps=args.max_steps,
  # save_dir=args.save_dir,
  # device=device,
# num_classes=num_classes,
# )
