In [87]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm
import math
import numpy as np
import os
from PIL import Image


In [88]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [89]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t_tensor: torch.Tensor) -> torch.Tensor: # Input parameter is t_tensor
        if t_tensor.dim() == 0:
            t_tensor = t_tensor.unsqueeze(0)
        
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        # Ensure 'embeddings' are created on the correct device
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = t_tensor.unsqueeze(1) * embeddings.unsqueeze(0) # Use t_tensor here
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [90]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, downsample: bool):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.activation = nn.SiLU()

        if downsample:
            self.pool = nn.MaxPool2d(2)
        else:
            self.pool = nn.Identity()

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        h = self.activation(self.norm1(self.conv1(x)))
        time_embedding = self.activation(self.time_mlp(t_emb))
        h = h + time_embedding.unsqueeze(-1).unsqueeze(-1)
        h = self.activation(self.norm2(self.conv2(h)))
        return self.pool(h)

In [91]:
class UpBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.activation = nn.SiLU()
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x: torch.Tensor, skip_x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        x = self.upsample(x)
        x = torch.cat([x, skip_x], dim=1) 
        h = self.activation(self.norm1(self.conv1(x)))
        time_embedding = self.activation(self.time_mlp(t_emb))
        h = h + time_embedding.unsqueeze(-1).unsqueeze(-1)
        h = self.activation(self.norm2(self.conv2(h)))
        return h

class UNet(nn.Module):
    def __init__(self, in_channels: int = 1, out_channels: int = 1, time_emb_dim: int = 256):
        super().__init__()
        self.time_mlp = PositionalEncoding(time_emb_dim)

        self.inc = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.down1 = ConvBlock(64, 128, time_emb_dim, downsample=True)
        self.down2 = ConvBlock(128, 256, time_emb_dim, downsample=True)

        self.bot1 = nn.Conv2d(256, 256, 3, padding=1)
        self.bot2 = nn.Conv2d(256, 256, 3, padding=1)

        self.up1 = UpBlock(256 + 128, 128, time_emb_dim) 
        self.up2 = UpBlock(128 + 64, 64, time_emb_dim)
        self.outc = nn.Conv2d(64, out_channels, 1)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        t_emb = self.time_mlp(t)

        x1 = F.relu(self.inc(x))
        x2 = self.down1(x1, t_emb)
        x3 = self.down2(x2, t_emb)

        x3 = F.relu(self.bot1(x3))
        x3 = F.relu(self.bot2(x3))

        x = self.up1(x3, x2, t_emb)
        x = self.up2(x, x1, t_emb)

        output = self.outc(x)
        return output

