<a href="https://colab.research.google.com/github/abhay-ramamurthy/Latent-Diffusion-Model/blob/main/Latent_Diffucion_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Latent Diffusion Model with VAE + Conditional DDPM + Class Conditioning + Full Debug Pipeline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import numpy as np
import os
import random

# -------------------- Setup --------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
torch.autograd.set_detect_anomaly(True)

# -------------------- VAE --------------------
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU()
        )
        self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)

    def forward(self, x):
        x = self.conv(x).view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 128 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 3, 1, 1), nn.Sigmoid()
        )

    def forward(self, z):
        z = self.fc(z).view(z.size(0), 128, 4, 4)
        return self.deconv(z)

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar, z

# -------------------- Conditional U-Net --------------------
class ConditionalUNet(nn.Module):
    def __init__(self, latent_dim, num_classes=10):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, latent_dim)
        self.model = nn.Sequential(
            nn.Conv2d(latent_dim * 2, 128, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(128, latent_dim, 3, 1, 1)
        )

    def forward(self, x, t, y):
        emb = self.label_embedding(y).view(y.size(0), -1, 1, 1).expand(-1, -1, x.size(2), x.size(3))
        x_cat = torch.cat([x, emb], dim=1)
        return self.model(x_cat)

# -------------------- Diffusion --------------------
def linear_beta_schedule(timesteps):
    return torch.linspace(0.0001, 0.02, timesteps)

def q_sample(x0, t, noise, betas):
    sqrt_alphas_cumprod = torch.sqrt(torch.cumprod(1.0 - betas, dim=0)).to(x0.device)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - torch.cumprod(1.0 - betas, dim=0)).to(x0.device)
    sqrt_alphas = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_one_minus = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
    return sqrt_alphas * x0 + sqrt_one_minus * noise

def diffusion_loss(model, x0, t, noise, betas):
    xt = q_sample(x0, t, noise, betas)
    pred_noise = model(xt, t)
    return F.mse_loss(pred_noise, noise)

# -------------------- Dataset --------------------
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=256, shuffle=True)

# -------------------- Init --------------------
dev = "cuda" if torch.cuda.is_available() else "cpu"
vae = VAE(latent_dim=64).to(dev)
unet = ConditionalUNet(latent_dim=64).to(dev)
optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-3)
optimizer_ddpm = torch.optim.Adam(unet.parameters(), lr=1e-4)

T = 1000
betas = linear_beta_schedule(T).to(dev)

# -------------------- Training --------------------
for epoch in range(200):
    for i, (img, labels) in enumerate(loader):
        img = img.to(dev)
        labels = labels.to(dev)

        recon, mu, logvar, z = vae(img)
        vae_loss = F.mse_loss(recon, img) + 0.001 * (
            -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / img.size(0)
        )
        optimizer_vae.zero_grad()
        vae_loss.backward()
        optimizer_vae.step()

        z_latent = z.detach().clone().reshape(z.size(0), 64, 1, 1).repeat(1, 1, 8, 8).requires_grad_()
        t = torch.randint(0, T, (img.size(0),), device=dev)
        noise = torch.randn_like(z_latent)

        ddpm_loss = diffusion_loss(lambda x, t: unet(x, t, labels), z_latent, t, noise, betas)
        optimizer_ddpm.zero_grad()
        ddpm_loss.backward()
        optimizer_ddpm.step()

        if i % 50 == 0:
            print(f"[Epoch {epoch+1} Batch {i+1}] VAE Loss: {vae_loss.item():.4f}, DDPM Loss: {ddpm_loss.item():.4f}")

# -------------------- Sampling Functions --------------------
def p_sample(model, zt, t, betas):
    beta_t = betas[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - torch.cumprod(1.0 - betas, dim=0)).to(zt.device)
    sqrt_recip_alphas_cumprod = 1.0 / torch.sqrt(torch.cumprod(1.0 - betas, dim=0)).to(zt.device)
    eps_theta = model(zt, torch.full((zt.shape[0],), t, dtype=torch.long).to(zt.device))
    mean = sqrt_recip_alphas_cumprod[t] * (zt - beta_t * eps_theta / sqrt_one_minus_alphas_cumprod[t])
    return mean + (torch.sqrt(beta_t) * torch.randn_like(zt) if t > 0 else 0)

