In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import gc
import copy

# Hyperparameters
NUM_EPOCHS = 200
BATCH_SIZE = 512
LEARNING_RATE = 2e-4
IMAGE_SIZE = 32  # CIFAR10
CHANNELS = 3
TIMESTEPS = 1000
BETA_START = 1e-4
BETA_END = 0.02
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_INTERVAL = 1  # Save every epoch
CHECKPOINT_PATH = "diffusion_checkpoint.pt"
BEST_MODEL_PATH = "diffusion_best_model.pt"
EMA_DECAY = 0.9999  # For EMA model weights
BASE_CHANNELS = 64  # Increased from 32 for better capacity

# Helper functions
def extract(a, t, shape):
    """Extract values from a tensor at specific timesteps."""
    batch_size = t.shape[0]
    out = a.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(shape) - 1)))

# Memory management function
def clear_memory():
    """Clear CUDA cache to free up memory."""
    gc.collect()
    torch.cuda.empty_cache()

# Create cosine learning rate scheduler
def get_cosine_schedule(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
    """Create a cosine LR schedule with warmup."""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

# Self-attention module for diffusion model
class SelfAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        size = x.shape[-2:]
        x = x.flatten(-2, -1).permute(0, 2, 1)  # (B, C, H*W) -> (B, H*W, C)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        attention_value = attention_value.permute(0, 2, 1).view(-1, self.channels, *size)  # (B, H*W, C) -> (B, C, H, W)
        return attention_value

# Create noise scheduler
class NoiseScheduler:
    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device="cuda"):
        """Initialize diffusion process parameters."""
        self.timesteps = timesteps

        # Define forward process variances (betas)
        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)

        # Calculate alphas and other constants
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alphas_cumprod[:-1]])

        # Constants for inference
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)

        # For posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)

    def q_sample(self, x_0, t, noise=None):
        """Sample from q(x_t | x_0) - the forward diffusion process."""
        if noise is None:
            noise = torch.randn_like(x_0)

        # Extract the appropriate timestep values
        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_0.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_0.shape)

        # Equation (4) from the DDPM paper
        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise

    def loss_fn(self, model, x_0, t, noise=None):
        """Calculate the denoising loss (Lsimple from paper Section 3.4)."""
        if noise is None:
            noise = torch.randn_like(x_0)

        # Get noisy sample x_t
        x_t = self.q_sample(x_0, t, noise)

        # Get model's predicted noise
        pred_noise = model(x_t, t)

        # Calculate MSE loss between actual and predicted noise
        return F.mse_loss(pred_noise, noise)

# Time embedding module
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

# Improved U-Net model with attention
class ImprovedUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalTimeEmbedding(base_channels),
            nn.Linear(base_channels, base_channels * 4),
            nn.GELU(),
            nn.Linear(base_channels * 4, base_channels * 4)
        )

        # Initial convolution
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Downsampling
        self.down1 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, padding=1),
            nn.GroupNorm(8, base_channels * 2),
            nn.GELU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)  # 32 -> 16
        )

        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 4, 3, padding=1),
            nn.GroupNorm(8, base_channels * 4),
            nn.GELU(),
            nn.Conv2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)  # 16 -> 8
        )

        # New deeper downsampling layer
        self.down3 = nn.Sequential(
            nn.Conv2d(base_channels * 4, base_channels * 8, 3, padding=1),
            nn.GroupNorm(8, base_channels * 8),
            nn.GELU(),
            nn.Conv2d(base_channels * 8, base_channels * 8, 4, stride=2, padding=1)  # 8 -> 4
        )

        # Attention layers at lower resolutions
        self.attn1 = SelfAttention(base_channels * 4)  # At 8x8 resolution
        self.attn2 = SelfAttention(base_channels * 8)  # At 4x4 resolution

        # Middle
        self.middle = nn.Sequential(
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1),
            nn.GroupNorm(8, base_channels * 8),
            nn.GELU(),
            SelfAttention(base_channels * 8),  # Attention in the middle
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1),
            nn.GroupNorm(8, base_channels * 8),
            nn.GELU(),
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1)
        )

        # Time projection layers
        self.time_proj1 = nn.Linear(base_channels * 4, base_channels)
        self.time_proj2 = nn.Linear(base_channels * 4, base_channels * 2)
        self.time_proj3 = nn.Linear(base_channels * 4, base_channels * 4)
        self.time_proj4 = nn.Linear(base_channels * 4, base_channels * 8)
        self.time_proj5 = nn.Linear(base_channels * 4, base_channels * 8)

        # Upsampling with attention - FIXED CHANNEL DIMENSIONS
        # New upsampling layer from lowest resolution
        self.up0 = nn.Sequential(
            # Input: concat of middle(8*bc) and down3(8*bc) -> 16*bc channels
            nn.ConvTranspose2d(base_channels * 16, base_channels * 8, 4, stride=2, padding=1),  # 4 -> 8
            nn.GroupNorm(8, base_channels * 8),
            nn.GELU()
        )

        self.up1 = nn.Sequential(
            # Input: concat of up0(8*bc) and down2(4*bc) -> 12*bc channels
            nn.ConvTranspose2d(base_channels * 12, base_channels * 4, 4, stride=2, padding=1),  # 8 -> 16
            nn.GroupNorm(8, base_channels * 4),
            nn.GELU()
        )

        self.up2 = nn.Sequential(
            # Input: concat of up1(4*bc) and down1(2*bc) -> 6*bc channels
            nn.ConvTranspose2d(base_channels * 6, base_channels * 2, 4, stride=2, padding=1),  # 16 -> 32
            nn.GroupNorm(8, base_channels * 2),
            nn.GELU()
        )

        # Final layers
        self.final = nn.Sequential(
            nn.Conv2d(base_channels * 3, base_channels, 3, padding=1),
            nn.GroupNorm(8, base_channels),
            nn.GELU(),
            nn.Conv2d(base_channels, out_channels, 3, padding=1)
        )

    def forward(self, x, t):
        # Time embedding
        t = t.unsqueeze(-1).type(torch.float)
        t = self.time_mlp(t)  # Shape [batch_size, base_channels * 4]
        t = t.squeeze(1)  # Remove extra dimension if needed

        # Initial features
        x0 = self.init_conv(x)  # Shape [batch_size, base_channels, H, W]

        # Add time information
        t0 = self.time_proj1(t)
        t0 = t0.view(t0.shape[0], t0.shape[1], 1, 1)  # Reshape to [batch_size, base_channels, 1, 1]
        x0 = x0 + t0

        # Downsample
        x1 = self.down1(x0)
        t1 = self.time_proj2(t)
        t1 = t1.view(t1.shape[0], t1.shape[1], 1, 1)
        x1 = x1 + t1

        x2 = self.down2(x1)
        t2 = self.time_proj3(t)
        t2 = t2.view(t2.shape[0], t2.shape[1], 1, 1)
        x2 = x2 + t2

        # Apply attention at 8x8 resolution
        x2 = self.attn1(x2)

        # New deeper path - downsample to 4x4
        x3 = self.down3(x2)
        t3 = self.time_proj4(t)
        t3 = t3.view(t3.shape[0], t3.shape[1], 1, 1)
        x3 = x3 + t3

        # Apply attention at 4x4 resolution
        x3 = self.attn2(x3)

        # Middle
        xm = self.middle(x3)
        t4 = self.time_proj5(t)
        t4 = t4.view(t4.shape[0], t4.shape[1], 1, 1)
        xm = xm + t4

        # Upsample with skip connections
        # xm and x3 both have 8*base_channels -> concat to 16*base_channels
        x_up0 = self.up0(torch.cat([xm, x3], dim=1))  # 16*bc -> 8*bc

        # x_up0 has 8*bc, x2 has 4*bc -> concat to 12*bc
        x_up1 = self.up1(torch.cat([x_up0, x2], dim=1))  # 12*bc -> 4*bc

        # x_up1 has 4*bc, x1 has 2*bc -> concat to 6*bc
        x_up2 = self.up2(torch.cat([x_up1, x1], dim=1))  # 6*bc -> 2*bc

        # Final: x_up2 has 2*bc, x0 has 1*bc -> concat to 3*bc
        output = self.final(torch.cat([x_up2, x0], dim=1))  # 3*bc -> out_channels

        return output