In [92]:
class DDIMScheduler:
    def __init__(self, timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02):
        self.timesteps = timesteps
        # Ensure all pre-computed tensors are on the correct device
        self.betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(device) # Move to device
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0).to(device) # Move to device
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(device) # Move to device
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(device) # Move to device

    def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
        # timesteps_tensor is already on device from the training loop
        sqrt_alpha_prod_t = self.sqrt_alphas_cumprod[timesteps].view(-1, 1, 1, 1) # timesteps already on device
        sqrt_one_minus_alpha_prod_t = self.sqrt_one_minus_alphas_cumprod[timesteps].view(-1, 1, 1, 1) # timesteps already on device

        noisy_samples = sqrt_alpha_prod_t * original_samples + sqrt_one_minus_alpha_prod_t * noise
        return noisy_samples

    def get_ddim_sampling_timesteps(self, num_inference_steps: int) -> torch.Tensor:
        assert num_inference_steps <= self.timesteps, \
            f"num_inference_steps ({num_inference_steps}) cannot be greater than total timesteps ({self.timesteps})"
        
        skip = self.timesteps // num_inference_steps
        timesteps = list(range(0, self.timesteps, skip))
        if timesteps[-1] != self.timesteps - 1:
            timesteps.append(self.timesteps - 1)
        
        return torch.tensor(timesteps[::-1], dtype=torch.long).to(device) # Ensure timesteps are on device

    def ddim_step(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor, eta: float = 0.0) -> tuple[torch.Tensor, torch.Tensor]:
        # timestep is already on device from the sampling loop
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod_prev[timestep]

        sigma_t = eta * torch.sqrt((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) * \
                    torch.sqrt(1 - alpha_prod_t / alpha_prod_t_prev)

        pred_x0 = (sample - torch.sqrt(1 - alpha_prod_t) * model_output) / torch.sqrt(alpha_prod_t)
        dir_xt = torch.sqrt(1 - alpha_prod_t_prev - sigma_t**2) * model_output
        
        noise_term = torch.randn_like(sample) if eta > 0 else torch.zeros_like(sample)

        prev_sample = torch.sqrt(alpha_prod_t_prev) * pred_x0 + dir_xt + sigma_t * noise_term
        
        return prev_sample, pred_x0


In [93]:
class DiffusionPipeline:
    def __init__(self, model: nn.Module, scheduler: DDIMScheduler, image_size: int, 
                 dataset: Dataset, batch_size: int, lr: float, epochs: int, 
                 save_dir: str = "ddim_pipeline_results"):
        self.model = model.to(device)
        self.scheduler = scheduler
        self.image_size = image_size
        self.dataset = dataset
        self.batch_size = batch_size
        self.lr = lr
        self.epochs = epochs
        self.save_dir = save_dir
        
        os.makedirs(save_dir, exist_ok=True)

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        
        class TransformedDataset(Dataset):
            def __init__(self, base_dataset, transform):
                self.base_dataset = base_dataset
                self.transform = transform
            
            def __len__(self):
                return len(self.base_dataset)
            
            def __getitem__(self, idx):
                img, _ = self.base_dataset[idx]
                return self.transform(img)

        self.train_dataset = TransformedDataset(dataset, self.transform)
        self.dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count()//2)
    
    def train(self):
        self.model.train()
        for epoch in range(self.epochs):
            total_loss = 0
            for batch_idx, images in enumerate(tqdm(self.dataloader, desc=f"Epoch {epoch+1}/{self.epochs}")):
                images = images.to(device)
                
                timesteps = torch.randint(0, self.scheduler.timesteps, (images.shape[0],), device=device)
                noise = torch.randn_like(images).to(device)
                noisy_images = self.scheduler.add_noise(images, noise, timesteps)
                
                self.optimizer.zero_grad()
                predicted_noise = self.model(noisy_images, timesteps)
                
                loss = self.criterion(predicted_noise, noise)
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(self.dataloader)
            print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

            if (epoch + 1) % 5 == 0 or epoch == self.epochs - 1:
                torch.save(self.model.state_dict(), os.path.join(self.save_dir, f"unet_epoch_{epoch+1}.pt"))
                self.generate_samples(num_samples=4, num_inference_steps=50, epoch=epoch+1)

    @torch.no_grad()
    def generate_samples(self, num_samples: int, num_inference_steps: int, eta: float = 0.0, epoch: int = None):
        self.model.eval()
        
        sample = torch.randn(num_samples, self.model.inc.in_channels, self.image_size, self.image_size).to(device)
        ddim_timesteps = self.scheduler.get_ddim_sampling_timesteps(num_inference_steps)
        
        for i, t in enumerate(tqdm(ddim_timesteps, desc="DDIM Sampling")):
            model_output = self.model(sample, t.unsqueeze(0).repeat(num_samples))
            sample, _ = self.scheduler.ddim_step(model_output, t, sample, eta)

        sample = (sample.clamp(-1, 1) + 1) / 2
        sample = sample.cpu().permute(0, 2, 3, 1).numpy()
        
        for i in range(num_samples):
            img = Image.fromarray((sample[i] * 255).astype(np.uint8).squeeze())
            filename = f"generated_sample_{'epoch_' + str(epoch) + '_' if epoch is not None else ''}{i}.png"
            img.save(os.path.join(self.save_dir, filename))
        
        self.model.train()

In [94]:
if __name__ == "__main__":
    image_size = 28
    in_channels = 1
    epochs = 20
    batch_size = 128
    learning_rate = 1e-4
    timesteps = 1000
    num_inference_steps = 50
    eta = 0.0

    mnist_dataset = MNIST(root="./data", train=True, download=True)

    model = UNet(in_channels=in_channels, out_channels=in_channels, time_emb_dim=256)
    scheduler = DDIMScheduler(timesteps=timesteps)

    pipeline = DiffusionPipeline(
        model=model,
        scheduler=scheduler,
        image_size=image_size,
        dataset=mnist_dataset,
        batch_size=batch_size,
        lr=learning_rate,
        epochs=epochs
    )
    
    print("Starting DDIM training pipeline...")
    pipeline.train()
    print("DDIM training complete.")

    print("Generating final samples with trained model...")
    pipeline.generate_samples(num_samples=16, num_inference_steps=num_inference_steps, eta=eta, epoch="final")
    print("Final samples generated and saved.")


    loaded_model = UNet(in_channels=in_channels, out_channels=in_channels, time_emb_dim=256).to(device)
    loaded_model.load_state_dict(torch.load("ddim_pipeline_results/unet_epoch_50.pt"))
    print("Loaded trained model for further sampling.")
    loaded_pipeline = DiffusionPipeline(loaded_model, scheduler, image_size, mnist_dataset, batch_size, learning_rate, epochs)
    loaded_pipeline.generate_samples(num_samples=16, num_inference_steps=num_inference_steps, eta=eta, epoch="loaded_model_test")


Starting DDIM training pipeline...


Epoch 1/20: 100%|██████████| 469/469 [00:49<00:00,  9.47it/s]


Epoch 1, Loss: 0.0976


Epoch 2/20: 100%|██████████| 469/469 [00:51<00:00,  9.08it/s]


Epoch 2, Loss: 0.0403


Epoch 3/20: 100%|██████████| 469/469 [00:51<00:00,  9.18it/s]


Epoch 3, Loss: 0.0349


Epoch 4/20: 100%|██████████| 469/469 [00:51<00:00,  9.16it/s]


Epoch 4, Loss: 0.0332


Epoch 5/20: 100%|██████████| 469/469 [00:50<00:00,  9.20it/s]


Epoch 5, Loss: 0.0309


DDIM Sampling: 100%|██████████| 51/51 [00:00<00:00, 231.75it/s]
Epoch 6/20: 100%|██████████| 469/469 [00:50<00:00,  9.22it/s]


Epoch 6, Loss: 0.0298


Epoch 7/20: 100%|██████████| 469/469 [00:50<00:00,  9.23it/s]


Epoch 7, Loss: 0.0288


Epoch 8/20: 100%|██████████| 469/469 [00:50<00:00,  9.21it/s]


Epoch 8, Loss: 0.0281


Epoch 9/20: 100%|██████████| 469/469 [00:50<00:00,  9.21it/s]


Epoch 9, Loss: 0.0278


Epoch 10/20: 100%|██████████| 469/469 [00:50<00:00,  9.20it/s]


Epoch 10, Loss: 0.0272


DDIM Sampling: 100%|██████████| 51/51 [00:00<00:00, 379.75it/s]
Epoch 11/20: 100%|██████████| 469/469 [00:50<00:00,  9.21it/s]


Epoch 11, Loss: 0.0267


Epoch 12/20: 100%|██████████| 469/469 [00:50<00:00,  9.22it/s]


Epoch 12, Loss: 0.0268


Epoch 13/20: 100%|██████████| 469/469 [00:50<00:00,  9.21it/s]


Epoch 13, Loss: 0.0263


Epoch 14/20: 100%|██████████| 469/469 [00:50<00:00,  9.20it/s]


Epoch 14, Loss: 0.0258


Epoch 15/20: 100%|██████████| 469/469 [00:51<00:00,  9.19it/s]


Epoch 15, Loss: 0.0255


DDIM Sampling: 100%|██████████| 51/51 [00:00<00:00, 344.01it/s]
Epoch 16/20: 100%|██████████| 469/469 [00:51<00:00,  9.19it/s]


Epoch 16, Loss: 0.0253


Epoch 17/20: 100%|██████████| 469/469 [00:50<00:00,  9.21it/s]


Epoch 17, Loss: 0.0252


Epoch 18/20: 100%|██████████| 469/469 [00:51<00:00,  9.17it/s]


Epoch 18, Loss: 0.0252


Epoch 19/20: 100%|██████████| 469/469 [00:51<00:00,  9.13it/s]


Epoch 19, Loss: 0.0250


Epoch 20/20: 100%|██████████| 469/469 [00:51<00:00,  9.13it/s]


Epoch 20, Loss: 0.0249


DDIM Sampling: 100%|██████████| 51/51 [00:00<00:00, 367.28it/s]


DDIM training complete.
Generating final samples with trained model...


DDIM Sampling: 100%|██████████| 51/51 [00:00<00:00, 185.03it/s]


Final samples generated and saved.


FileNotFoundError: [Errno 2] No such file or directory: 'ddim_pipeline_results/unet_epoch_50.pt'