In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np

In [None]:
# Prompt embedding: map digit to string or word
digit_to_string = {i: str(i) for i in range(10)}
digit_to_word = {0: "zero", 1: "one", 2: "two", 3: "three", 4: "four", 5: "five", 6: "six", 7: "seven", 8: "eight", 9: "nine"}

# Simple prompt encoder (embedding layer)
class PromptEncoder(nn.Module):
  def __init__(self, vocab, emb_dim=8, hidden_dim=32):
    super(PromptEncoder, self).__init__()
    self.vocab = vocab
    self.emb = nn.Embedding(len(vocab), emb_dim)

    self.to_gamma = nn.Linear(emb_dim, hidden_dim)
    self.to_beta = nn.Linear(emb_dim, hidden_dim)

  def forward(self, labels):
    # labels is a list of strings (e.g., ["zero", "one", "two"])
    idxs = torch.tensor([self.vocab[label] for label in labels])
    idxs = idxs.to(next(self.parameters()).device)  # Move to same device as model
    
    emb = self.emb(idxs)

    gamma = self.to_gamma(emb)
    beta = self.to_beta(emb)

    return gamma, beta 
  
def apply_film(x, gamma, beta):
  # x: [B, C, H, W], gamma/beta: [B, C]
  # Reshape gamma/beta for broadcasting  
  gamma = gamma[:, :, None, None]
  beta = beta[:, :, None, None]
  return gamma * x + beta

In [None]:
# Simple UNet for small images (28x28)
class SmallUNet(torch.nn.Module):
  def __init__(self, in_channels=1, out_channels=1, base_channels=16):  # Changed to 2 input channels
    super(SmallUNet, self).__init__()
    self.base_channels = base_channels

    self.enc1 = nn.Sequential(nn.Conv2d(in_channels, base_channels, 3, padding=1), nn.SiLU())
    self.enc2 = nn.Sequential(nn.Conv2d(base_channels, base_channels*2, 3, padding=1), nn.SiLU())
    self.enc3 = nn.Sequential(nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1), nn.SiLU())

    self.dec3 = nn.Sequential(nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1), nn.SiLU())
    
    self.up3 = nn.ConvTranspose2d(base_channels*2, base_channels*2, kernel_size=2, stride=2)
    self.dec2 = nn.Sequential(nn.Conv2d(base_channels*4, base_channels, 3, padding=1), nn.SiLU())
    
    self.up2 = nn.ConvTranspose2d(base_channels, base_channels, kernel_size=2, stride=2)
    self.dec1 = nn.Conv2d(base_channels*2, out_channels, 3, padding=1)

    self.pool = nn.MaxPool2d(2)
    self.up = nn.Upsample(scale_factor=2, mode='nearest') # ! ConvTranspose2d instead
    
    # timestep embedding
    self.time_mlp = nn.Sequential(
        nn.Linear(1, base_channels*2),
        nn.ReLU(),
        nn.Linear(base_channels*2, base_channels*2)
    )

  def forward(self, x, t, gamma, beta):
    # t: [B] timesteps, scale to embedding
    t = t[:, None].float() / 1000  # simple scaling
    t_emb = self.time_mlp(t)[:, :, None, None]  # [B, hidden*2,1,1]
    
    # Encoder
    e1 = self.enc1(x)
    e1 = apply_film(e1, gamma[:, :self.base_channels], beta[:, :self.base_channels])

    e2 = self.enc2(self.pool(e1))
    e2 = apply_film(e2, gamma[:, :self.base_channels*2], beta[:, :self.base_channels*2])

    e3 = self.enc3(self.pool(e2)) + t_emb
    e3 = apply_film(e3, gamma[:, :self.base_channels*2], beta[:, :self.base_channels*2])
    
    # Decoder
    d3 = self.dec3(e3)
    d3 = apply_film(d3, gamma[:, :self.base_channels*2], beta[:, :self.base_channels*2])

    d2 = self.dec2(torch.cat([self.up3(d3), e2], dim=1))
    d2 = apply_film(d2, gamma[:, :self.base_channels], beta[:, :self.base_channels])

    out = self.dec1(torch.cat([self.up2(d2), e1], dim=1))
    return out

In [None]:
# Prepare MNIST
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)

In [None]:
# Build vocab for prompt encoder
use_words = True  # Set False to use strings "1", "2", ...
if use_words:
  vocab_list = [digit_to_word[i] for i in range(10)]
else:
  vocab_list = [digit_to_string[i] for i in range(10)]
vocab = {v: i for i, v in enumerate(vocab_list)}

print(f"Vocab: {vocab}")

In [None]:
# -------------------------
# 3️⃣ Diffusion utils
# -------------------------
T = 100  # number of diffusion steps
beta = np.linspace(1e-4, 0.02, T)
alpha = 1.0 - beta
alpha_bar = np.cumprod(alpha)

def q_sample(x0, t, noise=None):
    """Sample from q(x_t | x_0) - the forward diffusion process"""
    if noise is None:
        noise = torch.randn_like(x0)
    
    # Convert t to cpu for numpy indexing, then back to device
    t_cpu = t.cpu()
    a_bar = torch.tensor(alpha_bar[t_cpu], device=x0.device, dtype=x0.dtype)
    
    # Reshape for broadcasting
    while len(a_bar.shape) < len(x0.shape):
        a_bar = a_bar.unsqueeze(-1)
    
    return a_bar.sqrt() * x0 + (1 - a_bar).sqrt() * noise