# EMA (Exponential Moving Average) for model weights
class EMAModel:
    def __init__(self, model, decay=0.9999):
        self.model = copy.deepcopy(model)
        self.model.eval()
        self.decay = decay
        self.model.requires_grad_(False)

    def update(self, model):
        with torch.no_grad():
            for ema_param, param in zip(self.model.parameters(), model.parameters()):
                ema_param.data.mul_(self.decay).add_(param.data, alpha=1 - self.decay)

            for ema_buffer, buffer in zip(self.model.buffers(), model.buffers()):
                ema_buffer.copy_(buffer)

# Training function with memory optimization
def train_epoch(model, ema_model, scheduler, dataloader, optimizer, epoch, lr_scheduler):
    model.train()
    epoch_loss = 0

    with tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}") as pbar:
        for batch_idx, (images, _) in enumerate(pbar):
            images = images.to(DEVICE)
            batch_size = images.shape[0]

            # Sample random timesteps
            t = torch.randint(0, TIMESTEPS, (batch_size,), device=DEVICE, dtype=torch.long)

            # Calculate loss
            optimizer.zero_grad()
            loss = scheduler.loss_fn(model, images, t)

            # Backpropagation
            loss.backward()
            optimizer.step()

            # Update EMA model
            ema_model.update(model)

            # Update learning rate
            lr_scheduler.step()

            # Update metrics
            epoch_loss += loss.item()
            pbar.set_postfix({"loss": epoch_loss / (batch_idx + 1), "lr": optimizer.param_groups[0]['lr']})

            # Free memory occasionally
            if batch_idx % 100 == 0:
                clear_memory()

    return epoch_loss / len(dataloader)

def save_checkpoint(model, ema_model, optimizer, epoch, loss, lr_scheduler, best_loss=float('inf'), filename=CHECKPOINT_PATH):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "ema_model_state_dict": ema_model.model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "lr_scheduler_state_dict": lr_scheduler.state_dict(),
        "epoch": epoch,
        "loss": loss,
        "best_loss": best_loss,
        "scheduler_params": {
            "timesteps": TIMESTEPS,
            "beta_start": BETA_START,
            "beta_end": BETA_END
        },
        "model_params": {
            "in_channels": CHANNELS,
            "out_channels": CHANNELS,
            "base_channels": BASE_CHANNELS
        }
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved to {filename}")

def save_best_model(model, ema_model, optimizer, epoch, loss, lr_scheduler, best_loss, filename=BEST_MODEL_PATH):
    if loss < best_loss:
        save_checkpoint(model, ema_model, optimizer, epoch, loss, lr_scheduler, loss, filename)
        print(f"New best model saved with loss: {loss:.6f}")
        return loss
    return best_loss

def load_checkpoint(model, ema_model, optimizer, lr_scheduler, filename=CHECKPOINT_PATH):
    if os.path.exists(filename):
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint["model_state_dict"])

        if "ema_model_state_dict" in checkpoint:
            ema_model.model.load_state_dict(checkpoint["ema_model_state_dict"])

        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        if "lr_scheduler_state_dict" in checkpoint and lr_scheduler is not None:
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])

        start_epoch = checkpoint["epoch"] + 1
        best_loss = checkpoint.get("best_loss", float('inf'))
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']} with loss {checkpoint['loss']:.6f}")
        return start_epoch, best_loss
    else:
        return 0, float('inf')

