In [1]:
import torch
import torch.nn as nn
import torch.optim as opt
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import os

# Reproducibility
import random
seed = 42
torch.manual_seed(seed)
random.seed(seed)

In [2]:
# Hyperparameters
device = "mps"
weight_decay = 0.01
lr = 3e-4
batch_size = 60
epochs = 1000
dataset_filepath = "./ImagenetHighResolution"
image_size = 32
num_steps = 1000
channel_mean = [0.4918695390224457, 0.4826536178588867, 0.44717657566070557] # CIFAR10
channel_std = [0.20224887132644653, 0.19936397671699524, 0.2009899914264679] # CIFAR10
# channel_mean = [0.4753786623477936, 0.4518871307373047, 0.40141433477401733] # ImageNet
# channel_std = [0.22465701401233673, 0.22012709081172943, 0.22106702625751495] # ImageNet
num_class = 10
channel_num = 128
dropout_ratio = 0.1
block_per_resolution = 2

In [3]:
class AdaNorm(nn.Module):
    def __init__(self, channel_num: int, embedding_dim: int):
        super().__init__()
        self.normalization = nn.BatchNorm2d(channel_num, affine=False)
        self.scale_shift = nn.Linear(embedding_dim, channel_num*2)
        self.activation = nn.SiLU()

    def forward(self, tensor: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor:
        tensor = self.normalization(tensor)
        embedding = self.activation(embedding)
        scale_shift_tensor = self.scale_shift(embedding)
        scale_shift_tensor = scale_shift_tensor[..., None, None]
        scale, shift = torch.chunk(scale_shift_tensor, 2, dim=1)
        tensor = tensor * (1 + scale) + shift
        return tensor

class BigGANsBlock(nn.Module):
    def __init__(self, in_channel: int, embedding_dim: int, dropout: float = 0.5, out_channel: int | None = None, up: bool = False, down: bool = False):
        super().__init__()
        # Init value of out_channel
        if not out_channel:
            out_channel = in_channel
        
        if down:
            self.down_sample = nn.MaxPool2d(2)

        if in_channel == out_channel:
            self.skip_connection = nn.Identity()
        else:
            self.skip_connection = nn.Conv2d(in_channel, out_channel, 1)

        self.batch_norm = nn.BatchNorm2d(in_channel, affine=False)
        self.activation = nn.SiLU()
        self.conv1 = nn.Conv2d(in_channel, out_channel, 3, padding=1)
        self.ada_norm = AdaNorm(out_channel, embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv2d(out_channel, out_channel, 3, padding=1)
        self.up = up
        self.down = down

    def forward(self, tensor: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor:
        # Skip connection
        skip_tensor = tensor
        if self.up:
            skip_tensor = F.interpolate(skip_tensor, scale_factor=2, mode="nearest")
        elif self.down:
            skip_tensor = self.down_sample(skip_tensor)
        skip_tensor = self.skip_connection(skip_tensor)

        # Main connection
        tensor = self.batch_norm(tensor)
        tensor = self.activation(tensor)
        if self.up:
            tensor = F.interpolate(tensor, scale_factor=2, mode="nearest")
        elif self.down:
            tensor = self.down_sample(tensor)
        tensor = self.conv1(tensor)

        tensor = self.ada_norm(tensor, embedding)
        tensor = self.activation(tensor)
        tensor = self.dropout(tensor)
        tensor = self.conv2(tensor)

        return tensor + skip_tensor

class Attention(nn.Module):
    def __init__(self, channel_num: int, channel_per_head: int = 64, dropout_ratio: float = 0.5):
        super().__init__()
        self.qkv = nn.Linear(channel_num, channel_num*3)
        self.o = nn.Linear(channel_num, channel_num)
        self.scaler = 1/math.sqrt(channel_per_head)
        self.attn_dropout = nn.Dropout(dropout_ratio)
        self.norm = nn.BatchNorm2d(channel_num, affine=False)

        assert channel_num % channel_per_head == 0
        self.num_head = int(channel_num / channel_per_head)
        self.head_dim = channel_per_head

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        # Skip connection
        skip_connection = tensor

        # Record shapes
        batch_size, channel_num, height, width = tensor.shape
        num_pixel: int = height * width

        # Classic self attention without attention mask
        tensor = self.norm(tensor)
        tensor = tensor.reshape(batch_size, channel_num, num_pixel).permute(0, 2, 1).contiguous()
        qkv_tensor = self.qkv(tensor)
        query, key, value = qkv_tensor.split([channel_num, channel_num, channel_num], dim=-1)

        query = query.contiguous().reshape(batch_size, num_pixel, self.num_head, self.head_dim)
        key = key.contiguous().reshape(batch_size, num_pixel, self.num_head, self.head_dim)
        value = value.contiguous().reshape(batch_size, num_pixel, self.num_head, self.head_dim)

        # Switch to batch_size, self.num_head, num_pixel, self.head_dim
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        attention_raw = torch.matmul(query, key.transpose(2, 3))
        attention_scaled = attention_raw * self.scaler
        attention_score = torch.softmax(attention_scaled, dim=-1)
        attention_score = self.attn_dropout(attention_score)
        value = torch.matmul(attention_score, value)

        # Reshape back to batch_size, num_pixel, channel_num
        value = value.transpose(1, 2).contiguous()
        value = value.reshape(batch_size, num_pixel, channel_num)

        # Output layer
        output = self.o(value)
        output = self.attn_dropout(output)

        # Reshape back to batch_size, channel_num, height, width
        output = output.permute(0, 2, 1).contiguous().reshape(batch_size, channel_num, height, width)
        output = output + skip_connection
        return output

class UNet(nn.Module):
    def __init__(self, image_channel: int = 3, image_size: int = 64, channel_num: int = 64, embedding_dim: int = 256, channel_per_head: int = 64, dropout_ratio: float = 0.5, block_per_resolution: int = 2, num_step: int = 1000, num_class: int = 1000):
        super().__init__()
        self.encoder = nn.ModuleList([nn.ModuleList([nn.Conv2d(image_channel, channel_num, 3, padding=1)])])
        self.decoder = nn.ModuleList()

        # Create config lists
        attention_resolution: list[int] = [32, 16, 8]
        resolution_list: list[int] = []
        channel_list: list[int] = []
        
        current_image_size: int = image_size
        current_channel_num: int = channel_num
        counter = 1
        while current_image_size >= 8:
            resolution_list.append(current_image_size)
            channel_list.append(channel_num * counter)
            current_image_size = int(current_image_size / 2)
            counter += 1
        skip_channel = [channel_list[0]]

        # Create positional embedding
        self.positional_embedding = nn.Parameter(torch.randn(num_step, embedding_dim) / embedding_dim ** 0.5)
        self.class_embedding = nn.Parameter(torch.randn(num_class, embedding_dim) / embedding_dim ** 0.5)

        # Encoder
        for i in range(len(channel_list)):
            for _ in range(block_per_resolution):
                layer = nn.ModuleList()
                layer.append(BigGANsBlock(channel_list[i], embedding_dim))
                if resolution_list[i] in attention_resolution:
                    layer.append(Attention(channel_list[i]))
                self.encoder.append(layer)
                skip_channel.append(channel_list[i])

            # Down projection
            if i != len(channel_list)-1:
                layer = nn.ModuleList()
                layer.append(BigGANsBlock(channel_list[i], embedding_dim, out_channel=channel_list[i + 1], down=True))
                self.encoder.append(layer)
                skip_channel.append(channel_list[i+1])

        # Bottleneck
        self.bottle_neck = nn.ModuleList([
            BigGANsBlock(channel_list[-1], embedding_dim),
            Attention(channel_list[-1]),
            BigGANsBlock(channel_list[-1], embedding_dim),
        ])

        # Decoder
        for i in range(len(channel_list)-1, -1, -1):
            for _ in range(block_per_resolution):
                layer = nn.ModuleList()
                layer.append(BigGANsBlock(channel_list[i] + skip_channel.pop(), embedding_dim, out_channel=channel_list[i]))
                if resolution_list[i] in attention_resolution:
                    layer.append(Attention(channel_list[i]))
                self.decoder.append(layer)

            # Up projection
            if i != 0:
                layer = nn.ModuleList()
                layer.append(BigGANsBlock(channel_list[i] + skip_channel.pop(), embedding_dim, out_channel=channel_list[i - 1], up=True))
                self.decoder.append(layer)
            else:
                layer = nn.ModuleList()
                layer.append(BigGANsBlock(channel_list[i] + skip_channel.pop(), embedding_dim, out_channel=channel_list[0]))
                self.decoder.append(layer)

        # Output kernels to change back to image channel
        self.out = nn.Sequential(
            nn.BatchNorm2d(channel_list[0], affine=False),
            nn.SiLU(),
            nn.Conv2d(channel_list[0], image_channel, kernel_size = 1),
        )

    def forward(self, tensor: torch.Tensor, time_step: torch.Tensor, label: torch.Tensor = None) -> torch.Tensor:
        embedding = self.positional_embedding[time_step]

        if label != None:
            label_embedding = self.class_embedding[label]
            embedding = embedding + label_embedding

        skip_connection = []

        # Encoder
        for layer in self.encoder:
            for module in layer:
                if(isinstance(module, BigGANsBlock)):
                    tensor = module(tensor, embedding)
                else:
                    tensor = module(tensor)

            skip_connection.append(tensor)

        # Bottleneck
        for module in self.bottle_neck:
            if(isinstance(module, BigGANsBlock)):
                tensor = module(tensor, embedding)
            else:
                tensor = module(tensor)

        # Decoder
        for layer in self.decoder:
            tensor = torch.concatenate((tensor, skip_connection.pop()), dim = 1)
            for module in layer:
                if(isinstance(module, BigGANsBlock)):
                    tensor = module(tensor, embedding)
                else:
                    tensor = module(tensor)

        tensor = self.out(tensor)

        return tensor

In [None]:
# Define image transformations for preprocessing
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
])

def valid_image_folder(path: str) -> bool:
    # Check if file starts with '._' or ends with '.DS_Store'
    filename = os.path.basename(path)
    if filename.startswith("._") or filename == ".DS_Store": # Stupid MacOS
        return False
    
    return True

# Use ImageFolder to automatically label images based on folder names
dataset = datasets.ImageFolder(root=dataset_filepath, is_valid_file=valid_image_folder, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
g = torch.Generator().manual_seed(seed)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size], generator=g
)

