In [281]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from typing import Tuple, Optional, Any
import math
from torch import Tensor
from einops import rearrange
from diffusers import DDPMScheduler
from torchvision import transforms, datasets
from tqdm import tqdm
from IPython.display import clear_output
import einops
from torch.utils.data.dataset import Dataset
from dataclasses import dataclass
from typing import List, Union, Dict, Any, Optional, BinaryIO
from PIL import Image, ImageColor, ImageDraw, ImageFont
import tqdm
import abc  # Abstract base class
import numpy as np
from transformers import CLIPTextModel, CLIPTokenizer


In [282]:
#@title Hyperparameters

T = 1000 # Timesteps

In [283]:
#@title Device / CUDA

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f'Using device: {device}')

Using device: cuda


In [284]:
@dataclass
class DataPack:
    train_dataset: Dataset
    train_loader: DataLoader
    val_dataset: Dataset
    val_loader: DataLoader
    transform_to_tensor: Any
    transform_to_pil: Any
    in_channels: int
    out_channels: int
    num_classes: int
    recommended_steps: Tuple[int]
    recommended_attn_step_indexes: List[int]

class MNISTTransformation:
    def __call__(self, tensor: torch.Tensor):
        return (tensor * -1 + 1).permute(1,2,0).detach().cpu().numpy()

def get_mnist_loader_and_transform(
    path_to_dataset: str = "../../../datasets",
    batch_size: int = 128,
    num_workers: int = 2
) -> DataPack:
    transform = transforms.Compose([
        transforms.Resize((32, 32)),        # Ensure input size matches UNet2DModel
        transforms.ToTensor()
    ])
    train_dataset = torchvision.datasets.MNIST(root=path_to_dataset, download=True, transform=transform)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    val_dataset = torchvision.datasets.MNIST(root=path_to_dataset, download=True, transform=transform, train=False)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return DataPack(
        train_dataset=train_dataset,
        train_loader=train_dataloader,
        val_dataset=val_dataset,
        val_loader=val_dataloader,
        transform_to_tensor=transform,
        transform_to_pil=MNISTTransformation(),
        in_channels=1,
        out_channels=1,
        num_classes=10,
        recommended_steps=(1,2,4),
        recommended_attn_step_indexes=[1]
    )

class CifarTransformation:
    def __call__(self, tensor: torch.Tensor):
        return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy()

def get_cifar10_loader_and_transform(
    path_to_dataset: str = "../../../datasets",
    batch_size: int = 128,
    num_workers: int = 2
) -> DataPack:
    transform_to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((.5,.5,.5), (.5,.5,.5))
    ])
    transform_to_pil = CifarTransformation()
    train_dataset = torchvision.datasets.CIFAR10(root=path_to_dataset, download=True, transform=transform_to_tensor)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    val_dataset = torchvision.datasets.CIFAR10(root=path_to_dataset, download=True, transform=transform_to_tensor, train=False)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return DataPack(
        train_dataset=train_dataset,
        train_loader=train_dataloader,
        val_dataset=val_dataset,
        val_loader=val_dataloader,
        transform_to_tensor=transform_to_tensor,
        transform_to_pil=transform_to_pil,
        in_channels=3,
        out_channels=3,
        num_classes=10,
        recommended_steps=(1,2,2,2),
        recommended_attn_step_indexes=[1,2]
    )



In [285]:
from diffusers.models.unets.unet_2d_blocks import (
    AttnDownBlock2D as DiffAttnDownBlock2D,
    DownBlock2D as DiffDownBlock2D,
    CrossAttnDownBlock2D as DiffCrossAttnDownBlock2D,
    AttnUpBlock2D as DiffAttnUpBlock2D,
    UpBlock2D as DiffUpBlock2D,
    ResnetBlock2D as DiffResnetBlock2D,
)

from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel

from diffusers import DDPMScheduler