# Training
def train_model():
    # Create directories
    os.makedirs("samples", exist_ok=True)

    # Load CIFAR10 dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
    ])

    trainset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )
    trainloader = DataLoader(
        trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
    )

    # Create improved model
    model = ImprovedUNet(
        in_channels=CHANNELS,
        out_channels=CHANNELS,
        base_channels=BASE_CHANNELS
    ).to(DEVICE)

    # Create EMA model
    ema_model = EMAModel(model, decay=EMA_DECAY)

    # Create noise scheduler
    scheduler = NoiseScheduler(
        timesteps=TIMESTEPS,
        beta_start=BETA_START,
        beta_end=BETA_END,
        device=DEVICE
    )

    # Create optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # Create learning rate scheduler
    # Calculate the total training steps
    total_steps = NUM_EPOCHS * len(trainloader)
    warmup_steps = int(0.1 * total_steps)  # 10% of total steps for warmup

    lr_scheduler = get_cosine_schedule(
        optimizer=optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    # Load checkpoint if exists
    start_epoch, best_loss = load_checkpoint(model, ema_model, optimizer, lr_scheduler)

    # Training loop
    for epoch in range(start_epoch, NUM_EPOCHS):
        # Clear memory before each epoch
        clear_memory()

        loss = train_epoch(model, ema_model, scheduler, trainloader, optimizer, epoch, lr_scheduler)

        # Save best model
        best_loss = save_best_model(model, ema_model, optimizer, epoch, loss, lr_scheduler, best_loss)

        # Save regular checkpoint
        if (epoch + 1) % SAVE_INTERVAL == 0:
            save_checkpoint(model, ema_model, optimizer, epoch, loss, lr_scheduler, best_loss)

    print("Training completed!")

# Run the training
if __name__ == "__main__":
    # Set memory growth
    torch.cuda.empty_cache()

    # Enable mixed precision training if available
    if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
        print("Using mixed precision training")

    train_model()

Using mixed precision training


100%|██████████| 170M/170M [00:02<00:00, 75.5MB/s]
Epoch 1/200: 100%|██████████| 98/98 [00:11<00:00,  8.39it/s, loss=0.967, lr=1e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.967425
Checkpoint saved to diffusion_checkpoint.pt


Epoch 2/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.438, lr=2e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.438193
Checkpoint saved to diffusion_checkpoint.pt


Epoch 3/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.164, lr=3e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.164149
Checkpoint saved to diffusion_checkpoint.pt


Epoch 4/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.114, lr=4e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.114455
Checkpoint saved to diffusion_checkpoint.pt


Epoch 5/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0901, lr=5e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.090100
Checkpoint saved to diffusion_checkpoint.pt


Epoch 6/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0746, lr=6e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.074604
Checkpoint saved to diffusion_checkpoint.pt


Epoch 7/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0655, lr=7e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.065502
Checkpoint saved to diffusion_checkpoint.pt


Epoch 8/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0605, lr=8e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.060493
Checkpoint saved to diffusion_checkpoint.pt


Epoch 9/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0568, lr=9e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.056751
Checkpoint saved to diffusion_checkpoint.pt


Epoch 10/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0535, lr=0.0001]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.053545
Checkpoint saved to diffusion_checkpoint.pt


Epoch 11/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0524, lr=0.00011]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.052441
Checkpoint saved to diffusion_checkpoint.pt


Epoch 12/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.051, lr=0.00012]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.050959
Checkpoint saved to diffusion_checkpoint.pt


Epoch 13/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0492, lr=0.00013]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.049197
Checkpoint saved to diffusion_checkpoint.pt


Epoch 14/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0482, lr=0.00014]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.048171
Checkpoint saved to diffusion_checkpoint.pt


Epoch 15/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0465, lr=0.00015]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.046509
Checkpoint saved to diffusion_checkpoint.pt


Epoch 16/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0459, lr=0.00016]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.045856
Checkpoint saved to diffusion_checkpoint.pt


Epoch 17/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0451, lr=0.00017]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.045055
Checkpoint saved to diffusion_checkpoint.pt


Epoch 18/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0435, lr=0.00018]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.043532
Checkpoint saved to diffusion_checkpoint.pt


Epoch 19/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0427, lr=0.00019]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.042685
Checkpoint saved to diffusion_checkpoint.pt


Epoch 20/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0418, lr=0.0002]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.041781
Checkpoint saved to diffusion_checkpoint.pt


Epoch 21/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0415, lr=0.0002]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.041525
Checkpoint saved to diffusion_checkpoint.pt


Epoch 22/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.04, lr=0.0002]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.040035
Checkpoint saved to diffusion_checkpoint.pt


Epoch 23/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0394, lr=0.0002]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.039442
Checkpoint saved to diffusion_checkpoint.pt


Epoch 24/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0396, lr=0.0002]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 25/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0385, lr=0.0002]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.038480
Checkpoint saved to diffusion_checkpoint.pt


Epoch 26/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.038, lr=0.000199]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.038035
Checkpoint saved to diffusion_checkpoint.pt


Epoch 27/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0385, lr=0.000199]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 28/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0372, lr=0.000199]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.037185
Checkpoint saved to diffusion_checkpoint.pt


Epoch 29/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0375, lr=0.000199]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 30/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0375, lr=0.000198]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 31/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.037, lr=0.000198]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.036985
Checkpoint saved to diffusion_checkpoint.pt


Epoch 32/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0371, lr=0.000198]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 33/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0363, lr=0.000197]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.036293
Checkpoint saved to diffusion_checkpoint.pt


Epoch 34/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0366, lr=0.000197]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 35/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0362, lr=0.000197]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.036223
Checkpoint saved to diffusion_checkpoint.pt


Epoch 36/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0358, lr=0.000196]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.035751
Checkpoint saved to diffusion_checkpoint.pt


Epoch 37/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0354, lr=0.000196]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.035355
Checkpoint saved to diffusion_checkpoint.pt


Epoch 38/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0351, lr=0.000195]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.035055
Checkpoint saved to diffusion_checkpoint.pt


Epoch 39/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0352, lr=0.000195]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 40/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0351, lr=0.000194]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 41/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0351, lr=0.000193]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 42/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0354, lr=0.000193]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 43/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0349, lr=0.000192]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.034919
Checkpoint saved to diffusion_checkpoint.pt


Epoch 44/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0349, lr=0.000191]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 45/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0347, lr=0.000191]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.034675
Checkpoint saved to diffusion_checkpoint.pt


Epoch 46/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0348, lr=0.00019]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 47/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0344, lr=0.000189]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.034352
Checkpoint saved to diffusion_checkpoint.pt


Epoch 48/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0343, lr=0.000188]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.034314
Checkpoint saved to diffusion_checkpoint.pt