def sample_with_label(model, shape, betas, T, y):
    zt = torch.randn(shape).to(next(model.parameters()).device)
    for t in reversed(range(T)):
        zt = p_sample(lambda x, t: model(x, t, y), zt, t, betas)
    return zt

# -------------------- Save Conditioned DDPM Samples (Fixed + Debug + UNet Bypass Test) --------------------

# 1. Sample latent with DDPM + class labels
sample_labels = torch.randint(0, 10, (8,), device=dev)
latent_ddpm = sample_with_label(unet, shape=(8, 64, 8, 8), betas=betas, T=T, y=sample_labels)

# 2. Average spatial latent to get shape [8, 64] for VAE decoder
latent_vectors = latent_ddpm.mean(dim=[2, 3])

# 3. Debug: print latent stats
print("Mean latent vector value:", latent_vectors.mean().item())
print("Std latent vector value:", latent_vectors.std().item())

# 4. Decode via VAE
final_images = vae.decoder(latent_vectors)
os.makedirs("outputs", exist_ok=True)
save_image(final_images, "outputs/generated_conditioned_ddpm.png")
print("✅ Saved: outputs/generated_conditioned_ddpm.png")

# 4B. Debug: test ConditionalUNet embedding fusion only (no diffusion loop)
z_debug = torch.randn(8, 64, 8, 8).to(dev)
y_debug = torch.randint(0, 10, (8,), device=dev)
z_fused = unet(z_debug, t=torch.zeros(8, dtype=torch.long).to(dev), y=y_debug)
latent_fused_vectors = z_fused.mean(dim=[2, 3])
fused_images = vae.decoder(latent_fused_vectors)
save_image(fused_images, "outputs/debug_fused_unet_latents.png")
print("🧪 Saved debug UNet fusion output: outputs/debug_fused_unet_latents.png")

# 5. Test decoder with encoder latents (baseline)
img, _ = next(iter(loader))
img = img.to(dev)
_, _, _, z = vae(img)
test_recon = vae.decoder(z)
save_image(test_recon, "outputs/test_decoder_output.png")
print("🧪 Test decoder output saved: outputs/test_decoder_output.png")

# 6. Test decoder with pure random latents
rand_latents = torch.randn(8, 64).to(dev)
test_rand = vae.decoder(rand_latents)
save_image(test_rand, "outputs/test_rand_input.png")
print("🧪 Test random input saved: outputs/test_rand_input.png")



100%|██████████| 170M/170M [00:14<00:00, 12.0MB/s]


[Epoch 1 Batch 1] VAE Loss: 0.0649, DDPM Loss: 1.0046
[Epoch 1 Batch 51] VAE Loss: 0.0476, DDPM Loss: 0.9315
[Epoch 1 Batch 101] VAE Loss: 0.0442, DDPM Loss: 0.8106
[Epoch 1 Batch 151] VAE Loss: 0.0386, DDPM Loss: 0.6978
[Epoch 2 Batch 1] VAE Loss: 0.0379, DDPM Loss: 0.6389
[Epoch 2 Batch 51] VAE Loss: 0.0399, DDPM Loss: 0.5695
[Epoch 2 Batch 101] VAE Loss: 0.0360, DDPM Loss: 0.5475
[Epoch 2 Batch 151] VAE Loss: 0.0374, DDPM Loss: 0.4967
[Epoch 3 Batch 1] VAE Loss: 0.0373, DDPM Loss: 0.4755
[Epoch 3 Batch 51] VAE Loss: 0.0372, DDPM Loss: 0.4386
[Epoch 3 Batch 101] VAE Loss: 0.0354, DDPM Loss: 0.4235
[Epoch 3 Batch 151] VAE Loss: 0.0350, DDPM Loss: 0.4231
[Epoch 4 Batch 1] VAE Loss: 0.0346, DDPM Loss: 0.3985
[Epoch 4 Batch 51] VAE Loss: 0.0355, DDPM Loss: 0.3869
[Epoch 4 Batch 101] VAE Loss: 0.0344, DDPM Loss: 0.3588
[Epoch 4 Batch 151] VAE Loss: 0.0349, DDPM Loss: 0.3614
[Epoch 5 Batch 1] VAE Loss: 0.0335, DDPM Loss: 0.3466
[Epoch 5 Batch 51] VAE Loss: 0.0343, DDPM Loss: 0.3471
[Epoch 

In [None]:
# Latent Diffusion Model with VAE + Conditional DDPM + Class Conditioning + Full Debug Pipeline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import numpy as np
import os
import random

# -------------------- Setup --------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
torch.autograd.set_detect_anomaly(True)

# -------------------- VAE --------------------
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU()
        )
        self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)

    def forward(self, x):
        x = self.conv(x).view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 128 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 3, 1, 1), nn.Sigmoid()
        )

    def forward(self, z):
        z = self.fc(z).view(z.size(0), 128, 4, 4)
        return self.deconv(z)

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar, z