# Create DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [4]:
# Define image transformations for preprocessing
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
])

# Load the training set
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Load the test set
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create a DataLoader for the combined dataset
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

In [5]:
# Instantiate the model
unet = UNet(image_size=image_size, num_class=num_class, channel_num=channel_num, dropout_ratio=dropout_ratio, block_per_resolution=block_per_resolution).to(device)
print("This model has", sum(p.numel() for p in unet.parameters()), "parameters.")
loss_train = []
loss_valid = []

This model has 41046019 parameters.


In [6]:
# util.py
beta_min = 1e-4
beta_max = 0.02

# Constants required for forward process
beta = torch.linspace(beta_min, beta_max, num_steps)
alpha = 1 - beta
alphas_cumprod = torch.cumprod(alpha, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).reshape(-1, 1, 1, 1).to(device)
one_minus_alphas_cumprod = 1 - alphas_cumprod
sqrt_one_minus_alphas_cumprod = torch.sqrt(one_minus_alphas_cumprod).reshape(-1, 1, 1, 1).to(device)

# Constants required for backward process
one_over_sqrt_alpha = 1/torch.sqrt(alpha).view(-1, 1, 1, 1).to(device)
one_minus_alpha = (1 - alpha).view(-1, 1, 1, 1).to(device)
sqrt_beta = torch.sqrt(beta).view(-1, 1, 1, 1).to(device)