Epoch 49/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0342, lr=0.000187]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.034175
Checkpoint saved to diffusion_checkpoint.pt


Epoch 50/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0342, lr=0.000187]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 51/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0341, lr=0.000186]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.034134
Checkpoint saved to diffusion_checkpoint.pt


Epoch 52/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0341, lr=0.000185]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.034081
Checkpoint saved to diffusion_checkpoint.pt


Epoch 53/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.034, lr=0.000184]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.034027
Checkpoint saved to diffusion_checkpoint.pt


Epoch 54/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0337, lr=0.000183]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.033662
Checkpoint saved to diffusion_checkpoint.pt


Epoch 55/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0336, lr=0.000182]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.033612
Checkpoint saved to diffusion_checkpoint.pt


Epoch 56/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0339, lr=0.000181]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 57/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0334, lr=0.00018]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.033354
Checkpoint saved to diffusion_checkpoint.pt


Epoch 58/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0331, lr=0.000179]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.033080
Checkpoint saved to diffusion_checkpoint.pt


Epoch 59/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0335, lr=0.000178]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 60/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0337, lr=0.000177]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 61/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0334, lr=0.000175]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 62/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0339, lr=0.000174]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 63/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0334, lr=0.000173]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 64/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0332, lr=0.000172]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 65/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0326, lr=0.000171]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.032571
Checkpoint saved to diffusion_checkpoint.pt


Epoch 66/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0335, lr=0.000169]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 67/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0329, lr=0.000168]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 68/200: 100%|██████████| 98/98 [00:10<00:00,  9.74it/s, loss=0.0329, lr=0.000167]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 69/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0328, lr=0.000166]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 70/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0332, lr=0.000164]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 71/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0334, lr=0.000163]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 72/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0331, lr=0.000162]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 73/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0325, lr=0.00016]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.032495
Checkpoint saved to diffusion_checkpoint.pt


Epoch 74/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0323, lr=0.000159]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.032299
Checkpoint saved to diffusion_checkpoint.pt


Epoch 75/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0327, lr=0.000157]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 76/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0324, lr=0.000156]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 77/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0324, lr=0.000154]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 78/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0322, lr=0.000153]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.032210
Checkpoint saved to diffusion_checkpoint.pt


Epoch 79/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0327, lr=0.000152]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 80/200: 100%|██████████| 98/98 [00:10<00:00,  9.73it/s, loss=0.0326, lr=0.00015]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 81/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.032, lr=0.000148]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.031981
Checkpoint saved to diffusion_checkpoint.pt


Epoch 82/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.033, lr=0.000147]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 83/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0325, lr=0.000145]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 84/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0325, lr=0.000144]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 85/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0322, lr=0.000142]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 86/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0323, lr=0.000141]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 87/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0324, lr=0.000139]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 88/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0322, lr=0.000137]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 89/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0322, lr=0.000136]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 90/200: 100%|██████████| 98/98 [00:10<00:00,  9.72it/s, loss=0.0322, lr=0.000134]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 91/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0321, lr=0.000133]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 92/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.032, lr=0.000131]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 93/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0321, lr=0.000129]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 94/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0322, lr=0.000128]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 95/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0323, lr=0.000126]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 96/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0316, lr=0.000124]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.031621
Checkpoint saved to diffusion_checkpoint.pt


Epoch 97/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0318, lr=0.000122]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 98/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0318, lr=0.000121]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 99/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0319, lr=0.000119]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 100/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0314, lr=0.000117]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.031437
Checkpoint saved to diffusion_checkpoint.pt


Epoch 101/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0316, lr=0.000116]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 102/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0321, lr=0.000114]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 103/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0318, lr=0.000112]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 104/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0311, lr=0.00011]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.031109
Checkpoint saved to diffusion_checkpoint.pt