# -------------------- Conditional U-Net --------------------
class ConditionalUNet(nn.Module):
    def __init__(self, latent_dim, num_classes=10):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, latent_dim)
        self.model = nn.Sequential(
            nn.Conv2d(latent_dim * 2, 128, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(128, latent_dim, 3, 1, 1)
        )

    def forward(self, x, t, y):
        emb = self.label_embedding(y).view(y.size(0), -1, 1, 1).expand(-1, -1, x.size(2), x.size(3))
        x_cat = torch.cat([x, emb], dim=1)
        return self.model(x_cat)

# -------------------- Diffusion --------------------
def linear_beta_schedule(timesteps):
    return torch.linspace(0.0001, 0.02, timesteps)

def q_sample(x0, t, noise, betas):
    sqrt_alphas_cumprod = torch.sqrt(torch.cumprod(1.0 - betas, dim=0)).to(x0.device)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - torch.cumprod(1.0 - betas, dim=0)).to(x0.device)
    sqrt_alphas = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_one_minus = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
    return sqrt_alphas * x0 + sqrt_one_minus * noise

def diffusion_loss(model, x0, t, noise, betas):
    xt = q_sample(x0, t, noise, betas)
    pred_noise = model(xt, t)
    return F.mse_loss(pred_noise, noise)

# -------------------- Dataset --------------------
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=256, shuffle=True)

# -------------------- Init --------------------
dev = "cuda" if torch.cuda.is_available() else "cpu"
vae = VAE(latent_dim=64).to(dev)
unet = ConditionalUNet(latent_dim=64).to(dev)
optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-3)
optimizer_ddpm = torch.optim.Adam(unet.parameters(), lr=1e-4)

T = 1000
betas = linear_beta_schedule(T).to(dev)

# -------------------- Training --------------------
for epoch in range(200):
    for i, (img, labels) in enumerate(loader):
        img = img.to(dev)
        labels = labels.to(dev)

        recon, mu, logvar, z = vae(img)
        vae_loss = F.mse_loss(recon, img) + 0.001 * (
            -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / img.size(0)
        )
        optimizer_vae.zero_grad()
        vae_loss.backward()
        optimizer_vae.step()

        z_latent = z.detach().clone().reshape(z.size(0), 64, 1, 1).repeat(1, 1, 8, 8).requires_grad_()
        t = torch.randint(0, T, (img.size(0),), device=dev)
        noise = torch.randn_like(z_latent)

        ddpm_loss = diffusion_loss(lambda x, t: unet(x, t, labels), z_latent, t, noise, betas)
        optimizer_ddpm.zero_grad()
        ddpm_loss.backward()
        optimizer_ddpm.step()

        if i % 50 == 0:
            print(f"[Epoch {epoch+1} Batch {i+1}] VAE Loss: {vae_loss.item():.4f}, DDPM Loss: {ddpm_loss.item():.4f}")

# -------------------- Sampling Functions --------------------
def p_sample(model, zt, t, betas):
    beta_t = betas[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - torch.cumprod(1.0 - betas, dim=0)).to(zt.device)
    sqrt_recip_alphas_cumprod = 1.0 / torch.sqrt(torch.cumprod(1.0 - betas, dim=0)).to(zt.device)
    eps_theta = model(zt, torch.full((zt.shape[0],), t, dtype=torch.long).to(zt.device))
    mean = sqrt_recip_alphas_cumprod[t] * (zt - beta_t * eps_theta / sqrt_one_minus_alphas_cumprod[t])
    return mean + (torch.sqrt(beta_t) * torch.randn_like(zt) if t > 0 else 0)