def forward_process(images: torch.Tensor, timesteps: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    batched_sqrt_alphas_cumprod = sqrt_alphas_cumprod[timesteps]
    batched_sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timesteps]
    noise = torch.randn_like(images)

    return batched_sqrt_alphas_cumprod * images + batched_sqrt_one_minus_alphas_cumprod * noise, noise

def sampling(model: UNet, labels: torch.Tensor, cfg_scale: int = 3) -> None:    
    with torch.no_grad():
        x = torch.randn(labels.shape[0], 3, image_size, image_size).to(device)

        for i in tqdm(range(num_steps-1, -1, -1)):
            t = torch.tensor([i]*labels.shape[0]).to(device)

            # Classifier free guidance
            predicted_noise_no_label = model(x, t, None)
            predicted_noise_with_label = model(x, t, labels)
            predicted_noise = torch.lerp(predicted_noise_no_label, predicted_noise_with_label, cfg_scale)

            # Sample for input to next step
            if(i == 0):
                noise = torch.zeros_like(x).to(device)
            else:
                noise = torch.randn_like(x).to(device)
            x = one_over_sqrt_alpha[t] * (x - ((one_minus_alpha[t])/(sqrt_one_minus_alphas_cumprod[t]))*predicted_noise) + sqrt_beta[t] * noise

        # Turn back into [0, 1] range
        mean = torch.tensor(channel_mean, device=device).view(1, 3, 1, 1)
        std = torch.tensor(channel_std, device=device).view(1, 3, 1, 1)
        
        x = x * std + mean
        x = x.clamp(0, 1)

        # Visualize it
        x = x.cpu().permute(0, 2, 3, 1).numpy()
        for i in range(x.shape[0]):
            plt.imshow(x[i])
            plt.axis('off')
            plt.show()