Epoch 105/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0314, lr=0.000109]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 106/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0325, lr=0.000107]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 107/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0314, lr=0.000105]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 108/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0314, lr=0.000103]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 109/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0314, lr=0.000102]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 110/200: 100%|██████████| 98/98 [00:10<00:00,  9.71it/s, loss=0.0312, lr=0.0001]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 111/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0317, lr=9.83e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 112/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0312, lr=9.65e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 113/200: 100%|██████████| 98/98 [00:10<00:00,  9.65it/s, loss=0.0319, lr=9.48e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 114/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0313, lr=9.3e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 115/200: 100%|██████████| 98/98 [00:10<00:00,  9.66it/s, loss=0.0314, lr=9.13e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 116/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0313, lr=8.95e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 117/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0306, lr=8.78e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.030617
Checkpoint saved to diffusion_checkpoint.pt


Epoch 118/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0311, lr=8.61e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 119/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0319, lr=8.44e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 120/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0316, lr=8.26e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 121/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0314, lr=8.09e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 122/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0308, lr=7.92e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 123/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0313, lr=7.75e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 124/200: 100%|██████████| 98/98 [00:10<00:00,  9.65it/s, loss=0.031, lr=7.58e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 125/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0315, lr=7.41e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 126/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0319, lr=7.24e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 127/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0313, lr=7.08e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 128/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0313, lr=6.91e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 129/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0313, lr=6.74e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 130/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0312, lr=6.58e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 131/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0315, lr=6.42e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 132/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0309, lr=6.25e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 133/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0311, lr=6.09e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 134/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0313, lr=5.93e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 135/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0305, lr=5.77e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.030506
Checkpoint saved to diffusion_checkpoint.pt


Epoch 136/200: 100%|██████████| 98/98 [00:10<00:00,  9.64it/s, loss=0.0309, lr=5.62e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 137/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0311, lr=5.46e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 138/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.031, lr=5.31e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 139/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0311, lr=5.15e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 140/200: 100%|██████████| 98/98 [00:10<00:00,  9.66it/s, loss=0.031, lr=5e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 141/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0311, lr=4.85e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 142/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0305, lr=4.7e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 143/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0309, lr=4.55e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 144/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0308, lr=4.41e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 145/200: 100%|██████████| 98/98 [00:10<00:00,  9.66it/s, loss=0.0309, lr=4.26e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 146/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0307, lr=4.12e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 147/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0307, lr=3.98e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 148/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0304, lr=3.84e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.030449
Checkpoint saved to diffusion_checkpoint.pt


Epoch 149/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0308, lr=3.71e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 150/200: 100%|██████████| 98/98 [00:10<00:00,  9.66it/s, loss=0.0303, lr=3.57e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.030321
Checkpoint saved to diffusion_checkpoint.pt


Epoch 151/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0307, lr=3.44e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 152/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0303, lr=3.31e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.030256
Checkpoint saved to diffusion_checkpoint.pt


Epoch 153/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0305, lr=3.18e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 154/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0309, lr=3.05e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 155/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0304, lr=2.93e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 156/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0303, lr=2.81e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 157/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0305, lr=2.69e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 158/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0307, lr=2.57e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 159/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0311, lr=2.45e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 160/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0303, lr=2.34e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 161/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0304, lr=2.23e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 162/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0303, lr=2.12e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 163/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0303, lr=2.01e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 164/200: 100%|██████████| 98/98 [00:10<00:00,  9.70it/s, loss=0.0306, lr=1.91e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 165/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0305, lr=1.81e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 166/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0305, lr=1.71e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 167/200: 100%|██████████| 98/98 [00:10<00:00,  9.63it/s, loss=0.0302, lr=1.61e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.030189
Checkpoint saved to diffusion_checkpoint.pt


Epoch 168/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0304, lr=1.52e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 169/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0303, lr=1.43e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 170/200: 100%|██████████| 98/98 [00:10<00:00,  9.68it/s, loss=0.0306, lr=1.34e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 171/200: 100%|██████████| 98/98 [00:10<00:00,  9.69it/s, loss=0.0306, lr=1.25e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 172/200: 100%|██████████| 98/98 [00:10<00:00,  9.64it/s, loss=0.0305, lr=1.17e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 173/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.031, lr=1.09e-5]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 174/200: 100%|██████████| 98/98 [00:10<00:00,  9.61it/s, loss=0.0301, lr=1.01e-5]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.030118
Checkpoint saved to diffusion_checkpoint.pt


Epoch 175/200: 100%|██████████| 98/98 [00:10<00:00,  9.64it/s, loss=0.0299, lr=9.37e-6]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.029891
Checkpoint saved to diffusion_checkpoint.pt


Epoch 176/200: 100%|██████████| 98/98 [00:10<00:00,  9.66it/s, loss=0.0305, lr=8.65e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 177/200: 100%|██████████| 98/98 [00:10<00:00,  9.62it/s, loss=0.0304, lr=7.95e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 178/200: 100%|██████████| 98/98 [00:10<00:00,  9.60it/s, loss=0.0297, lr=7.28e-6]


Checkpoint saved to diffusion_best_model.pt
New best model saved with loss: 0.029730
Checkpoint saved to diffusion_checkpoint.pt


Epoch 179/200: 100%|██████████| 98/98 [00:10<00:00,  9.61it/s, loss=0.0302, lr=6.64e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 180/200: 100%|██████████| 98/98 [00:10<00:00,  9.64it/s, loss=0.0307, lr=6.03e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 181/200: 100%|██████████| 98/98 [00:10<00:00,  9.62it/s, loss=0.0303, lr=5.45e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 182/200: 100%|██████████| 98/98 [00:10<00:00,  9.63it/s, loss=0.0305, lr=4.89e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 183/200: 100%|██████████| 98/98 [00:10<00:00,  9.64it/s, loss=0.0307, lr=4.37e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 184/200: 100%|██████████| 98/98 [00:10<00:00,  9.64it/s, loss=0.0303, lr=3.87e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 185/200: 100%|██████████| 98/98 [00:10<00:00,  9.63it/s, loss=0.0304, lr=3.41e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 186/200: 100%|██████████| 98/98 [00:10<00:00,  9.65it/s, loss=0.0307, lr=2.97e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 187/200: 100%|██████████| 98/98 [00:10<00:00,  9.62it/s, loss=0.0303, lr=2.56e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 188/200: 100%|██████████| 98/98 [00:10<00:00,  9.65it/s, loss=0.0306, lr=2.19e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 189/200: 100%|██████████| 98/98 [00:10<00:00,  9.62it/s, loss=0.0297, lr=1.84e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 190/200: 100%|██████████| 98/98 [00:10<00:00,  9.63it/s, loss=0.0303, lr=1.52e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 191/200: 100%|██████████| 98/98 [00:10<00:00,  9.65it/s, loss=0.0307, lr=1.23e-6]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 192/200: 100%|██████████| 98/98 [00:10<00:00,  9.62it/s, loss=0.0299, lr=9.73e-7]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 193/200: 100%|██████████| 98/98 [00:10<00:00,  9.62it/s, loss=0.0303, lr=7.45e-7]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 194/200: 100%|██████████| 98/98 [00:10<00:00,  9.65it/s, loss=0.0303, lr=5.48e-7]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 195/200: 100%|██████████| 98/98 [00:10<00:00,  9.63it/s, loss=0.0302, lr=3.81e-7]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 196/200: 100%|██████████| 98/98 [00:10<00:00,  9.65it/s, loss=0.0298, lr=2.44e-7]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 197/200: 100%|██████████| 98/98 [00:10<00:00,  9.64it/s, loss=0.0302, lr=1.37e-7]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 198/200: 100%|██████████| 98/98 [00:10<00:00,  9.63it/s, loss=0.0302, lr=6.09e-8]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 199/200: 100%|██████████| 98/98 [00:10<00:00,  9.64it/s, loss=0.0303, lr=1.52e-8]


Checkpoint saved to diffusion_checkpoint.pt


Epoch 200/200: 100%|██████████| 98/98 [00:10<00:00,  9.67it/s, loss=0.0301, lr=0]


Checkpoint saved to diffusion_checkpoint.pt
Training completed!


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import torchvision
import os
import gc
import matplotlib.pyplot as plt
import math

# Only define the necessary classes and functions for sampling
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_PATH = "diffusion_best_model.pt"  # Use the best model for sampling
SAMPLE_BATCH_SIZE = 4  # Very small batch size for sampling
BASE_CHANNELS = 64  # Match the training model

def clear_memory():
    """Clear CUDA cache to free up memory."""
    gc.collect()
    torch.cuda.empty_cache()

def extract(a, t, shape):
    """Extract values from a tensor at specific timesteps."""
    batch_size = t.shape[0]
    out = a.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(shape) - 1)))

# Time embedding module
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, t):
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

# Self-attention module for diffusion model
class SelfAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        size = x.shape[-2:]
        x = x.flatten(-2, -1).permute(0, 2, 1)  # (B, C, H*W) -> (B, H*W, C)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        attention_value = attention_value.permute(0, 2, 1).view(-1, self.channels, *size)  # (B, H*W, C) -> (B, C, H, W)
        return attention_value

# Improved U-Net model with attention (identical to training model)
class ImprovedUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalTimeEmbedding(base_channels),
            nn.Linear(base_channels, base_channels * 4),
            nn.GELU(),
            nn.Linear(base_channels * 4, base_channels * 4)
        )

        # Initial convolution
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Downsampling
        self.down1 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, padding=1),
            nn.GroupNorm(8, base_channels * 2),
            nn.GELU(),
            nn.Conv2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)  # 32 -> 16
        )

        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels * 2, base_channels * 4, 3, padding=1),
            nn.GroupNorm(8, base_channels * 4),
            nn.GELU(),
            nn.Conv2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)  # 16 -> 8
        )

        # New deeper downsampling layer
        self.down3 = nn.Sequential(
            nn.Conv2d(base_channels * 4, base_channels * 8, 3, padding=1),
            nn.GroupNorm(8, base_channels * 8),
            nn.GELU(),
            nn.Conv2d(base_channels * 8, base_channels * 8, 4, stride=2, padding=1)  # 8 -> 4
        )

        # Attention layers at lower resolutions
        self.attn1 = SelfAttention(base_channels * 4)  # At 8x8 resolution
        self.attn2 = SelfAttention(base_channels * 8)  # At 4x4 resolution

        # Middle
        self.middle = nn.Sequential(
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1),
            nn.GroupNorm(8, base_channels * 8),
            nn.GELU(),
            SelfAttention(base_channels * 8),  # Attention in the middle
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1),
            nn.GroupNorm(8, base_channels * 8),
            nn.GELU(),
            nn.Conv2d(base_channels * 8, base_channels * 8, 3, padding=1)
        )

        # Time projection layers
        self.time_proj1 = nn.Linear(base_channels * 4, base_channels)
        self.time_proj2 = nn.Linear(base_channels * 4, base_channels * 2)
        self.time_proj3 = nn.Linear(base_channels * 4, base_channels * 4)
        self.time_proj4 = nn.Linear(base_channels * 4, base_channels * 8)
        self.time_proj5 = nn.Linear(base_channels * 4, base_channels * 8)

        # Upsampling with attention
        # New upsampling layer from lowest resolution
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 16, base_channels * 8, 4, stride=2, padding=1),  # 4 -> 8
            nn.GroupNorm(8, base_channels * 8),
            nn.GELU()
        )

        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 12, base_channels * 4, 4, stride=2, padding=1),  # 8 -> 16
            nn.GroupNorm(8, base_channels * 4),
            nn.GELU()
        )

        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 6, base_channels * 2, 4, stride=2, padding=1),  # 16 -> 32
            nn.GroupNorm(8, base_channels * 2),
            nn.GELU()
        )

        # Final layers
        self.final = nn.Sequential(
            nn.Conv2d(base_channels * 3, base_channels, 3, padding=1),
            nn.GroupNorm(8, base_channels),
            nn.GELU(),
            nn.Conv2d(base_channels, out_channels, 3, padding=1)
        )

    def forward(self, x, t):
        # Time embedding
        t = t.unsqueeze(-1).type(torch.float)
        t = self.time_mlp(t)  # Shape [batch_size, base_channels * 4]
        t = t.squeeze(1)  # Remove extra dimension if needed

        # Initial features
        x0 = self.init_conv(x)  # Shape [batch_size, base_channels, H, W]

        # Add time information
        t0 = self.time_proj1(t)
        t0 = t0.view(t0.shape[0], t0.shape[1], 1, 1)  # Reshape to [batch_size, base_channels, 1, 1]
        x0 = x0 + t0

        # Downsample
        x1 = self.down1(x0)
        t1 = self.time_proj2(t)
        t1 = t1.view(t1.shape[0], t1.shape[1], 1, 1)
        x1 = x1 + t1

        x2 = self.down2(x1)
        t2 = self.time_proj3(t)
        t2 = t2.view(t2.shape[0], t2.shape[1], 1, 1)
        x2 = x2 + t2

        # Apply attention at 8x8 resolution
        x2 = self.attn1(x2)

        # New deeper path - downsample to 4x4
        x3 = self.down3(x2)
        t3 = self.time_proj4(t)
        t3 = t3.view(t3.shape[0], t3.shape[1], 1, 1)
        x3 = x3 + t3

        # Apply attention at 4x4 resolution
        x3 = self.attn2(x3)

        # Middle
        xm = self.middle(x3)
        t4 = self.time_proj5(t)
        t4 = t4.view(t4.shape[0], t4.shape[1], 1, 1)
        xm = xm + t4

        # Upsample with skip connections
        x_up0 = self.up0(torch.cat([xm, x3], dim=1))
        x_up1 = self.up1(torch.cat([x_up0, x2], dim=1))
        x_up2 = self.up2(torch.cat([x_up1, x1], dim=1))

        # Final
        output = self.final(torch.cat([x_up2, x0], dim=1))

        return output

# Create improved noise scheduler for sampling with DDIM
class ImprovedSamplingScheduler:
    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device="cuda", ddim_steps=50, ddim_eta=0.0):
        """Initialize diffusion process parameters with DDIM support."""
        self.timesteps = timesteps
        self.ddim_steps = ddim_steps
        self.ddim_eta = ddim_eta
        self.ddim_timesteps = torch.linspace(0, timesteps - 1, ddim_steps).long().to(device)
        self.ddim_timestep_map = self._get_ddim_timestep_map()

        # Define forward process variances (betas)
        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)

        # Calculate alphas and other constants
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alphas_cumprod[:-1]])

        # For posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)

    def _get_ddim_timestep_map(self):
        """Get the mapping from ddim step to original diffusion step."""
        steps = torch.arange(self.timesteps)
        return {int(t.item()): i for i, t in enumerate(self.ddim_timesteps)}

    def p_sample_ddim(self, model, x_t, t_index):
        """Single step DDIM sampling."""
        with torch.no_grad():
            t = self.ddim_timesteps[t_index]
            next_t = self.ddim_timesteps[min(t_index + 1, len(self.ddim_timesteps) - 1)]

            # Use model to predict noise
            pred_noise = model(x_t, t.unsqueeze(0).repeat(x_t.shape[0]))

            # Extract relevant alphas
            alpha_cumprod_t = self.alphas_cumprod[t]
            alpha_cumprod_next = self.alphas_cumprod[next_t] if next_t < self.timesteps else torch.tensor(1.0).to(x_t.device)

            # Predict original sample
            sqrt_recip_alpha_cumprod = (1 / torch.sqrt(alpha_cumprod_t)).view(-1, 1, 1, 1)
            sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod_t).view(-1, 1, 1, 1)
            pred_x0 = (x_t - sqrt_one_minus_alpha_cumprod * pred_noise) * sqrt_recip_alpha_cumprod
            pred_x0 = torch.clamp(pred_x0, -1.0, 1.0)

            # Direction pointing to x_t
            dir_xt = torch.sqrt(1.0 - alpha_cumprod_next).view(-1, 1, 1, 1) * pred_noise

            # DDIM formula
            x_next = torch.sqrt(alpha_cumprod_next).view(-1, 1, 1, 1) * pred_x0 + dir_xt

            # Add noise if eta > 0 (between DDIM and DDPM)
            if t_index < len(self.ddim_timesteps) - 1 and self.ddim_eta > 0:
                noise = torch.randn_like(x_t)
                variance = ((1 - alpha_cumprod_next) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_next)).view(-1, 1, 1, 1)
                x_next = x_next + self.ddim_eta * torch.sqrt(variance) * noise

            return x_next

    def p_sample_loop_ddim(self, model, shape, save_intermediate=False):
        """DDIM sampling loop."""
        model.eval()

        # Start from pure noise
        x = torch.randn(shape).to(DEVICE)

        # For storing intermediate results
        intermediates = [] if save_intermediate else None
        if save_intermediate:
            intermediates.append(x.cpu())

        # Gradually denoise with DDIM
        for i in tqdm(reversed(range(self.ddim_steps)), desc="DDIM Sampling"):
            # Free memory occasionally
            if i % 10 == 0:
                clear_memory()

            x = self.p_sample_ddim(model, x, i)

            # Save intermediate
            if save_intermediate and i % max(1, self.ddim_steps // 10) == 0:
                intermediates.append(x.cpu())

        if save_intermediate:
            return x, intermediates
        return x

    def p_sample(self, model, x_t, t):
        """Single step sampling from p(x_{t-1} | x_t)."""
        with torch.no_grad():
            # Get model prediction (predicted noise)
            pred_noise = model(x_t, t)

            # Extract values for the current timestep
            alpha = self.alphas[t]
            alpha_cumprod = self.alphas_cumprod[t]
            alpha_cumprod_prev = self.alphas_cumprod_prev[t]

            # Calculate predicted x_0 with proper reshaping
            sqrt_recip_alpha_cumprod = (1 / torch.sqrt(alpha_cumprod)).view(-1, 1, 1, 1)
            sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod).view(-1, 1, 1, 1)

            pred_x0 = (x_t - sqrt_one_minus_alpha_cumprod * pred_noise) * sqrt_recip_alpha_cumprod
            pred_x0 = torch.clamp(pred_x0, -1.0, 1.0)

            # Calculate mean for p(x_{t-1} | x_t) with proper reshaping for all coefficients
            coef1_numerator = (self.betas[t] * torch.sqrt(alpha_cumprod_prev)).view(-1, 1, 1, 1)
            coef1_denominator = (1 - alpha_cumprod).view(-1, 1, 1, 1)
            coef1 = coef1_numerator / coef1_denominator

            coef2_numerator = ((1 - alpha_cumprod_prev) * torch.sqrt(alpha)).view(-1, 1, 1, 1)
            coef2_denominator = (1 - alpha_cumprod).view(-1, 1, 1, 1)
            coef2 = coef2_numerator / coef2_denominator

            mean = coef1 * pred_x0 + coef2 * x_t

            # Add noise if t > 0, otherwise return the mean
            if t[0] > 0:  # Check the first element of t
                noise = torch.randn_like(x_t)
                var = torch.sqrt(self.posterior_variance[t]).view(-1, 1, 1, 1)
                return mean + var * noise
            else:
                return mean

    def p_sample_loop(self, model, shape, save_intermediate=False):
        """Generate samples using the reverse diffusion process."""
        model.eval()

        # Start from pure noise
        x = torch.randn(shape).to(DEVICE)

        # For storing intermediate results
        intermediates = [] if save_intermediate else None
        if save_intermediate:
            intermediates.append(x.cpu())

        # Gradually denoise
        for t_idx in tqdm(reversed(range(self.timesteps)), desc="DDPM Sampling"):
            # Use a single timestep for all batch elements
            t = torch.full((shape[0],), t_idx, device=DEVICE, dtype=torch.long)

            # Free memory before each step
            if t_idx % 100 == 0:
                clear_memory()

            # Sample
            x = self.p_sample(model, x, t)

            # Save intermediate
            if save_intermediate and t_idx % 100 == 0:
                intermediates.append(x.cpu())

        if save_intermediate:
            return x, intermediates
        return x

