# Lab 2 Task 5
Implemented from article: <br>
https://levelup.gitconnected.com/building-stable-diffusion-from-scratch-using-python-f3ebc8c42da3

In [2]:
"""
%pip install einops
"""

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

# misc
import matplotlib.pyplot as plt
import tqdm
import math
import numpy as np
from einops import rearrange # For rearranging tensors

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using: cuda:0


In [13]:
BATCH_SIZE = 64

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_set = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle=True)

In [8]:
# Gausian random features
# Used for time-embedding. When score function is dependent on time,
# Time is essentially many sinusoidal features. 
class GaussianFourierProjection(nn.Module):
    def __init__(self, embed_dim, scale = 30):
        super().__init__()
        # fixed during opt => no training
        self.W = nn.Parameter(torch.randn(embed_dim//2) * scale, requires_grad=False)

    def forward(self, x):
        # project the tensor into sine and cosine components (concatenated for output)
        # see article for more info
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat( [torch.sin(x_proj), torch.cos(x_proj)], dim=1)

In [9]:
# FC later: output -> feature map
class Dense(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        # Slightly bewildering but simply casts 2D tensor -> 4D tensor
        # This is to make it suitable as a feature map for later layers
        return self.dense(x)[..., None, None] 

In [None]:
# TBI
def score_model():
    pass

In [14]:
# Sampler 
# Sampling: Random img -> guess noise -> remove said noise -> repeat
# Can be done in several ways. Here with the Euler-Murayama method (same as article)
def em_sampler(model, marginal_prob_std, diff_coef, 
               batch_size = BATCH_SIZE, num_steps = 500, y_tensor = None):
    eps = 1e-3
    x_shape = (1,28,28)

    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, *x_shape, device=device) * marginal_prob_std(t)[:,None,None,None]
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    x = init_x

    with torch.no_grad():
        for time_step in tqdm(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diff_coef(batch_time_step)
            mean_x = x + (g**2)[:,None,None,None] * score_model(x, batch_time_step, y=y_tensor) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.rand_like(x)
    
    return mean_x

In [None]:
# Main score model 
# We opt for a U-Net architecture as described in the article.
# We use the res variant, where skip-connections add instead of concat


In [None]:
# Training
# 
model = ...

epochs = 100
batch_size = 1024

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)