def sample_with_label(model, shape, betas, T, y):
    zt = torch.randn(shape).to(next(model.parameters()).device)
    for t in reversed(range(T)):
        zt = p_sample(lambda x, t: model(x, t, y), zt, t, betas)
    return zt

# -------------------- Save Conditioned DDPM Samples (Fixed + Debug + UNet Bypass Test) --------------------

# 1. Sample latent with DDPM + class labels
sample_labels = torch.randint(0, 10, (8,), device=dev)
latent_ddpm = sample_with_label(unet, shape=(8, 64, 8, 8), betas=betas, T=T, y=sample_labels)

# 2. Average spatial latent to get shape [8, 64] for VAE decoder
latent_vectors = latent_ddpm.mean(dim=[2, 3])

# 3. Debug: print latent stats
print("Mean latent vector value:", latent_vectors.mean().item())
print("Std latent vector value:", latent_vectors.std().item())

# 4. Decode via VAE
final_images = vae.decoder(latent_vectors)
os.makedirs("outputs", exist_ok=True)
save_image(final_images, "outputs/generated_conditioned_ddpm.png")
print("✅ Saved: outputs/generated_conditioned_ddpm.png")

# 4B. Debug: test ConditionalUNet embedding fusion only (no diffusion loop)
z_debug = torch.randn(8, 64, 8, 8).to(dev)
y_debug = torch.randint(0, 10, (8,), device=dev)
z_fused = unet(z_debug, t=torch.zeros(8, dtype=torch.long).to(dev), y=y_debug)
latent_fused_vectors = z_fused.mean(dim=[2, 3])
fused_images = vae.decoder(latent_fused_vectors)
save_image(fused_images, "outputs/debug_fused_unet_latents.png")
print("🧪 Saved debug UNet fusion output: outputs/debug_fused_unet_latents.png")

# 5. Test decoder with encoder latents (baseline)
img, _ = next(iter(loader))
img = img.to(dev)
_, _, _, z = vae(img)
test_recon = vae.decoder(z)
save_image(test_recon, "outputs/test_decoder_output.png")
print("🧪 Test decoder output saved: outputs/test_decoder_output.png")

# 6. Test decoder with pure random latents
rand_latents = torch.randn(8, 64).to(dev)
test_rand = vae.decoder(rand_latents)
save_image(test_rand, "outputs/test_rand_input.png")
print("🧪 Test random input saved: outputs/test_rand_input.png")



100%|██████████| 170M/170M [00:14<00:00, 12.0MB/s]


[Epoch 1 Batch 1] VAE Loss: 0.0649, DDPM Loss: 1.0046
[Epoch 1 Batch 51] VAE Loss: 0.0476, DDPM Loss: 0.9315
[Epoch 1 Batch 101] VAE Loss: 0.0442, DDPM Loss: 0.8106
[Epoch 1 Batch 151] VAE Loss: 0.0386, DDPM Loss: 0.6978
[Epoch 2 Batch 1] VAE Loss: 0.0379, DDPM Loss: 0.6389
[Epoch 2 Batch 51] VAE Loss: 0.0399, DDPM Loss: 0.5695
[Epoch 2 Batch 101] VAE Loss: 0.0360, DDPM Loss: 0.5475
[Epoch 2 Batch 151] VAE Loss: 0.0374, DDPM Loss: 0.4967
[Epoch 3 Batch 1] VAE Loss: 0.0373, DDPM Loss: 0.4755
[Epoch 3 Batch 51] VAE Loss: 0.0372, DDPM Loss: 0.4386
[Epoch 3 Batch 101] VAE Loss: 0.0354, DDPM Loss: 0.4235
[Epoch 3 Batch 151] VAE Loss: 0.0350, DDPM Loss: 0.4231
[Epoch 4 Batch 1] VAE Loss: 0.0346, DDPM Loss: 0.3985
[Epoch 4 Batch 51] VAE Loss: 0.0355, DDPM Loss: 0.3869
[Epoch 4 Batch 101] VAE Loss: 0.0344, DDPM Loss: 0.3588
[Epoch 4 Batch 151] VAE Loss: 0.0349, DDPM Loss: 0.3614
[Epoch 5 Batch 1] VAE Loss: 0.0335, DDPM Loss: 0.3466
[Epoch 5 Batch 51] VAE Loss: 0.0343, DDPM Loss: 0.3471
[Epoch 

In [None]:
# Latent Diffusion Model with VAE + Conditional DDPM + Class Conditioning + Full Debug Pipeline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import numpy as np
import os
import random

# -------------------- Setup --------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
torch.autograd.set_detect_anomaly(True)

# -------------------- VAE --------------------
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.ReLU()
        )
        self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)

    def forward(self, x):
        x = self.conv(x).view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 128 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 128, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 3, 1, 1), nn.Sigmoid()
        )

    def forward(self, z):
        z = self.fc(z).view(z.size(0), 128, 4, 4)
        return self.deconv(z)

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar, z