def load_model_for_sampling():
    """Load model from checkpoint for sampling."""
    if not os.path.exists(CHECKPOINT_PATH):
        # Fall back to regular checkpoint if best model doesn't exist
        alternative_path = "diffusion_checkpoint.pt"
        if os.path.exists(alternative_path):
            print(f"Best model not found, using {alternative_path} instead")
            checkpoint_path = alternative_path
        else:
            raise FileNotFoundError(f"Checkpoint files not found!")
    else:
        checkpoint_path = CHECKPOINT_PATH

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)

    # Get model parameters
    model_params = checkpoint.get("model_params", {
        "in_channels": 3,
        "out_channels": 3,
        "base_channels": BASE_CHANNELS
    })

    # Create model
    model = ImprovedUNet(
        in_channels=model_params["in_channels"],
        out_channels=model_params["out_channels"],
        base_channels=model_params["base_channels"]
    ).to(DEVICE)

    # Load weights - prefer EMA model weights if available
    if "ema_model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["ema_model_state_dict"])
        print("Using EMA model weights for sampling")
    else:
        model.load_state_dict(checkpoint["model_state_dict"])
        print("Using regular model weights for sampling")

    # Get scheduler parameters
    scheduler_params = checkpoint.get("scheduler_params", {
        "timesteps": 1000,
        "beta_start": 1e-4,
        "beta_end": 0.02
    })

    # Create improved scheduler with DDIM support
    scheduler = ImprovedSamplingScheduler(
        timesteps=scheduler_params["timesteps"],
        beta_start=scheduler_params["beta_start"],
        beta_end=scheduler_params["beta_end"],
        device=DEVICE,
        ddim_steps=50  # Use 50 steps for faster sampling
    )

    return model, scheduler, checkpoint["epoch"]

