# Latent DiffiT: Transformer-Based Diffusion Model for Image Generation

This notebook implements a complete Latent DiffiT pipeline for training and generating images with a transformer-based diffusion model.

## Setup and Imports

In [ ]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from diffusers import AutoencoderKL
from einops import rearrange

import os
import math
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

## 1. Mount Google Drive and Set Up Directories

We'll mount Google Drive to access our dataset and save generated images.

In [ ]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up data and output directories
DATA_DIR = "/content/drive/MyDrive/DiffiT_latent_space/image-net-256/archive/data"
OUTPUT_DIR = "/content/drive/MyDrive/DiffiT_latent_space/generate_image"
MODEL_DIR = "/content/drive/MyDrive/DiffiT_latent_space/models"

# Create output and model directories if they don't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# Check the dataset structure
print(f"Dataset directory: {DATA_DIR}")
if os.path.exists(DATA_DIR):
    print(f"Number of class folders: {len(os.listdir(DATA_DIR))}")
    print(f"First 10 classes: {os.listdir(DATA_DIR)[:10]}")
else:
    print("Dataset directory not found! Please check the path.")

## 2. Load VAE Model

We'll use a pre-trained VAE from Stable Diffusion for encoding images to latent space and decoding from latent space.

In [ ]:
# Load VAE from Stable Diffusion
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="vae")
vae.to(device)
vae.eval()  # Set to inference mode

# Freeze VAE weights
for param in vae.parameters():
    param.requires_grad = False

# Get the scaling factor (typically 0.18215 for Stable Diffusion VAE)
scaling_factor = getattr(vae.config, "scaling_factor", 0.18215)
print(f"VAE scaling factor: {scaling_factor}")

# Define helper functions for encoding and decoding
def encode_to_latent(images):
    """
    Encode images to latent space using the VAE.
    
    Args:
        images: Tensor of shape [B, C, H, W] in range [-1, 1]
        
    Returns:
        latents: Tensor of shape [B, 4, H/8, W/8]
    """
    with torch.no_grad():
        latent_dist = vae.encode(images).latent_dist
        latents = latent_dist.sample()
        latents = latents / scaling_factor
    return latents

def decode_from_latent(latents):
    """
    Decode latents to images using the VAE.
    
    Args:
        latents: Tensor of shape [B, 4, H/8, W/8]
        
    Returns:
        images: Tensor of shape [B, 3, H, W] in range [-1, 1]
    """
    with torch.no_grad():
        latents_scaled = latents * scaling_factor
        images = vae.decode(latents_scaled).sample
    return images

# Test VAE with a random image
test_image = torch.randn(1, 3, 256, 256).to(device)  # Random image
test_latent = encode_to_latent(test_image)
test_recon = decode_from_latent(test_latent)

print(f"Test image shape: {test_image.shape}")
print(f"Test latent shape: {test_latent.shape}")
print(f"Test reconstruction shape: {test_recon.shape}")

## 3. Data Loading

Set up the ImageNet dataset and dataloaders.

In [ ]:
def get_imagenet_dataloader(data_dir, batch_size=32, num_workers=2):
    # Define transforms for images
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Ensure images are 256x256
        transforms.ToTensor(),          # Convert to tensor [0, 1]
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
    ])

    # Load dataset using ImageFolder
    dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    return dataloader, dataset

# Set up the dataloader
batch_size = 32
dataloader, dataset = get_imagenet_dataloader(DATA_DIR, batch_size=batch_size)
num_classes = len(dataset.classes)

print(f"Dataset loaded with {len(dataset)} images in {num_classes} classes")
print(f"Batch size: {batch_size}")

# Display some sample images
def show_batch(dataloader):
    images, labels = next(iter(dataloader))
    images = (images + 1) / 2  # Convert from [-1, 1] to [0, 1] for display
    grid = vutils.make_grid(images[:16], nrow=4, padding=2, normalize=False)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()
    class_names = [dataset.classes[label] for label in labels[:16]]
    print("Classes:", class_names)

try:
    show_batch(dataloader)
except Exception as e:
    print(f"Error displaying batch: {e}")

## 4. Latent DiffiT Model Implementation

Implement the transformer-based diffusion model for latent space.

In [ ]:
# 4.1 Time Embedding and Utilities