# -------------------- Conditional U-Net --------------------
class ConditionalUNet(nn.Module):
    def __init__(self, latent_dim, num_classes=10):
        super().__init__()
        self.label_embedding = nn.Embedding(num_classes, latent_dim)
        self.model = nn.Sequential(
            nn.Conv2d(latent_dim * 2, 128, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(128, latent_dim, 3, 1, 1)
        )

    def forward(self, x, t, y):
        emb = self.label_embedding(y).view(y.size(0), -1, 1, 1).expand(-1, -1, x.size(2), x.size(3))
        x_cat = torch.cat([x, emb], dim=1)
        return self.model(x_cat)

# -------------------- Diffusion --------------------
def linear_beta_schedule(timesteps):
    return torch.linspace(0.0001, 0.02, timesteps)

def q_sample(x0, t, noise, betas):
    sqrt_alphas_cumprod = torch.sqrt(torch.cumprod(1.0 - betas, dim=0)).to(x0.device)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - torch.cumprod(1.0 - betas, dim=0)).to(x0.device)
    sqrt_alphas = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_one_minus = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
    return sqrt_alphas * x0 + sqrt_one_minus * noise

def diffusion_loss(model, x0, t, noise, betas):
    xt = q_sample(x0, t, noise, betas)
    pred_noise = model(xt, t)
    return F.mse_loss(pred_noise, noise)

# -------------------- Dataset --------------------
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=256, shuffle=True)

# -------------------- Init --------------------
dev = "cuda" if torch.cuda.is_available() else "cpu"
vae = VAE(latent_dim=64).to(dev)
unet = ConditionalUNet(latent_dim=64).to(dev)
optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-3)
optimizer_ddpm = torch.optim.Adam(unet.parameters(), lr=1e-4)

T = 1000
betas = linear_beta_schedule(T).to(dev)

# -------------------- Training --------------------
for epoch in range(200):
    for i, (img, labels) in enumerate(loader):
        img = img.to(dev)
        labels = labels.to(dev)

        recon, mu, logvar, z = vae(img)
        vae_loss = F.mse_loss(recon, img) + 0.001 * (
            -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / img.size(0)
        )
        optimizer_vae.zero_grad()
        vae_loss.backward()
        optimizer_vae.step()

        z_latent = z.detach().clone().reshape(z.size(0), 64, 1, 1).repeat(1, 1, 8, 8).requires_grad_()
        t = torch.randint(0, T, (img.size(0),), device=dev)
        noise = torch.randn_like(z_latent)

        ddpm_loss = diffusion_loss(lambda x, t: unet(x, t, labels), z_latent, t, noise, betas)
        optimizer_ddpm.zero_grad()
        ddpm_loss.backward()
        optimizer_ddpm.step()

        if i % 50 == 0:
            print(f"[Epoch {epoch+1} Batch {i+1}] VAE Loss: {vae_loss.item():.4f}, DDPM Loss: {ddpm_loss.item():.4f}")

# -------------------- Sampling Functions --------------------
def p_sample(model, zt, t, betas):
    beta_t = betas[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - torch.cumprod(1.0 - betas, dim=0)).to(zt.device)
    sqrt_recip_alphas_cumprod = 1.0 / torch.sqrt(torch.cumprod(1.0 - betas, dim=0)).to(zt.device)
    eps_theta = model(zt, torch.full((zt.shape[0],), t, dtype=torch.long).to(zt.device))
    mean = sqrt_recip_alphas_cumprod[t] * (zt - beta_t * eps_theta / sqrt_one_minus_alphas_cumprod[t])
    return mean + (torch.sqrt(beta_t) * torch.randn_like(zt) if t > 0 else 0)