def generate_samples(save_dir="samples", num_samples=8, save_intermediate=False, use_ddim=True):
    """Generate samples from trained model."""
    global SAMPLE_BATCH_SIZE  # Use the global variable

    os.makedirs(save_dir, exist_ok=True)

    # Load model
    model, scheduler, epoch = load_model_for_sampling()
    print(f"Loaded model from epoch {epoch}")

    # Free memory before sampling
    clear_memory()

    # Generate samples in small batches
    all_samples = []
    remaining = num_samples

    while remaining > 0:
        try:
            batch_size = min(SAMPLE_BATCH_SIZE, remaining)
            print(f"Generating batch of {batch_size} samples")

            if save_intermediate:
                if use_ddim:
                    samples, intermediates = scheduler.p_sample_loop_ddim(
                        model,
                        shape=(batch_size, 3, 32, 32),
                        save_intermediate=True
                    )
                else:
                    samples, intermediates = scheduler.p_sample_loop(
                        model,
                        shape=(batch_size, 3, 32, 32),
                        save_intermediate=True
                    )
                # Save intermediate samples for visualization
                visualize_diffusion_process(intermediates, f"{save_dir}/diffusion_process{'_ddim' if use_ddim else ''}.png")
            else:
                if use_ddim:
                    samples = scheduler.p_sample_loop_ddim(
                        model,
                        shape=(batch_size, 3, 32, 32),
                        save_intermediate=False
                    )
                else:
                    samples = scheduler.p_sample_loop(
                        model,
                        shape=(batch_size, 3, 32, 32),
                        save_intermediate=False
                    )

            all_samples.append(samples.cpu())
            remaining -= batch_size

            # Clear memory after each batch
            clear_memory()

        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                # Reduce batch size and try again
                SAMPLE_BATCH_SIZE = max(1, SAMPLE_BATCH_SIZE // 2)
                print(f"OOM error, reducing batch size to {SAMPLE_BATCH_SIZE}")
                clear_memory()
            else:
                raise e

    # Combine all samples
    samples = torch.cat(all_samples, dim=0)

    # Denormalize images
    samples = (samples + 1) / 2  # Scale from [-1, 1] to [0, 1]

    # Create grid and save
    grid = torchvision.utils.make_grid(samples, nrow=4)
    output_path = f"{save_dir}/final_samples_epoch_{epoch}{'_ddim' if use_ddim else ''}.png"
    torchvision.utils.save_image(grid, output_path)

    print(f"Generated {num_samples} samples saved to {output_path}")
    return samples

def visualize_diffusion_process(intermediate_samples, save_path):
    """Visualize the diffusion process using intermediate samples."""
    # Select a few timesteps to visualize
    fig, ax = plt.subplots(1, len(intermediate_samples), figsize=(20, 4))

    for i, samples in enumerate(intermediate_samples):
        # Take the first image from each batch
        img = samples[0].permute(1, 2, 0).numpy()
        img = (img + 1) / 2  # Denormalize
        img = np.clip(img, 0, 1)

        if len(intermediate_samples) > 1:
            ax[i].imshow(img)
            ax[i].set_title(f"Step {i * (1000 // (len(intermediate_samples)-1))}" if i > 0 else "Noise")
            ax[i].axis('off')
        else:
            ax.imshow(img)
            ax.set_title(f"Step {i}")
            ax.axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

# Run sampling
if __name__ == "__main__":
    # Clear memory before starting
    clear_memory()

    # Generate samples
    try:
        # First try with more samples using DDIM for faster sampling
        generate_samples(num_samples=16, save_intermediate=True, use_ddim=True)

        # Also generate samples with regular DDPM for comparison
        generate_samples(num_samples=4, save_intermediate=True, use_ddim=False)
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            # Fall back to fewer samples
            clear_memory()
            print("Retrying with fewer samples...")
            generate_samples(num_samples=4, save_intermediate=False, use_ddim=True)
        else:
            raise e

Using EMA model weights for sampling
Loaded model from epoch 177
Generating batch of 4 samples


DDIM Sampling: 50it [00:01, 48.63it/s]


Generating batch of 4 samples


DDIM Sampling: 50it [00:00, 56.87it/s]


Generating batch of 4 samples


DDIM Sampling: 50it [00:00, 54.15it/s]


Generating batch of 4 samples


DDIM Sampling: 50it [00:00, 52.03it/s]


Generated 16 samples saved to samples/final_samples_epoch_177_ddim.png
Using EMA model weights for sampling
Loaded model from epoch 177
Generating batch of 4 samples


DDPM Sampling: 1000it [00:07, 138.55it/s]


Generated 4 samples saved to samples/final_samples_epoch_177.png


In [None]:
from google.colab import drive
drive.mount('/content/drive')

ValueError: mount failed