In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset, Dataset, Subset
import torchvision
from torchvision import transforms, datasets
import copy
import math
import random
from torch.cuda import amp
from torch.utils.data import random_split
torch.manual_seed(0)

In [None]:
# Hyperparameters
beta_start = 1e-4
beta_end = 0.02
steps = 1000
device = "cuda" if torch.cuda.is_available() else "cpu"
image_size = 64
image_channel = 3
epochs = 200
lr = 1e-3
weight_decay = 1e-3
batch_size = 8
num_class = 200
pos_dim = 1024
gradient_accumulation_step = 1

In [None]:
# This is the utils file
def zero_out(layer):
    for p in layer.parameters():
        p.detach().zero_()
    return layer

def positional_embedding_creator(num_step: int, pos_dim: int):
    matrix = torch.zeros(num_step, pos_dim)
    for i in range(num_step):
        for j in range(0, pos_dim, 2):
            matrix[i, j] = np.sin(i/(10000**(j/pos_dim)))
            if(j+1<pos_dim):
                matrix[i, j+1] = np.cos(i/(10000**(j/pos_dim)))

    return matrix

In [None]:
# Diffusion model

# AdaGN according to paper "Diffusion Models Beat GANs on Image Synthesis"
class AdaNorm(nn.Module):
    def __init__(self, num_channel: int):
        super().__init__()
        num_group = int(num_channel/16) # According to group norm paper, 16 channels per group produces the best result
        self.gnorm = nn.GroupNorm(num_group, num_channel, affine=False)

    def forward(self, tensor: torch.Tensor, emb: torch.Tensor):
        scale, shift = torch.chunk(emb, 2, dim=1)

        tensor = self.gnorm(tensor)
        tensor = tensor * (1 + scale) + shift
        return tensor


class MyGroupNorm(nn.Module):
    def __init__(self, num_channel: int):
        super().__init__()
        num_group = int(num_channel/16) # According to group norm paper, 16 channels per group produces the best result
        self.gnorm = nn.GroupNorm(num_group, num_channel, affine=False)

    def forward(self, tensor: torch.Tensor):
        return self.gnorm(tensor)


class ResBlock(nn.Module):
    def __init__(self, in_channel: int, out_channel: int, emb_dim: int = 1024, up: bool = False, down: bool = False):
        super().__init__()
        self.emb = nn.Sequential(nn.SiLU(), nn.Linear(emb_dim, 2*out_channel))
        if up:
            self.change_size = nn.Upsample(scale_factor=2, mode='nearest')
        elif down:
            self.change_size = nn.AvgPool2d(kernel_size=2, stride=2)
        else:
            self.change_size = nn.Identity()

        # Normalization
        self.gnorm1 = MyGroupNorm(in_channel)
        self.gnorm2 = AdaNorm(out_channel)

        # Convolution
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size = 3, padding = 1)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size = 3, padding = 1)
        if in_channel != out_channel:
            self.conv3 = nn.Conv2d(in_channel, out_channel, kernel_size = 1)
        else:
            self.conv3 = nn.Identity()

        # Combine input stage
        self.input = nn.Sequential(
            self.gnorm1,
            nn.SiLU(),
            self.change_size,
            self.conv1
        )

        # Combine output stage
        self.output = nn.Sequential(
            nn.SiLU(),
            zero_out(self.conv2)
        )

        # Skip connection
        self.skip_connection = nn.Sequential(
            self.change_size,
            self.conv3
        )

        # Embedding
        self.embed = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.SiLU(),
            nn.Linear(emb_dim, 2 * out_channel)
        )

    def forward(self, tensor: torch.Tensor, emb: torch.Tensor):
        emb = self.embed(emb).view(tensor.shape[0], -1, 1, 1)

        h = self.input(tensor)
        h = self.gnorm2(h, emb)
        h = self.output(h)
        x = self.skip_connection(tensor)

        return x + h