def sample_with_label(model, shape, betas, T, y):
    zt = torch.randn(shape).to(next(model.parameters()).device)
    for t in reversed(range(T)):
        zt = p_sample(lambda x, t: model(x, t, y), zt, t, betas)
    return zt

# -------------------- Save Conditioned DDPM Samples (Fixed + Debug + UNet Bypass Test) --------------------

# 1. Sample latent with DDPM + class labels
sample_labels = torch.randint(0, 10, (8,), device=dev)
latent_ddpm = sample_with_label(unet, shape=(8, 64, 8, 8), betas=betas, T=T, y=sample_labels)

# 2. Average spatial latent to get shape [8, 64] for VAE decoder
latent_vectors = latent_ddpm.mean(dim=[2, 3])

# 3. Debug: print latent stats
print("Mean latent vector value:", latent_vectors.mean().item())
print("Std latent vector value:", latent_vectors.std().item())

# 4. Decode via VAE
final_images = vae.decoder(latent_vectors)
os.makedirs("outputs", exist_ok=True)
save_image(final_images, "outputs/generated_conditioned_ddpm.png")
print("✅ Saved: outputs/generated_conditioned_ddpm.png")

# 4B. Debug: test ConditionalUNet embedding fusion only (no diffusion loop)
z_debug = torch.randn(8, 64, 8, 8).to(dev)
y_debug = torch.randint(0, 10, (8,), device=dev)
z_fused = unet(z_debug, t=torch.zeros(8, dtype=torch.long).to(dev), y=y_debug)
latent_fused_vectors = z_fused.mean(dim=[2, 3])
fused_images = vae.decoder(latent_fused_vectors)
save_image(fused_images, "outputs/debug_fused_unet_latents.png")
print("🧪 Saved debug UNet fusion output: outputs/debug_fused_unet_latents.png")

# 5. Test decoder with encoder latents (baseline)
img, _ = next(iter(loader))
img = img.to(dev)
_, _, _, z = vae(img)
test_recon = vae.decoder(z)
save_image(test_recon, "outputs/test_decoder_output.png")
print("🧪 Test decoder output saved: outputs/test_decoder_output.png")

# 6. Test decoder with pure random latents
rand_latents = torch.randn(8, 64).to(dev)
test_rand = vae.decoder(rand_latents)
save_image(test_rand, "outputs/test_rand_input.png")
print("🧪 Test random input saved: outputs/test_rand_input.png")



100%|██████████| 170M/170M [00:14<00:00, 12.0MB/s]


[Epoch 1 Batch 1] VAE Loss: 0.0649, DDPM Loss: 1.0046
[Epoch 1 Batch 51] VAE Loss: 0.0476, DDPM Loss: 0.9315
[Epoch 1 Batch 101] VAE Loss: 0.0442, DDPM Loss: 0.8106
[Epoch 1 Batch 151] VAE Loss: 0.0386, DDPM Loss: 0.6978
[Epoch 2 Batch 1] VAE Loss: 0.0379, DDPM Loss: 0.6389
[Epoch 2 Batch 51] VAE Loss: 0.0399, DDPM Loss: 0.5695
[Epoch 2 Batch 101] VAE Loss: 0.0360, DDPM Loss: 0.5475
[Epoch 2 Batch 151] VAE Loss: 0.0374, DDPM Loss: 0.4967
[Epoch 3 Batch 1] VAE Loss: 0.0373, DDPM Loss: 0.4755
[Epoch 3 Batch 51] VAE Loss: 0.0372, DDPM Loss: 0.4386
[Epoch 3 Batch 101] VAE Loss: 0.0354, DDPM Loss: 0.4235
[Epoch 3 Batch 151] VAE Loss: 0.0350, DDPM Loss: 0.4231
[Epoch 4 Batch 1] VAE Loss: 0.0346, DDPM Loss: 0.3985
[Epoch 4 Batch 51] VAE Loss: 0.0355, DDPM Loss: 0.3869
[Epoch 4 Batch 101] VAE Loss: 0.0344, DDPM Loss: 0.3588
[Epoch 4 Batch 151] VAE Loss: 0.0349, DDPM Loss: 0.3614
[Epoch 5 Batch 1] VAE Loss: 0.0335, DDPM Loss: 0.3466
[Epoch 5 Batch 51] VAE Loss: 0.0343, DDPM Loss: 0.3471
[Epoch 