In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from torch.cuda import amp
from tqdm import tqdm
import random
import numpy as np
import math
import os
from torch.utils.data import DataLoader

torch.manual_seed(0)

<torch._C.Generator at 0x79a3e45780f0>

In [12]:
# Hyperparameters
beta_start = 1e-4
beta_end = 0.02
steps = 1000
device_id = 0
image_size = 64
image_channel = 3
epochs = 300
lr = 3e-4
weight_decay = 0
batch_size = 1
num_class = 1000
pos_dim = 1024
dataset_filepath = "/media/danjie_tang/Danjie HDD/Imagenet/ILSVRC/Data/CLS-LOC/train"

In [3]:
# Constants used for diffusion model
beta = torch.linspace(beta_start, beta_end, steps).cuda(device_id)
sqrt_beta = torch.sqrt(beta).view(-1, 1, 1, 1)
alpha = 1 - beta
alphas_cumprod = torch.cumprod(alpha, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).view(-1, 1, 1, 1)
one_minus_alphas_cumprod = 1 - alphas_cumprod
sqrt_one_minus_alphas_cumprod = torch.sqrt(one_minus_alphas_cumprod).view(-1, 1, 1, 1)
one_over_sqrt_alpha = 1/torch.sqrt(alpha).view(-1, 1, 1, 1)
one_minus_alpha = (1 - alpha).view(-1, 1, 1, 1)

In [4]:
# util.py

# Forward pass
def forward_pass(images, t):
    batch_sqrt_alphas_cumprod = sqrt_alphas_cumprod[t]
    batch_sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[t]
    noise = torch.randn_like(images).cuda(device_id)

    return batch_sqrt_alphas_cumprod * images + batch_sqrt_one_minus_alphas_cumprod * noise, noise

# Sampling
def sampling(model, labels, cfg_scale: int = 3):
    model.eval()
    with torch.no_grad():
        x = torch.randn(labels.shape[0], image_channel, image_size, image_size).cuda(device_id)

        for i in tqdm(range(steps-1, -1, -1)):
            t = torch.tensor([i]*labels.shape[0]).cuda(device_id)

            # Classifier free guidance
            predicted_noise_no_label = model(x, t, None)
            predicted_noise_with_label = model(x, t, labels)
            predicted_noise = torch.lerp(predicted_noise_no_label, predicted_noise_with_label, cfg_scale)

            if(i == 0):
                noise = torch.zeros_like(x).cuda(device_id)
            else:
                noise = torch.randn_like(x).cuda(device_id)

            x = one_over_sqrt_alpha[t] * (x - ((one_minus_alpha[t])/(sqrt_one_minus_alphas_cumprod[t]))*predicted_noise) + sqrt_beta[t] * noise

    model.train()

    x = (x.clamp(-1, 1) + 1) / 2
    x = (x * 255).type(torch.uint8)

    for i in range(x.shape[0]):
        tensor = x[i].permute(1, 2, 0).to("cpu")
        plt.imshow(tensor)
        plt.show()

def positional_embedding(num_step: int, emb_dim: int) -> torch.Tensor:
    """
    Create positional embedding tensor.

    :param num_step: Number of time steps.
    :param emb_dim: Embedding dimension.
    :return: Positional embedding tensor.
    """
    matrix = torch.zeros(num_step, emb_dim)
    for i in range(num_step):
        for j in range(0, emb_dim, 2):
            matrix[i, j] = np.sin(i/(10000**(j/emb_dim)))
            if(j+1<emb_dim):
                matrix[i, j+1] = np.cos(i/(10000**(j/emb_dim)))

    return matrix