class SelfAttention(nn.Module):
    def __init__(self, channel: int):
        super().__init__()
        self.num_head = int(channel/32)

    def forward(self, tensor: torch.Tensor):
        batch, channel, length = tensor.shape
        ch = channel // 3 // self.num_head
        q, k, v = tensor.chunk(3, dim = 1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        # The code below is from Diffusion Model Beat GANs on Image Synthesis paper code
        weight = torch.einsum(
            "bct,bcs->bts",
            (q * scale).view(batch * self.num_head, ch, length),
            (k * scale).view(batch * self.num_head, ch, length),
        )  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = torch.einsum("bts,bcs->bct", weight, v.reshape(batch * self.num_head, ch, length))
        return a.reshape(batch, -1, length)


class Attention(nn.Module):
    def __init__(self, channel: int):
        super().__init__()
        self.gnorm = MyGroupNorm(channel)
        self.qkv = nn.Conv1d(channel, channel * 3, 1)
        self.attention = SelfAttention(channel)
        self.output = zero_out(nn.Conv1d(channel, channel, 1))

    def forward(self, tensor: torch.Tensor):
        # Perform self attention
        batch, channel, width, height = tensor.shape
        tensor = tensor.reshape(batch, channel, -1)
        # Skip connection
        tensor_skip = tensor
        tensor = self.gnorm(tensor)
        tensor = self.qkv(tensor)
        tensor = self.attention(tensor)
        tensor = self.output(tensor)

        # Adding the skip connection tensor back to the current tensor
        tensor = tensor + tensor_skip

        tensor = tensor.reshape(batch, channel, width, height)
        return tensor


class UNet(nn.Module):
    def __init__(self, image_channel: int = 3, depth: int = 2, emb_dim: int = 1024, num_step = 1000, num_classes = 10):
        super().__init__()

        # Create model architecture
        channels = [160, 320, 640, 1280]
        attention_channel = [320, 640, 1280]
        self.encoder = nn.ModuleList([nn.ModuleList([nn.Conv2d(image_channel, channels[0], 3, padding=1)])])
        self.decoder = nn.ModuleList()

        skip_channel = [channels[0]]

        # Encoder
        for i in range(len(channels)):
            for _ in range(depth):
                layer = nn.ModuleList()
                layer.append(ResBlock(channels[i], channels[i], emb_dim = emb_dim))
                if channels[i] in attention_channel:
                    layer.append(Attention(channels[i]))
                self.encoder.append(layer)
                skip_channel.append(channels[i])

            if i != len(channels)-1:
                layer = nn.ModuleList()
                layer.append(ResBlock(channels[i], channels[i + 1], down=True, emb_dim = emb_dim))
                self.encoder.append(layer)
                skip_channel.append(channels[i+1])

        # Bottleneck
        self.bottle_neck = nn.ModuleList([
            ResBlock(channels[-1], channels[-1]),
            Attention(channels[-1]),
            ResBlock(channels[-1], channels[-1]),
        ])

        # Decoder
        for i in range(len(channels)-1, -1, -1):
            for block in range(depth+1):
                layer = nn.ModuleList()
                layer.append(ResBlock(channels[i] + skip_channel.pop(), channels[i], emb_dim = emb_dim))
                if channels[i] in attention_channel:
                    layer.append(Attention(channels[i]))

                if i != 0 and block == depth:
                    layer.append(ResBlock(channels[i], channels[i - 1], up=True, emb_dim = emb_dim))

                self.decoder.append(layer)

        # Create time embedding
        self.time_embedding = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.SiLU(),
            nn.Linear(emb_dim, emb_dim)
        )

        # Create class embedding
        self.class_embedding = nn.Embedding(num_classes, emb_dim)

        # Output kernels to change back to image channel
        self.out = nn.Sequential(
            MyGroupNorm(channels[0]),
            nn.SiLU(),
            zero_out(nn.Conv2d(channels[0], image_channel, 3, padding=1)),
        )

    def forward(self, tensor: torch.Tensor, time_embedding: torch.Tensor, label: torch.Tensor | None):
        # Creating embedding
        embedding = self.time_embedding(time_embedding)
        if label != None:
            class_embedding = self.class_embedding(label)
            embedding = embedding + class_embedding

        skip_connection = []

        # Encoder
        for layer in self.encoder:
            for module in layer:
                if(isinstance(module, ResBlock)):
                    tensor = module(tensor, embedding)
                else:
                    tensor = module(tensor)

            skip_connection.append(tensor)

        # Bottleneck
        for module in self.bottle_neck:
            if(isinstance(module, ResBlock)):
                tensor = module(tensor, embedding)
            else:
                tensor = module(tensor)

        # Decoder
        for layer in self.decoder:
            tensor = torch.concatenate((tensor, skip_connection.pop()), dim = 1)
            for module in layer:
                if(isinstance(module, ResBlock)):
                    tensor = module(tensor, embedding)
                else:
                    tensor = module(tensor)

        tensor = self.out(tensor)

        return tensor

class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())

In [None]:
# Constants used for diffusion model
beta = torch.linspace(beta_start, beta_end, steps).to(device)
sqrt_beta = torch.sqrt(beta).view(-1, 1, 1, 1)
alpha = 1 - beta
alphas_cumprod = torch.cumprod(alpha, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).view(-1, 1, 1, 1)
one_minus_alphas_cumprod = 1 - alphas_cumprod
sqrt_one_minus_alphas_cumprod = torch.sqrt(one_minus_alphas_cumprod).view(-1, 1, 1, 1)
one_over_sqrt_alpha = 1/torch.sqrt(alpha).view(-1, 1, 1, 1)
one_minus_alpha = (1 - alpha).view(-1, 1, 1, 1)

In [None]:
# Forward pass
def forward_pass(images, t):
    batch_sqrt_alphas_cumprod = sqrt_alphas_cumprod[t]
    batch_sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[t]
    noise = torch.randn_like(images).to(device)

    return batch_sqrt_alphas_cumprod * images + batch_sqrt_one_minus_alphas_cumprod * noise, noise