In [7]:
# Set up optimizer and loss
optimizer = opt.AdamW(unet.parameters(), lr = lr, weight_decay = weight_decay)
criterion = nn.MSELoss()

In [8]:
total_steps = T_max=epochs*len(train_loader)
warmup_steps = int(total_steps * 0.05)

# Warmup: LR linearly increases from 0 → base LR over warmup_steps
warmup_scheduler = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)

# Cosine annealing: after warmup
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=(total_steps - warmup_steps), eta_min=1e-5)

# Combine them
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_steps])

In [9]:
for epoch in range(epochs):
    train_loss_list = []
    valid_loss_list = []
    unet.train()
    for images, label in tqdm(train_loader):
        # Zero out grad
        optimizer.zero_grad()

        # Preparing for forward pass
        images = images.to(device)
        label = label.to(device)
        time_step = torch.randint(1, num_steps, size = (images.shape[0], )).to(device)
        x_t, noise = forward_process(images, time_step)

        # Classifier free guidance.
        if random.random() < 0.1:
            label = None

        # Forward pass
        predicted_noise = unet(x_t, time_step, label)
        loss = criterion(predicted_noise, noise)

        # Back propagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(unet.parameters(), max_norm=1.0)
        optimizer.step()

        # Record loss
        train_loss_list.append(loss.item())

        # Step the learning rate
        scheduler.step()

    unet.eval()
    with torch.no_grad():
        for images, label in tqdm(val_loader):
            # Preparing for forward pass
            images = images.to(device)
            label = label.to(device)
            time_step = torch.randint(1, num_steps, size = (images.shape[0], )).to(device)
            x_t, noise = forward_process(images, time_step)

            # Forward pass
            predicted_noise = unet(x_t, time_step, label)
            loss = criterion(predicted_noise, noise)
            valid_loss_list.append(loss.item())

    print(f"Epoch #{epoch}")
    print(f"Current learning rate is {optimizer.param_groups[0]['lr']}")
    print("Train Loss is:", sum(train_loss_list)/len(train_loss_list))
    loss_train.append(sum(train_loss_list)/len(train_loss_list))
    print("Valid Loss is:", sum(valid_loss_list)/len(valid_loss_list))
    loss_valid.append(sum(valid_loss_list)/len(valid_loss_list))
    if epoch % 100 == 0:
        label = torch.tensor([0, 1]).to(device)
        sampling(unet, label)
    torch.mps.empty_cache()

 48%|███████████████████▊                     | 404/834 [07:27<07:56,  1.11s/it]


KeyboardInterrupt: 

In [None]:
torch.save(unet, "batch_64_1000_epoch_CIFAR10.pth")