def sinusoidal_positional_encoding_2d(height: int, width: int, channel: int) -> torch.Tensor:
    """
    Generate a 2D sinusoidal positional encoding.

    :param height: The height of the encoding.
    :param width: The width of the encoding.
    :param channel: The number of channels in the encoding.
    :return: A tensor of shape (height, width, channel) containing the 2D positional encoding.
    """
    if channel % 2 != 0:
        raise ValueError("The 'channel' dimension must be an even number.")

    # First, build in (height, width, channel) format
    pe = torch.zeros(height, width, channel)

    half_ch = channel // 2

    # Precompute the exponent for row and column
    row_div_term = torch.exp(
        -math.log(10000.0) * (torch.arange(0, half_ch, 2).float() / half_ch)
    )
    col_div_term = torch.exp(
        -math.log(10000.0) * (torch.arange(0, half_ch, 2).float() / half_ch)
    )

    for h in range(height):
        for w in range(width):
            # Encode row index (h) into the first half of the channels
            for i in range(0, half_ch, 2):
                pe[h, w, i]     = math.sin(h * row_div_term[i // 2])
                pe[h, w, i + 1] = math.cos(h * row_div_term[i // 2])

            # Encode column index (w) into the second half of the channels
            for j in range(0, half_ch, 2):
                pe[h, w, half_ch + j]     = math.sin(w * col_div_term[j // 2])
                pe[h, w, half_ch + j + 1] = math.cos(w * col_div_term[j // 2])

    # Permute to get the shape (channel, width, height).
    # Currently pe is (height, width, channel) = (H, W, C)
    # We want (C, W, H), so we do permute(2, 1, 0).
    pe = pe.permute(2, 1, 0)  # => (channel, width, height)

    return pe

def zero_out(layer):
    for p in layer.parameters():
        p.detach().zero_()
    return layer

In [5]:
# model.py

class AdaNorm(nn.Module):
    def __init__(self, num_channel: int, channel_per_group: int = 16, emb_dim: int = 1024):
        super().__init__()
        assert num_channel % channel_per_group == 0, "num_channel must be divisible by channel_per_group"
        num_group = num_channel // channel_per_group
        self.embedding_proj = nn.Sequential(
            nn.Linear(emb_dim, 2 * num_channel),
            nn.ReLU()
        )
        self.gnorm = nn.GroupNorm(num_group, num_channel, affine=False)

    def forward(self, tensor: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor:
        """
        Perform adanormalization on input tensor.

        :param tensor: Input tensor to be normalized.
        :param embedding: Embedding tensor containing time embedding and potentially class embedding.
        :return: Normalized tensor.
        """
        embedding = self.embedding_proj(embedding)
        embedding = embedding.view(embedding.shape[0], embedding.shape[1], 1, 1)
        scale, shift = torch.chunk(embedding, 2, dim=1)

        tensor = self.gnorm(tensor)
        # tensor = tensor * torch.sigmoid(scale) + shift
        tensor = tensor * (1 + scale) + shift
        return tensor

class ResBlock(nn.Module):
    def __init__(self, in_channel: int, out_channel: int, emb_dim: int = 1024, up: bool = False, down: bool = False, channel_per_group: int = 16):
        super().__init__()

        # Upsampling or downsampling only for skip connection
        self.up = up
        self.down = down

        # Normalization layers
        self.norm1 = AdaNorm(in_channel, emb_dim=emb_dim)
        self.norm2 = AdaNorm(out_channel, emb_dim=emb_dim)

        # Convolution layers
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
        self.conv2 = zero_out(nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1))

        # Skip connection
        if in_channel != out_channel or up or down:
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1),
                nn.Upsample(scale_factor=2) if up else nn.Identity(),
                nn.AvgPool2d(kernel_size=2) if down else nn.Identity(),
            )
        else:
            self.skip_connection = nn.Identity()

    def forward(self, tensor: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor:
        skip_tensor = self.skip_connection(tensor)

        # Main path
        tensor = self.norm1(tensor, embedding)
        tensor = F.relu(tensor)
        if self.up:
            tensor = F.interpolate(tensor, scale_factor=2)
        if self.down:
            tensor = F.avg_pool2d(tensor, kernel_size=2)
        tensor = self.conv1(tensor)
        tensor = self.norm2(tensor, embedding)
        tensor = F.relu(tensor)
        tensor = self.conv2(tensor)

        tensor += skip_tensor
        return tensor


class SelfAttentionBlock(nn.Module):
    def __init__(self, embedding_dim: int, image_size: int, head_dim: int = 64, channel_per_group: int = 16):
        super().__init__()
        self.head_dim: int = head_dim
        self.num_head: int = embedding_dim // head_dim
        self.scale: float = head_dim ** -0.5
        self.num_pixel = image_size ** 2
        self.gnorm1 = nn.GroupNorm(embedding_dim // channel_per_group, embedding_dim)
        self.gnorm2 = nn.GroupNorm(embedding_dim // channel_per_group, embedding_dim)

        # QKV projection
        self.qkv_proj = nn.Linear(embedding_dim, embedding_dim * 3)

        # Output layer
        self.output = zero_out(nn.Conv2d(embedding_dim, embedding_dim, kernel_size=1))

        # Positional embedding for patches
        self.positional_encoding = nn.Parameter(sinusoidal_positional_encoding_2d(image_size, image_size, embedding_dim))
        self.positional_encoding.requires_grad_(False)

        # Feed Forward Layer
        self.ffn1 = nn.Conv2d(embedding_dim, embedding_dim * 8, kernel_size=1)
        self.ffn2 = nn.Conv2d(embedding_dim * 8, embedding_dim, kernel_size=1)

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        skip_tensor = tensor

        tensor = self.gnorm1(tensor)

        # Reshape for self attention
        batch_size, channel, height, width = tensor.shape
        tensor = tensor + self.positional_encoding
        tensor = tensor.view(batch_size, channel, self.num_pixel)
        tensor = tensor.permute(0, 2, 1)

        tensor = self.qkv_proj(tensor)

        query, key, value = torch.chunk(tensor, 3, dim=-1)
        query = query.view(batch_size, self.num_pixel, self.num_head, self.head_dim)
        key = key.view(batch_size, self.num_pixel, self.num_head, self.head_dim)
        value = value.view(batch_size, self.num_pixel, self.num_head, self.head_dim)

        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        # Self attention
        attention_raw = torch.matmul(query, key.transpose(2, 3))
        attention_scaled = attention_raw * self.scale
        attention_score = torch.softmax(attention_scaled, dim=-1)
        value = torch.matmul(attention_score, value)

        # Reshape for self attention output
        tensor = value.transpose(1, 2).contiguous()
        tensor = tensor.view(batch_size, self.num_pixel, channel)
        tensor = tensor.permute(0, 2, 1)
        tensor = tensor.reshape(batch_size, channel, height, width)
        tensor = self.output(tensor)

        tensor = tensor + skip_tensor

        # Feed Forward Layer
        tensor = self.gnorm2(tensor)
        tensor = self.ffn1(tensor)
        tensor = F.relu(tensor)
        tensor = self.ffn2(tensor)

        return tensor

class UNet(nn.Module):
    def __init__(self, image_channel: int = 3, image_size: int = 64, channels: list[int] = [64, 128, 256, 512], attention_channels = [128, 256, 512], depth: int = 2, emb_dim: int = 1024, num_step: int = 1000, num_classes: int = 10, channel_per_group: int = 16, patch_size: int = 2, head_dim: int = 32):
        super().__init__()
        self.encoder = nn.ModuleList([nn.ModuleList([nn.Conv2d(image_channel, channels[0], 3, padding=1)])])
        self.decoder = nn.ModuleList()
        skip_channel = [channels[0]]
        image_size = [image_size // (2**i) for i in range(len(channels))]

        self.positional_encoding = nn.Embedding(num_step, emb_dim)
        self.positional_encoding.weight.data.copy_(positional_embedding(num_step, emb_dim))
        self.positional_encoding.weight.requires_grad = False

        # Encoder
        for i in range(len(channels)):
            for _ in range(depth):
                layer = nn.ModuleList()
                layer.append(ResBlock(channels[i], channels[i], emb_dim = emb_dim, channel_per_group=channel_per_group))
                if channels[i] in attention_channels:
                    layer.append(SelfAttentionBlock(channels[i], image_size[i]))
                self.encoder.append(layer)
                skip_channel.append(channels[i])

            # Down projection
            if i != len(channels)-1:
                layer = nn.ModuleList()
                layer.append(ResBlock(channels[i], channels[i + 1], down=True, emb_dim = emb_dim, channel_per_group=channel_per_group))
                self.encoder.append(layer)
                skip_channel.append(channels[i+1])

        # Bottleneck
        self.bottle_neck = nn.ModuleList([
            ResBlock(channels[-1], channels[-1], channel_per_group=channel_per_group),
            SelfAttentionBlock(channels[-1], image_size[-1]),
            ResBlock(channels[-1], channels[-1], channel_per_group=channel_per_group),
        ])

        # Decoder
        for i in range(len(channels)-1, -1, -1):
            for _ in range(depth):
                layer = nn.ModuleList()
                layer.append(ResBlock(channels[i] + skip_channel.pop(), channels[i], emb_dim = emb_dim))
                if channels[i] in attention_channels:
                    layer.append(SelfAttentionBlock(channels[i], image_size[i]))
                self.decoder.append(layer)

            # Up projection
            if i != 0:
                layer = nn.ModuleList()
                layer.append(ResBlock(channels[i] + skip_channel.pop(), channels[i - 1], up=True, emb_dim = emb_dim))
                self.decoder.append(layer)
            else:
                layer = nn.ModuleList()
                layer.append(ResBlock(channels[i] + skip_channel.pop(), channels[0], emb_dim = emb_dim))
                self.decoder.append(layer)

        self.time_embedding_proj = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.ReLU()
        )
        self.class_embedding = nn.Embedding(num_classes, emb_dim)

        # Output kernels to change back to image channel
        self.out = nn.Sequential(
            nn.GroupNorm(channels[0] // channel_per_group, channels[0], affine=False),
            nn.SiLU(),
            zero_out(nn.Conv2d(channels[0], image_channel, kernel_size = 1)),
        )

    def forward(self, tensor: torch.Tensor, time_step: torch.Tensor, label: torch.Tensor = None) -> torch.Tensor:
        """
        Diffusion model.

        :param tensor: Input tensor.
        :param time_step: Time step tensor.
        :param label: Label tensor.
        :return: Predicted noise.
        """
        embedding = self.positional_encoding(time_step)
        embedding = self.time_embedding_proj(embedding)

        if label != None:
            class_embedding = self.class_embedding(label)
            embedding = embedding + class_embedding

        skip_connection = []

        # Encoder
        for layer in self.encoder:
            for module in layer:
                if(isinstance(module, ResBlock)):
                    tensor = module(tensor, embedding)
                else:
                    tensor = module(tensor)

            skip_connection.append(tensor)

        # Bottleneck
        for module in self.bottle_neck:
            if(isinstance(module, ResBlock)):
                tensor = module(tensor, embedding)
            else:
                tensor = module(tensor)

        # Decoder
        for layer in self.decoder:
            tensor = torch.concatenate((tensor, skip_connection.pop()), dim = 1)
            for module in layer:
                if(isinstance(module, ResBlock)):
                    tensor = module(tensor, embedding)
                else:
                    tensor = module(tensor)

        tensor = self.out(tensor)

        return tensor

In [6]:
# Define image transformations for preprocessing
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def valid_image_folder(path: str) -> bool:
    # Check if file starts with '._' or ends with '.DS_Store'
    filename = os.path.basename(path)
    if filename.startswith("._") or filename == ".DS_Store": # Stupid MacOS
        return False
    
    return True

# Use ImageFolder to automatically label images based on folder names
dataset = datasets.ImageFolder(root=dataset_filepath, is_valid_file=valid_image_folder, transform=transform)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create DataLoaders for training and validation
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

NameError: name 'DataLoader' is not defined

In [9]:
# Instantiate the model
unet = UNet(num_classes=num_class).cuda(device_id)
print("This model has", sum(p.numel() for p in unet.parameters()), "parameters.")
scaler = amp.GradScaler()
loss_train = []
loss_valid = []

This model has 118823491 parameters.


  scaler = amp.GradScaler()


In [10]:
# Set up optimizer and loss
optimizer = opt.AdamW(unet.parameters(), lr = lr, weight_decay = weight_decay)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*len(train_dataloader), eta_min=1e-5)

In [11]:
for epoch in range(epochs):
    train_loss_list = []
    valid_loss_list = []
    for images, label in tqdm(train_dataloader):
        # Zero out grad
        optimizer.zero_grad()

        # Preparing for forward pass
        images = images.cuda(device_id)
        label = label.cuda(device_id)
        time_step = torch.randint(1, steps, size = (images.shape[0], )).cuda(device_id)
        x_t, noise = forward_pass(images, time_step)

        # Classifier free guidance.
        if random.random() < 0.1:
            label = None

        # Forward pass
        with amp.autocast():
            predicted_noise = unet(x_t, time_step, label)
            loss = criterion(predicted_noise, noise)

        # Back propagation
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(unet.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        # Record loss
        train_loss_list.append(loss.item())

        # Step the learning rate
        scheduler.step()

    if(epoch % 10 == 0):
        with torch.no_grad():
            for images, label in tqdm(valid_dataloader):
                # Preparing for forward pass
                images = images.cuda(device_id)
                label = label.cuda(device_id)
                time_step = torch.randint(1, steps, size = (images.shape[0], )).cuda(device_id)
                x_t, noise = forward_pass(images, time_step)

                # Forward pass
                with amp.autocast():
                    predicted_noise = unet(x_t, time_step, label)
                    loss = criterion(predicted_noise, noise)
                valid_loss_list.append(loss.item())

    print(f"Epoch #{epoch}")
    print(f"Current learning rate is {optimizer.param_groups[0]['lr']}")
    print("Train Loss is:", sum(train_loss_list)/len(train_loss_list))
    loss_train.append(sum(train_loss_list)/len(train_loss_list))
    if(epoch % 10 == 0):
        print("Valid Loss is:", sum(valid_loss_list)/len(valid_loss_list))
        loss_valid.append(sum(valid_loss_list)/len(valid_loss_list))
    if(epoch % 10 == 0):
        label = torch.tensor([0, 1]).cuda(device_id)
        sampling(unet, label)

  with amp.autocast():
  0%|                                                                     | 0/1153050 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (128) must match the size of tensor b (32) at non-singleton dimension 3

In [11]:
torch.save(unet, "diffusion_CIFAR10.pth")