In [None]:
# Positional embedding
pos_emb_matrix = positional_embedding_creator(steps, pos_dim).to(device)

In [None]:
# Sampling(inference)

def sampling(model, labels, cfg_scale: int = 3):
    model.eval()
    with torch.no_grad():
        x = torch.randn(labels.shape[0], image_channel, image_size, image_size).to(device)

        for i in range(steps-1, -1, -1):
            t = torch.tensor([i]*labels.shape[0])
            pos_emb = pos_emb_matrix[t].to(device)

            # Classifier free guidance
            predicted_noise_no_label = model(x, pos_emb, None)
            predicted_noise_with_label = model(x, pos_emb, labels)
            predicted_noise = torch.lerp(predicted_noise_no_label, predicted_noise_with_label, cfg_scale)

            if(i == 0):
                noise = torch.zeros_like(x).to(device)
            else:
                noise = torch.randn_like(x).to(device)

            x = one_over_sqrt_alpha[t] * (x - ((one_minus_alpha[t])/(sqrt_one_minus_alphas_cumprod[t]))*predicted_noise) + sqrt_beta[t] * noise

    model.train()

    x = (x.clamp(-1, 1) + 1) / 2
    x = (x * 255).type(torch.uint8)

    for i in range(x.shape[0]):
        tensor = x[i].permute(1, 2, 0).to("cpu")
        plt.imshow(tensor)
        plt.show()

In [None]:
# Define the transformation
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the training set
trainset = datasets.ImageFolder(root='./tiny-imagenet-200/train', transform=transform)

# Split the dataset
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
train_dataset, val_dataset = random_split(trainset, [train_size, val_size])

# Create a DataLoader for the combined dataset
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Instantiate the model
unet = UNet(num_classes=num_class).to(device)
print("This model has", sum(p.numel() for p in unet.parameters()), "parameters.")
scaler = amp.GradScaler()
loss_train = []
loss_valid = []

In [None]:
# Set up optimizer and loss
optimizer = opt.AdamW(unet.parameters(), lr = lr, weight_decay = weight_decay)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)

ema = EMA(0.9999) # 0.9999 according to the diffusion model beat GANs paper.
ema_model = copy.deepcopy(unet).eval().requires_grad_(False)

In [None]:
# Training
for epoch in range(epochs):
    train_loss_list = []
    valid_loss_list = []
    for batch_idx, (images, label) in enumerate(tqdm(train_dataloader)):
        # Preparing for forward pass
        images = images.to(device)
        label = label.to(device)
        t = torch.randint(1, steps, size = (images.shape[0], ))
        pos_emb = pos_emb_matrix[t].to(device)
        x_t, noise = forward_pass(images, t)

        # Classifier free guidance.
        if random.random() < 0.1:
            label = None

        with amp.autocast():
            # Forward pass
            predicted_noise = unet(x_t, pos_emb, label)
            loss = criterion(predicted_noise, noise)
            # Back propagation
            loss = loss / gradient_accumulation_step
            scaler.scale(loss).backward()
            train_loss_list.append(loss.item())

        if (batch_idx + 1) % gradient_accumulation_step == 0 or (batch_idx + 1) == len(train_dataloader):
            torch.nn.utils.clip_grad_norm_(unet.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            # Zero out grad
            optimizer.zero_grad()

            # EMA and loss
            ema.step_ema(ema_model, unet)
        

    for images, label in tqdm(valid_dataloader):
        # Preparing for forward pass
        images = images.to(device)
        label = label.to(device)
        t = torch.randint(1, steps, size = (images.shape[0], ))
        pos_emb = pos_emb_matrix[t].to(device)
        x_t, noise = forward_pass(images, t)

        # Forward pass
        with amp.autocast():
            predicted_noise = unet(x_t, pos_emb, label)
            loss = criterion(predicted_noise, noise)
            valid_loss_list.append(loss.item())

    # Step the learning rate
    scheduler.step()

    print(f"Epoch #{epoch}")
    print(f"Current learning rate is {optimizer.param_groups[0]['lr']}")
    print("Train Loss is:", sum(train_loss_list)/len(train_loss_list))
    loss_train.append(sum(train_loss_list)/len(train_loss_list))
    print("Valid Loss is:", sum(valid_loss_list)/len(valid_loss_list))
    loss_valid.append(sum(valid_loss_list)/len(valid_loss_list))
    label = torch.tensor([26, 26, 26, 26, 26, 26]).to(device)
    sampling(ema_model, label)

# Testing

In [None]:
label = torch.tensor([0, 0, 0, 0, 0, 0]).to(device)
sampling(ema_model, label, 4)

In [None]:
torch.save(ema_model, "diffusion_CIFAR10.pth")

In [None]:
torch.save(unet, "diffusion_CIFAR10_student.pth")