In [286]:
class DDPM(nn.Module):
    """
    Conditional DDPM with UNet backbone and support for training + inference.
    """
    def __init__(
        self,
        T: int,
        unet: UNet2DConditionModel,
        noise_scheduler: DDPMScheduler,
        device: str
    ):
        super().__init__()
        self.T = T
        self.unet = unet.to(device)
        self.noise_sched = noise_scheduler
        self.device = device

        # Pre-compute terms from beta schedule
        bar_alpha_t = self.noise_sched.alphas_cumprod.to(device)
        self.register_buffer("sqrt_bar_alpha_t", torch.sqrt(bar_alpha_t))  # √ᾱ_t
        self.register_buffer("sqrt_minus_bar_alpha_t_schedule", torch.sqrt(1 - bar_alpha_t))  # √(1 - ᾱ_t)

        self.criterion = nn.MSELoss()

    def forward(self, batch: dict) -> torch.Tensor:
        """
        Forward diffusion + loss. Used in training.
        """
        imgs = batch["image"].to(self.device)
        lbls = batch["label"]
        text = batch["text"]
        text_embedding = batch["text_embedding"]

        b, c, h, w = imgs.shape

        # Sample time steps
        t = torch.randint(low=0, high=self.T, size=(b,), device=self.device, dtype=torch.long)

        # Sample Gaussian noise
        noise = torch.randn_like(imgs, device=self.device)

        # q(x_t | x_0): noisy image generation
        noisy_imgs = self.sqrt_bar_alpha_t[t].view(b, 1, 1, 1) * imgs + \
                     self.sqrt_minus_bar_alpha_t_schedule[t].view(b, 1, 1, 1) * noise

        # Predict noise with UNet
        pred_noise = self.unet(noisy_imgs, t, encoder_hidden_states=text_embedding).sample

        # Compute MSE loss between predicted and true noise
        return self.criterion(pred_noise, noise)

    def predict_noise(self, noisy_imgs: torch.Tensor, t: torch.Tensor, cond: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Predict noise given noisy image x_t and timestep t (used during inference).
        """
        return self.unet(noisy_imgs, t, cond).sample

    def sample(self, shape: tuple, text_encoder, batch) -> torch.Tensor:
        b, c, h, w = shape

        sample = torch.randn((b, c, h, w), device=self.device)

        imgs, txt, lbls = batch

        with torch.no_grad():
            for t in reversed(range(self.noise_sched.num_train_timesteps)):
                t_tensor = torch.full((b,), t, device=self.device, dtype=torch.long)

                # Get the text embedding for conditioning
                encoder_hidden_states = text_encoder(txt, return_dict=False)[0]

                # Predict noise
                noise_pred = self.unet(sample, t_tensor, encoder_hidden_states).sample

                # Run step per item in batch since diffusers expects scalar timestep
                prev_samples = []
                for i in range(b):
                    out = self.noise_sched.step(
                        model_output=noise_pred[i:i+1],  # single sample
                        timestep=t_tensor[i].cpu(),     # scalar
                        sample=sample[i:i+1]            # single sample
                    ).prev_sample
                    prev_samples.append(out)

                sample = torch.cat(prev_samples, dim=0)

        return sample

In [287]:
# batch_size = 64
# data = get_mnist_loader_and_transform(batch_size=batch_size)

cond_dim = 10

unet = UNet2DConditionModel(
    sample_size=32,
    in_channels=1,
    out_channels=1,
    layers_per_block=1,
    block_out_channels=(112, 224, 336, 448),
    norm_num_groups=16,
    addition_embed_type="text",  # `addition_embed_type`: simple_projection must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'.
    encoder_hid_dim=512
).to(device)

model = DDPM(
    T=T,
    unet = unet,
    noise_scheduler = DDPMScheduler(T),
    device = device
)


In [288]:
num_of_params = sum([p.numel() for p in model.parameters()])
print("Number of trainable parameters in the model: " + str(f"{num_of_params:,}"))

Number of trainable parameters in the model: 64,325,521


In [None]:
class MNISTWithTextLabel(torch.utils.data.Dataset):
    def __init__(self, root, train=True, transform=None):
        self.dataset = datasets.MNIST(root=root, train=train, download=True, transform=transform)
        self.label_to_text = {
            0: "zero", 1: "one", 2: "two", 3: "three", 4: "four",
            5: "five", 6: "six", 7: "seven", 8: "eight", 9: "nine"
        }

        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", use_safetensors=True)
        self.text_encoder = self.text_encoder.eval().to(device)  # disable training

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        text = self.label_to_text[label]

        # Tokenize the text and move to the correct device
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.text_encoder.device)

        # Get the text embedding (last_hidden_state or pooled output)
        with torch.no_grad():
            embedding = self.text_encoder(**inputs).last_hidden_state  # shape: [1, seq_len, hidden_dim]

        embedding = embedding.mean(dim=1)

        return {
            "image": image,
            "label": label,
            "text": text,
            "text_embedding": embedding  # shape: [768] for CLIP ViT-B/32
        }

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

train_dataset = MNISTWithTextLabel(root="../../datasets", train=True, transform=transform)
val_dataset = MNISTWithTextLabel(root="../../datasets", train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)


In [290]:
def sample_and_plot(model: DDPM, device: str, num_samples: int = 8):
    model.eval()
    with torch.no_grad():
        cond = torch.randn((num_samples, 10, 64), device=device)  # random condition
        samples = model.sample((num_samples, 1, 32, 32), cond=cond).cpu()
        grid = torchvision.utils.make_grid(samples, nrow=4, normalize=True, value_range=(0, 1))

        plt.figure(figsize=(6, 6))
        plt.imshow(grid.permute(1, 2, 0).numpy())
        plt.axis('off')
        plt.title("Generated Samples")
        plt.show()

# Train

In [291]:
def train(
    model: DDPM,
    optimizer: torch.optim.Optimizer,
    epochs: int,
    device: str,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    sample_every: int = 1,
):
    training_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        training_loss = 0.0
        pbar = tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch}")

        for index, batch in enumerate(pbar):
            loss = model(batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            training_loss += loss.item()
            pbar.set_postfix(loss=training_loss / (index + 1))
        
        if epoch % sample_every == 0:
            sample_and_plot(model, device)

        training_loss /= len(train_dataloader)
        training_losses.append(training_loss)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, labels in val_dataloader:
                imgs = imgs.to(device)
                cond = torch.nn.functional.one_hot(labels, num_classes=10).float().to(device)
                loss = model(imgs, cond=cond)
                val_loss += loss.item()
        val_loss /= len(val_dataloader)
        val_losses.append(val_loss)

        print(f"Epoch {epoch} | Train Loss: {training_loss:.4f} | Val Loss: {val_loss:.4f}")

    sample_and_plot(model, device)
    return training_losses, val_losses

In [None]:
train(
    model=model,
    optimizer=torch.optim.Adam(params=model.parameters(), lr=2e-4),
    epochs=10,
    device=device,
    train_dataloader=train_loader,
    val_dataloader=val_loader
)

Epoch 0:   0%|          | 0/938 [00:00<?, ?it/s]