class SinusoidalPositionEmbeddings(nn.Module):
    """Sinusoidal position embeddings for timesteps"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Swish(nn.Module):
    """Swish activation function: x * sigmoid(x)"""
    def forward(self, x):
        return x * torch.sigmoid(x)

class TimeEmbedding(nn.Module):
    """Time embedding with MLP and Swish activation"""
    def __init__(self, time_embed_dim, model_dim):
        super().__init__()
        self.time_embed_dim = time_embed_dim
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(time_embed_dim),
            nn.Linear(time_embed_dim, model_dim),
            Swish(),
            nn.Linear(model_dim, model_dim)
        )

    def forward(self, time):
        return self.time_embed(time)

class LabelEmbedding(nn.Module):
    """Label embedding with MLP and Swish activation"""
    def __init__(self, num_classes, embed_dim, model_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, embed_dim)
        self.projection = nn.Sequential(
            nn.Linear(embed_dim, model_dim),
            Swish(),
            nn.Linear(model_dim, model_dim)
        )

    def forward(self, labels):
        x = self.embedding(labels)
        return self.projection(x)

In [ ]:
# 4.2 Time-dependent Multi-head Self-Attention (TMSA)

class TimeDependentMultiHeadAttention(nn.Module):
    """Time-dependent Multi-head Self-Attention (TMSA)"""
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        # Spatial projection weights (Wqs, Wks, Wvs)
        self.to_q_spatial = nn.Linear(dim, inner_dim, bias=False)
        self.to_k_spatial = nn.Linear(dim, inner_dim, bias=False)
        self.to_v_spatial = nn.Linear(dim, inner_dim, bias=False)

        # Temporal projection weights (Wqt, Wkt, Wvt)
        self.to_q_temporal = nn.Linear(dim, inner_dim, bias=False)
        self.to_k_temporal = nn.Linear(dim, inner_dim, bias=False)
        self.to_v_temporal = nn.Linear(dim, inner_dim, bias=False)

        # Output projection
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

        # Relative position bias
        self.rel_pos_bias = nn.Parameter(torch.zeros(heads, 49, 49))

    def forward(self, x, time_emb):
        """Forward pass of TMSA"""
        batch_size, seq_len, _ = x.shape
        h = self.heads

        # Spatial components
        q_spatial = self.to_q_spatial(x).reshape(batch_size, seq_len, h, -1).permute(0, 2, 1, 3)
        k_spatial = self.to_k_spatial(x).reshape(batch_size, seq_len, h, -1).permute(0, 2, 1, 3)
        v_spatial = self.to_v_spatial(x).reshape(batch_size, seq_len, h, -1).permute(0, 2, 1, 3)

        # Temporal components
        time_emb_expanded = time_emb.unsqueeze(1)
        q_temporal = self.to_q_temporal(time_emb_expanded).reshape(batch_size, 1, h, -1).permute(0, 2, 1, 3)
        k_temporal = self.to_k_temporal(time_emb_expanded).reshape(batch_size, 1, h, -1).permute(0, 2, 1, 3)
        v_temporal = self.to_v_temporal(time_emb_expanded).reshape(batch_size, 1, h, -1).permute(0, 2, 1, 3)

        # Broadcast temporal components
        q_temporal = q_temporal.expand(-1, -1, seq_len, -1)
        k_temporal = k_temporal.expand(-1, -1, seq_len, -1)
        v_temporal = v_temporal.expand(-1, -1, seq_len, -1)

        # Combine spatial and temporal components
        q = q_spatial + q_temporal
        k = k_spatial + k_temporal
        v = v_spatial + v_temporal

        # Attention calculation
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Add relative position bias
        if seq_len <= 49:
            bias = self.rel_pos_bias[:, :seq_len, :seq_len]
            dots = dots + bias.unsqueeze(0)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)

        return self.to_out(out)

In [ ]:
# 4.3 Feed Forward Network

class FeedForward(nn.Module):
    """MLP with time conditioning"""
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        # MLP for spatial features
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            Swish(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

        # MLP for time conditioning
        self.time_mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            Swish(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x, time_emb):
        # Time conditioning
        time_out = self.time_mlp(time_emb).unsqueeze(1)
        return self.net(x) + time_out  # Additive conditioning

In [ ]:
# 4.4 Latent DiffiT Transformer Block

class LatentDiffiTTransformerBlock(nn.Module):
    """LatentDiffiT Transformer Block with TMSA and time-conditioned FFN"""
    def __init__(self, dim, heads=8, dim_head=64, mlp_dim=None, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = TimeDependentMultiHeadAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)

        mlp_dim = mlp_dim or (dim * 4)
        self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)

    def forward(self, x, time_emb):
        # LayerNorm and TMSA with residual connection
        x = x + self.attn(self.norm1(x), time_emb)
        # LayerNorm and MLP with residual connection
        x = x + self.mlp(self.norm2(x), time_emb)
        return x

In [ ]:
# 4.5 Encoder class for patching latent representations

class Encoder(nn.Module):
    """Encoder for patching latent representations"""
    def __init__(self, img_size=256, patch_size=16, hidden_dim=768):
        super().__init__()
        
        latent_channels = 4
        latent_size = img_size // 8  # 32x32 for 256x256 images
        
        self.patch_size = patch_size
        self.latent_size = latent_size
        self.hidden_dim = hidden_dim
        
        # Ensure latent_size is divisible by patch_size
        assert latent_size % patch_size == 0, "latent_size must be divisible by patch_size"
        self.patches_per_side = latent_size // patch_size
        self.num_patches = self.patches_per_side ** 2
        
        # Patch embedding layer (similar to ViT)
        self.patch_embedding = nn.Conv2d(
            in_channels=latent_channels,
            out_channels=hidden_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        
        # Position embedding
        self.position_embedding = nn.Parameter(torch.zeros(1, self.num_patches, hidden_dim))
        
    def forward(self, latents):
        patches = self.patch_embedding(latents)  # [B, hidden_dim, patches_per_side, patches_per_side]
        
        # Reshape and add positional embedding
        embedded = rearrange(patches, 'b c h w -> b (h w) c')
        embedded = embedded + self.position_embedding
        
        return embedded  # [B, num_patches, hidden_dim]

In [ ]:
# 4.6 Unpatchify to convert from sequence back to grid format

class Unpatchify(nn.Module):
    """Convert patch sequence back to grid format"""
    def __init__(self, patch_size, hidden_dim):
        super().__init__()
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        
    def forward(self, x):
        """
        x: (B, L, C) - batch size, number of patches, channels
        return: (B, C, H, W) - grid with hidden_dim channels
        """
        B, L, C = x.shape
        assert C == self.hidden_dim, f"Input channels must be {self.hidden_dim}, got {C}"
        patches_per_side = int(math.sqrt(L))
        H = W = patches_per_side * self.patch_size
        
        # Reshape from sequence to grid
        x = x.reshape(B, patches_per_side, patches_per_side, C)
        
        # Convert from [B, patches_per_side, patches_per_side, C] to [B, C, H, W]
        x = x.permute(0, 3, 1, 2)  # [B, C, patches_per_side, patches_per_side]
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)  # Upsample to [B, C, H, W]
        
        return x

In [ ]:
# 4.7 Decoder to predict noise

class Decoder(nn.Module):
    """Decoder to predict noise in latent space"""
    def __init__(self, in_channels, hidden_dim, out_channels=4):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(hidden_dim),
            nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1),
        )
        
    def forward(self, x):
        """
        x: (B, C, H, W) - input from unpatchify, C = hidden_dim
        return: (B, 4, H, W) - predicted noise in latent space
        """
        return self.decoder(x)

In [ ]:
# 4.8 Latent DiffiT Transformer

class LatentDiffiTTransformer(nn.Module):
    """Latent DiffiT Transformer for diffusion in latent space"""
    def __init__(
        self,
        dim,
        depth,
        heads=8,
        dim_head=64,
        mlp_dim=None,
        dropout=0.0,
        time_embed_dim=None,
        label_embed_dim=None,
        num_classes=1000
    ):
        super().__init__()
        
        # Parameters
        self.dim = dim
        time_embed_dim = time_embed_dim or dim * 4
        label_embed_dim = label_embed_dim or dim
        
        # Time and Label Embeddings
        self.time_embedding = TimeEmbedding(time_embed_dim, dim)
        self.label_embedding = LabelEmbedding(num_classes, label_embed_dim, dim)
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            LatentDiffiTTransformerBlock(
                dim=dim,
                heads=heads,
                dim_head=dim_head,
                mlp_dim=mlp_dim,
                dropout=dropout
            ) for _ in range(depth)
        ])
        
        # Final layer norm
        self.final_norm = nn.LayerNorm(dim)
        
    def combine_embeddings(self, time_emb, label_emb=None):
        # Combine time and label embeddings
        if label_emb is not None:
            combined_emb = time_emb + label_emb
        else:
            combined_emb = time_emb
        return combined_emb
    
    def forward(self, x, time, labels=None):
        # Create time token from timestep
        time_emb = self.time_embedding(time)
        
        # Create and combine with label embedding if provided
        if labels is not None:
            label_emb = self.label_embedding(labels)
            combined_emb = self.combine_embeddings(time_emb, label_emb)
        else:
            combined_emb = time_emb
        
        # Process through transformer blocks
        for block in self.transformer_blocks:
            x = block(x, combined_emb)
        
        # Final layer norm
        x = self.final_norm(x)
        
        return x

In [ ]:
# 4.9 Full LatentDiffiT Pipeline

class LatentDiffiTPipeline(nn.Module):
    """Complete LatentDiffiT Pipeline for training and sampling"""
    def __init__(
        self,
        img_size=256,
        patch_size=16,
        hidden_dim=768,
        depth=12,
        heads=12,
        dim_head=64,
        mlp_dim=None,
        dropout=0.0,
        time_embed_dim=None,
        label_embed_dim=None,
        num_classes=1000
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        if mlp_dim is None:
            mlp_dim = hidden_dim * 4
        
        # Initialize components
        self.encoder = Encoder(
            img_size=img_size,
            patch_size=patch_size,
            hidden_dim=hidden_dim
        )
        
        self.transformer = LatentDiffiTTransformer(
            dim=hidden_dim,
            depth=depth,
            heads=heads,
            dim_head=dim_head,
            mlp_dim=mlp_dim,
            dropout=dropout,
            time_embed_dim=time_embed_dim,
            label_embed_dim=label_embed_dim,
            num_classes=num_classes
        )
        
        self.unpatchify = Unpatchify(patch_size=patch_size, hidden_dim=hidden_dim)
        
        self.decoder = Decoder(
            in_channels=hidden_dim,
            hidden_dim=hidden_dim // 2,
            out_channels=4  # Match latent space channels
        )
    
    def forward(self, noisy_latents, timesteps, labels=None):
        # Apply classifier-free guidance to the first three channels
        noisy_latents[:, :3, :, :] *= (1 + noisy_latents[:, :3, :, :])
        
        # Process through encoder to get patches
        embedded = self.encoder.patch_embedding(noisy_latents)
        embedded = rearrange(embedded, 'b c h w -> b (h w) c') + self.encoder.position_embedding
        
        # Process through transformer
        transformer_output = self.transformer(embedded, timesteps, labels)
        
        # Convert back to spatial representation
        unpatched = self.unpatchify(transformer_output)
        
        # Predict noise
        predicted_noise = self.decoder(unpatched)
        
        return predicted_noise
    
    def sample(self, num_samples, timesteps, device, labels=None):
        """Generate images using the diffusion process"""
        latent_size = self.img_size // 8
        latents = torch.randn(num_samples, 4, latent_size, latent_size).to(device)
        timesteps_tensor = torch.arange(timesteps - 1, -1, -1, device=device).float()
        
        if labels is None:
            labels = torch.randint(0, self.num_classes, (num_samples,), device=device)
        
        # Diffusion sampling loop
        for t in timesteps_tensor:
            t_batch = t.repeat(num_samples).float()
            predicted_noise = self.forward(latents, t_batch, labels)
            
            # Simple DDPM update rule
            alpha = 1 - t / timesteps
            latents = (latents - (1 - alpha) * predicted_noise) / alpha
        
        # Decode latents to images
        with torch.no_grad():
            images = decode_from_latent(latents)
        
        return images

## 5. Initialize Model

Create the LatentDiffiT model with appropriate configurations.

In [ ]:
# Initialize the LatentDiffiT pipeline
model = LatentDiffiTPipeline(
    img_size=256,
    patch_size=16,
    hidden_dim=768,
    depth=6,       # Reduced for faster training
    heads=8,
    dim_head=64,
    num_classes=num_classes
).to(device)

# Print model summary
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Number of transformer blocks: {len(model.transformer.transformer_blocks)}")

## 6. Training Function

Define the training loop for LatentDiffiT.

In [ ]:
def train_diffit(model, dataloader, num_epochs, num_timesteps, device, learning_rate, save_dir=MODEL_DIR, output_dir=OUTPUT_DIR):
    """Train the LatentDiffiT model"""
    # Setup optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    
    # Set up directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)
    
    # Training history
    history = {'epoch': [], 'loss': []}
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch_idx, (images, labels) in enumerate(progress_bar):
            images = images.to(device)
            labels = labels.to(device)
            batch_size = images.shape[0]
            
            # Create random timesteps
            timesteps = torch.randint(0, num_timesteps, (batch_size,), device=device).float()
            
            # Encode images to latent space
            with torch.no_grad():
                latents = encode_to_latent(images)  # [B, 4, H/8, W/8]
            
            # Add noise to latent space (simple DDPM)
            noise = torch.randn_like(latents)
            t = timesteps / num_timesteps
            noisy_latents = (1 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
            
            # Predict noise in latent space
            optimizer.zero_grad()
            predicted_noise = model(noisy_latents, timesteps, labels)  # [B, 4, H/8, W/8]
            loss = criterion(predicted_noise, noise)
            loss.backward()
            optimizer.step()
            
            # Update progress bar
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
            
            # Save intermediate results for long training runs
            if batch_idx % 500 == 0 and batch_idx > 0:
                print(f"\nSaving intermediate model at epoch {epoch+1}, batch {batch_idx}")
                intermediate_path = os.path.join(save_dir, f"latent_diffit_epoch{epoch+1}_batch{batch_idx}.pth")
                torch.save(model.state_dict(), intermediate_path)
        
        # Calculate and log average loss
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
        
        # Update history
        history['epoch'].append(epoch + 1)
        history['loss'].append(avg_loss)
        
        # Save model checkpoint
        checkpoint_path = os.path.join(save_dir, f"latent_diffit_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model saved to {checkpoint_path}")
        
        # Generate samples
        model.eval()
        with torch.no_grad():
            # Generate 16 samples
            num_samples = 16
            print(f"Generating {num_samples} sample images...")
            generated_images = model.sample(num_samples=num_samples, timesteps=50, device=device)
            
            # Convert from [-1, 1] to [0, 1] for saving
            generated_images = (generated_images + 1) / 2
            
            # Save images
            output_path = os.path.join(output_dir, f"epoch_{epoch+1}.png")
            vutils.save_image(generated_images, output_path, nrow=4, padding=2)
            print(f"Generated images saved to {output_path}")
    
    # Plot loss history
    plt.figure(figsize=(10, 5))
    plt.plot(history['epoch'], history['loss'])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.savefig(os.path.join(save_dir, 'training_loss.png'))
    plt.show()
    
    return history

## 7. Train the Model

Run the training process for the LatentDiffiT model.

In [ ]:
# Training parameters
num_epochs = 10
num_timesteps = 1000
learning_rate = 3e-5

# Start training
print(f"Starting training for {num_epochs} epochs...")
history = train_diffit(
    model=model,
    dataloader=dataloader,
    num_epochs=num_epochs,
    num_timesteps=num_timesteps,
    device=device,
    learning_rate=learning_rate
)

## 8. Generate Images

Load the trained model and generate images with different class labels.

In [ ]:
def generate_images(model, device, num_images=16, timesteps=50, class_ids=None, output_dir=OUTPUT_DIR):
    """Generate images using the trained model"""
    model.eval()
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Prepare class labels if provided
    if class_ids is not None:
        # Convert to tensor and ensure it's on the correct device
        labels = torch.tensor(class_ids, device=device)
        print(f"Generating images for specific classes: {class_ids}")
        
        # Make sure we have enough labels
        if len(labels) < num_images:
            labels = labels.repeat((num_images + len(labels) - 1) // len(labels))
            labels = labels[:num_images]
    else:
        # Generate random class labels
        labels = torch.randint(0, model.num_classes, (num_images,), device=device)
        print("Generating images with random class labels")
    
    # Generate images
    print(f"Generating {num_images} images with {timesteps} timesteps...")
    with torch.no_grad():
        start_time = time.time()
        generated_images = model.sample(num_samples=num_images, timesteps=timesteps, device=device, labels=labels)
        end_time = time.time()
    
    # Process and save images
    generated_images = (generated_images + 1) / 2  # Convert from [-1, 1] to [0, 1]
    
    # Save as grid
    grid_path = os.path.join(output_dir, f"generated_grid_{time.strftime('%Y%m%d_%H%M%S')}.png")
    vutils.save_image(generated_images, grid_path, nrow=int(math.sqrt(num_images)), padding=2)
    
    # Save individual images
    for i, img in enumerate(generated_images):
        img_path = os.path.join(output_dir, f"generated_{i}_class{labels[i].item()}.png")
        vutils.save_image(img, img_path)
    
    print(f"Generation complete in {end_time - start_time:.2f} seconds")
    print(f"Images saved to {output_dir}")
    
    # Display the grid
    plt.figure(figsize=(12, 12))
    grid = vutils.make_grid(generated_images, nrow=int(math.sqrt(num_images)), padding=2, normalize=False)
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title("Generated Images")
    plt.show()
    
    return generated_images

# Load the best model (usually the last epoch or you can pick a specific one)
best_model_path = os.path.join(MODEL_DIR, f"latent_diffit_epoch_{num_epochs}.pth")
if os.path.exists(best_model_path):
    print(f"Loading best model from {best_model_path}")
    model.load_state_dict(torch.load(best_model_path, map_location=device))
else:
    print("Best model not found, using the current model state.")

# Generate images with the trained model
import time  # Import time for timestamping

# Generate 16 random images
generate_images(model, device, num_images=16, timesteps=100)

# Generate specific classes (if you know the class indices)
# For example, generate dog breeds if your dataset contains them
specific_classes = [15, 97, 182, 344]  # Example class indices
generate_images(model, device, num_images=4, timesteps=100, class_ids=specific_classes)

## 9. Interactive Image Generation Tool

Create an interactive tool to generate images with different parameters.

In [ ]:
# Only run in interactive environments like Colab
try:
    from ipywidgets import interact, IntSlider, FloatSlider, Dropdown, Text
    import ipywidgets as widgets
    
    # Get available classes
    class_names = sorted(dataset.classes)
    class_dict = {name: idx for idx, name in enumerate(class_names)}
    
    # Function for interactive image generation
    def interactive_generate(num_images=4, timesteps=100, class_name=class_names[0], seed=42):
        # Set random seed
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        
        # Get class index
        class_idx = class_dict[class_name]
        print(f"Generating {num_images} images of class '{class_name}' (index: {class_idx})")
        
        # Generate images
        return generate_images(model, device, num_images=num_images, timesteps=timesteps, 
                              class_ids=[class_idx] * num_images)
    
    # Create interactive widgets
    interact(
        interactive_generate,
        num_images=IntSlider(min=1, max=16, step=1, value=4, description="Images:"),
        timesteps=IntSlider(min=10, max=200, step=10, value=100, description="Steps:"),
        class_name=Dropdown(options=class_names[:100], description="Class:"),  # Limiting to 100 for performance
        seed=IntSlider(min=0, max=1000, step=1, value=42, description="Seed:")
    )
    
except ImportError:
    print("Interactive widgets not available. Skipping interactive tool.")

## 10. Conclusion

The LatentDiffiT model combines transformers with diffusion models to operate in the latent space of a pre-trained VAE. This approach allows for efficient training and generation of high-quality images.

Key components of the implementation:

1. **VAE Integration**: Using a pre-trained Stable Diffusion VAE to work in latent space
2. **Transformer Architecture**: Implementing the Time-dependent Multi-head Self-Attention (TMSA) mechanism
3. **Diffusion Process**: Training the model to predict noise in the latent space
4. **Class-Conditional Generation**: Support for generating images conditioned on class labels

The notebook provides a complete pipeline from data loading to training and image generation. You can experiment with different hyperparameters to optimize the quality of generated images.