# Mamba Denoiser VS DiT

In [None]:
# This implementation is based on Dino Diffusion
# https://github.com/madebyollin/dino-diffusion

In [None]:
import random
import os
from collections import namedtuple
from pathlib import Path
from functools import lru_cache

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from dit_mamba_trainloader import Config
import numpy as np

In [None]:
Config.image_size = 64
Config.shape = (3, 64, 64)

# Load the Dataset

In [None]:
# IMPORTANT: make sure to download the dataset in install.ipynb
# https://github.com/cyizhuo/Stanford-Cars-dataset

Sample = namedtuple("Sample", ("im", "noisy_im", "noise_level"))

def alpha_blend(a, b, alpha):
    return alpha * a + (1 - alpha) * b

class CustomStanfordCarsDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.car_names = sorted(os.listdir(img_dir))  # Assuming car names are the class labels
        
        for idx, car_name in enumerate(self.car_names):
            car_dir = os.path.join(img_dir, car_name)
            for img_name in os.listdir(car_dir):
                self.image_paths.append(os.path.join(car_dir, img_name))
                self.labels.append(idx)  # Use the index of the car name as the label

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        noise = torch.rand_like(image)
        noise_level = torch.rand(1, 1, 1)
        noisy_im = alpha_blend(noise, image, noise_level)
        return Sample(image, noisy_im, noise_level)

def load_transformed_dataset():
    data_transforms = [
        transforms.Resize((Config.image_size, Config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), # Scales data into [0,1] 
        #transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
    ]
    data_transform = transforms.Compose(data_transforms)

    train = CustomStanfordCarsDataset('Stanford-Cars-dataset/train', transform=data_transform)

    test = CustomStanfordCarsDataset('Stanford-Cars-dataset/test', transform=data_transform)
    
    return torch.utils.data.ConcatDataset([train, test])
def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        #transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if not isinstance(image, torch.Tensor) or image.ndim == 4:
        image = torch.cat(tuple(image), -1)
    display(reverse_transforms(image))

d_train = load_transformed_dataset()

# View the Dataset

In [None]:
def demo_dataset(dataset, n=16):
    print(f"Dataset has {len(dataset)} samples (not counting augmentation).")
    print(f"Here are some samples from the dataset:")
    samples = random.choices(dataset, k=n)
    print(f"Inputs")
    
    show_tensor_image(s.noisy_im for s in samples)
    show_tensor_image(s.noise_level.expand(3, 16, Config.image_size) for s in samples)
    print(f"Target Outputs")
    show_tensor_image(s.im for s in samples)
demo_dataset(d_train)

# Initializing the Model

In [None]:
# Make sure n_layers is divisible by 8 for mamba and embed_dim
from dataclasses import dataclass, asdict

@dataclass
class MambaConfig:
    image_size: int = 64
    patch_size: int = 2
    embed_dim: int = 160
    dropout: float = 0
    n_layers: int = 8
    n_channels: int = 3

mamba_config = MambaConfig()

In [None]:
from mamba_denoiser import MambaDenoiser
from dit import DiT

mamba_model = MambaDenoiser(**asdict(mamba_config)).to(Config.device)
transformer_model = DiT(depth=8, hidden_size=12*32, patch_size=2, num_heads=12, input_size = 64).to(Config.device)


print(f"Mamba model has {sum(p.numel() for p in mamba_model.parameters() if p.requires_grad) / 1e6:.1f} million trainable parameters.")
print(f"Transformer model has {sum(p.numel() for p in transformer_model.parameters() if p.requires_grad) / 1e6:.1f} million trainable parameters.")

In [None]:
def weight_average(w_prev, w_new, n): # taking the average prediction of the model
    alpha = min(0.9, n / 10)
    return alpha_blend(w_prev, w_new, alpha)
    
avg_mamba_model = torch.optim.swa_utils.AveragedModel(mamba_model, avg_fn=weight_average)
avg_transformer_model = torch.optim.swa_utils.AveragedModel(transformer_model, avg_fn=weight_average)

# Train Models

In [None]:
from dit_mamba_trainloader import Trainer, generate_images

In [None]:
# Train mamba
mamba_trainer = Trainer(mamba_model, avg_mamba_model, d_train, batch_size=32, learning_rate=3e-4)
mamba_trainer.train(n_seconds=6*60*60) # change the training time if necessary 
torch.save(avg_mamba_model.state_dict(), 'avg_mamba_model.pth')
torch.save(mamba_model.state_dict(), 'mamba_model.pth')

In [None]:
# Train Transformer
transformer_trainer = Trainer(transformer_model, avg_transformer_model, d_train, batch_size=32)
transformer_trainer.train(n_seconds=6*60*60)
torch.save(avg_transformer_model.state_dict(), 'avg_transformer_model.pth')
torch.save(transformer_model.state_dict(), 'transformer_model.pth')

In [None]:
avg_mamba_model.load_state_dict(torch.load('avg_mamba_model.pth'))
mamba_model.load_state_dict(torch.load('mamba_model.pth'))
avg_transformer_model.load_state_dict(torch.load('avg_transformer_model.pth'))
transformer_model.load_state_dict(torch.load('transformer_model.pth'))

# Generate the images

In [None]:
def demo_sample_grids(dataset, model, rows=6, cols=6, n_steps=100, step_size=2):
    torch.manual_seed(16) # change the seed if necessary 
    real_rows, fake_rows = [], []
    for i in tqdm(range(rows)):
        real_rows.append(torch.cat([random.choice(dataset).im for _ in range(cols)], -1))
        fake_rows.append(torch.cat(tuple(generate_images(model, n_images=cols, n_steps = n_steps, step_size = step_size)), -1))
    real_im = torch.cat(real_rows, -2)
    padding = torch.ones_like(real_im[..., :32])
    fake_im = torch.cat(fake_rows, -2).cpu()
    return TF.to_pil_image(torch.cat([real_im, padding, fake_im], -1))

In [None]:
# Generate images for the mamba model:
demo_sample_grids(d_train, avg_mamba_model, n_steps=100, step_size=3)

In [None]:
# Generate images for the Transformer model:
demo_sample_grids(d_train, avg_transformer_model, n_steps=100, step_size=3)

# FID score

In [None]:
from fid import fid_score

def calculate_FID(dataset, model, n_steps=300, step_size=1, n_samples=100):
    real_list, fake_list = [dataset[i].im for i in range(n_samples)], []
    for i in tqdm(range(int(len(real_list)/100))):
        f = generate_images(model, n_images=100, n_steps=n_steps, step_size=step_size)
        for z in f:
            fake_list.append(z)

    fid = fid_score(real_list, fake_list, device=Config.device)
    return fid


In [None]:
mamba_fid = 0
with torch.no_grad():
    mamba_fid = calculate_FID(d_train, avg_mamba_model, n_steps=100, step_size=3, n_samples=16000)
print("FID score for mamba model: ", mamba_fid)

In [None]:
transformer_fid = 0
with torch.no_grad():
    transformer_fid = calculate_FID(d_train, avg_transformer_model, n_steps=100, step_size=3, n_samples=16000)
print("FID score for transformer model: ", transformer_fid)