In [None]:
# Model, optimizer, loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = SmallUNet().to(device)
prompt_encoder = PromptEncoder(vocab).to(device)
optimizer = optim.Adam(list(unet.parameters()) + list(prompt_encoder.parameters()), lr=1e-3)
loss_fn = nn.MSELoss()

In [None]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="./logs/diffusion/")

In [None]:
# Training loop - Diffusion model training
# The model learns to predict the noise that was added to create noisy images
for epoch in range(15):
    unet.train()
    prompt_encoder.train()

    running_loss = 0.0

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        # Move data to device
        imgs = imgs.to(device)
        
        # Get prompt embeddings
        prompts = [digit_to_word[l.item()] if use_words else digit_to_string[l.item()] for l in labels]
        prompt_gamma, prompt_beta = prompt_encoder(prompts)
        
        # Sample random timesteps for each image in the batch
        t = torch.randint(0, T, (imgs.size(0),), device=device)
        
        # Sample noise to add to the images
        noise = torch.randn_like(imgs)
        
        # Create noisy images at timestep t using the forward diffusion process
        x_t = q_sample(imgs, t, noise=noise)
    
        # Forward pass: model predicts the noise that was added
        predicted_noise = unet(x_t, t, prompt_gamma, prompt_beta)
        
        # Loss: compare predicted noise with actual noise
        loss = loss_fn(predicted_noise, noise)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)

    writer.add_scalar("Loss/Train", avg_loss, epoch+1)
    print(f"Epoch {epoch+1} Loss: {avg_loss:.4f}")

In [None]:
# Generate new image from text prompt - DDPM sampling
import matplotlib.pyplot as plt

def generate_image_ddpm(prompt_text, model, prompt_encoder, device, num_steps=50):
    """Generate image using DDPM reverse process"""
    model.eval()
    with torch.no_grad():
        # Start with pure noise
        img = torch.randn(1, 1, 28, 28).to(device)
        
        # Get prompt embedding
        prompt_gamma, prompt_beta = prompt_encoder([prompt_text])
        
        # Reverse diffusion process
        timesteps = np.linspace(T-1, 0, num_steps).astype(int)
        
        for i, t_val in enumerate(timesteps):
            t = torch.tensor([t_val], device=device).long()
                        
            # Predict noise
            predicted_noise = model(img, t, prompt_gamma, prompt_beta)
            
            # Remove predicted noise (simplified DDPM step)
            alpha_t = alpha[t_val]
            alpha_bar_t = alpha_bar[t_val]
            
            if t_val > 0:
                # Not the final step - add some randomness
                beta_t = beta[t_val]
                noise = torch.randn_like(img)
                img = (1 / np.sqrt(alpha_t)) * (img - ((1 - alpha_t) / np.sqrt(1 - alpha_bar_t)) * predicted_noise)
                img = img + np.sqrt(beta_t) * noise
            else:
                # Final step - no noise
                img = (1 / np.sqrt(alpha_t)) * (img - ((1 - alpha_t) / np.sqrt(1 - alpha_bar_t)) * predicted_noise)
        
        # Clamp values
        img = torch.clamp(img, 0, 1)
        
        # Convert to numpy for display
        generated_img = img.cpu().squeeze().numpy()
        return generated_img

def generate_image_simple(prompt_text, model, prompt_encoder, device):
    """Simple single-step generation (for comparison)"""
    model.eval()
    with torch.no_grad():
        # Start with moderate noise
        img = torch.randn(1, 1, 28, 28).to(device) * 0.5 + 0.5
        
        # Get prompt embedding
        prompt_gamma, prompt_beta = prompt_encoder([prompt_text])
                
        # Use a moderate timestep
        t = torch.tensor([T//2], device=device).long()
        
        # Predict noise and subtract it
        predicted_noise = model(img, t, prompt_gamma, prompt_beta)
        generated = img - predicted_noise * 0.3  # Scale factor for single step
        generated = torch.clamp(generated, 0, 1)
        
        # Convert to numpy for display
        generated_img = generated.cpu().squeeze().numpy()
        return generated_img

# Test generation after training (uncomment after training completes)
# Generate and display images for different prompts
prompts_to_test = ["one", "two", "three", "four", "five"]

print("Ready to generate images. Run training first, then uncomment the generation code below.")

# Uncomment after training:

fig, axes = plt.subplots(2, len(prompts_to_test), figsize=(15, 6))
fig.suptitle('Image Generation: DDPM (top) vs Simple (bottom)')

for i, prompt in enumerate(prompts_to_test):
    # DDPM generation
    generated_img_ddpm = generate_image_ddpm(prompt, unet, prompt_encoder, device)
    axes[0, i].imshow(generated_img_ddpm, cmap='gray')
    axes[0, i].set_title(f'DDPM: "{prompt}"')
    axes[0, i].axis('off')
    
    # Simple generation
    generated_img_simple = generate_image_simple(prompt, unet, prompt_encoder, device)
    axes[1, i].imshow(generated_img_simple, cmap='gray')
    axes[1, i].set_title(f'Simple: "{prompt}"')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()
