# 1. Define the U-net based Time-dependent Score Network

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools

class TimeEncoding(nn.Module):
    """ Fourier enconding to time. """

    def __init__(self, embed_dim, scale=30.0) -> None:
        super().__init__()
        # randomly sample weights during initialization. These weights are fixed
        # during optimization and are not trainable.
        # randn: random normal noise
        # half for sin, half for cos
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        # expand dimension: x[:, None], W[None, :]
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi   # TODO: syntatic sugar!!!!
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
    

class Dense(nn.Module):
    """ A fully connected layer that reshapes outputs to feature maps. """

    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.dense(x)[..., None, None]  # expand dimensions


class ScoreNet(nn.Module):
    """ Time-dependent score estimiate model based on U-Net. """

    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256) -> None:
        """ Initialize a time-dependent score-based network. """

        super().__init__()
        self.embed = nn.Sequential(TimeEncoding(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim))

        # U-net encoder, space decreases, channel increases.

        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        
        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])

        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # U-net decoder, space increases, channel decreases, including skip connections

        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm5 = nn.GroupNorm(32, num_channels=channels[2])

        self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm6 = nn.GroupNorm(32, num_channels=channels[1])

        self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm7 = nn.GroupNorm(32, num_channels=channels[0])

        self.tconv1 = nn.ConvTranspose2d(channels[0]+ channels[0], 1, 3, stride=1)

        self.act = lambda x: x * torch.sigmoid(x)   # swish activation function
        self.marginal_prob_std = marginal_prob_std

    def forward(self, x, t):
        embed = self.act(self.embed(t))

        # encoder
        h1 = self.conv1(x)

        h1 += self.dense1(embed)    # inject time t
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)

        h2 = self.conv2(h1)
        h2 += self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)

        h3 = self.conv3(h2)
        h3 += self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)

        h4 = self.conv4(h3)
        h4 += self.dense4(embed)
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)

        # decoder
        h = self.tconv4(h4)
        h += self.dense5(embed)
        h = self.tgnorm5(h)
        h = self.act(h)

        h = self.tconv3(torch.cat([h, h3], dim=1)) # skip connection
        h += self.dense6(embed)
        h = self.tgnorm6(h)
        h = self.act(h)

        h = self.tconv2(torch.cat([h, h2], dim=1))
        h += self.dense7(embed)
        h = self.tgnorm7(h)
        h = self.act(h)

        h = self.tconv1(torch.cat([h, h1], dim=1))

        # Normalize output
        # divide the expectation of second order norm
        # equivalent to moving lambda into score net
        # objective: make predict score's 2nd order norm approachs real score's 2nd order norm
        h = h / self.marginal_prob_std(t)[:, None, None, None]

        return h




# 2. Define SDE and Denoising Score Matching Objective

dx = sigma^t * dw

In [19]:
device = 'cpu'
# device = 'cuda'

def marginal_prob_std(t, sigma):
    """ Compute standard deviation at any time t. """

    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2*t) - 1.) / 2. / np.log(sigma))


def diffusion_coeff(t, sigma):
    """ Compute diffuison coefficient at any time t. Note that there is no draft coefficient in this demo. """

    return torch.tensor(sigma**t, device=device)


sigma = 25.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)    # non-parameteric function
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)        # non-parameteric function


In [4]:
def loss_fn(score_model:nn.modules, x:torch.tensor, marginal_prob_std, eps=1e-5):
    """ The loss function for training score-based gnerative models. 
    
    Args:
        model: A PyTorch model instance that represents a time-dependent score-based model.
        x: A mini-batch of training data.
        marginal_prob_std: A function that gives the standard deviation of the perturbation kernel.
        eps: A tolerance value for numerical stability.
    """

    # Step 1: randomly generate [batch_size] float time t from [0.00001, 0.99999], Uniform
    # x.shape[0] = batch_size = 32, each batch uses the same time step.
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps # [batch_size x 1]

    # Step 2: get perturbed_x: sample P_t(x) based on reparameterization
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None, None, None]

    # Step 3: put perturbed sample and time into Score Network to predict score
    score = score_model(perturbed_x, random_t)

    # Step 4: Compute score matching loss
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3))) #TODO: why +

    return loss

In [16]:
from copy import deepcopy

class EMA(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moveing average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e,m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

# 3. Training Socre-based model on MNIST data

In [None]:
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm

score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)

n_epochs = 50 #@param {'type': 'integer'}
## size of a mini-batch
batch_size = 32 #@param {'type': 'integer'}
## learning rate
lr = 1e-4 #@param {'type' : 'number'}

dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True) # gray scale images
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = tqdm.tqdm(range(n_epochs))

ema = EMA(score_model)
for epoch in tqdm_epoch:
    # training speed: Mac CPU: 250s/Epoch, Kaggle GPU: 30s/Epoch, Colab: 35s/Epoch
    avg_loss = 0
    num_items = 0
    for x, y in data_loader:
        # y in [0,9], label
        x = x.to(device) # [32 x 1 x 28 x 28] [Batch_size x Channel x Width x Height]
        loss = loss_fn(score_model, x, marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema.update(score_model)
        avg_loss += loss.item() * x.shape[0] # what is item()
        num_items += x.shape[0]

print('Average ScoreMatching Loss: {:5f}'.format(avg_loss / num_items))
torch.save(score_model.state_dict(), f'ckpt_{n_epochs}.pth')

In [None]:
## The number of sampling steps
num_steps = 500

def euler_sampler(score_model, 
                  marginal_prob_std, 
                  diffusion_coeff, 
                  batch_size=64, 
                  num_steps=num_steps, 
                  device='cuda', 
                  eps=1e-3):

    # Step1: define initial time t and random sample from prior distribution 
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
                * marginal_prob_std(t)[:, None, None, None]
    
    # Step2: define sampling reverse-time grid and time step
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]

    # Step3: solve reverse-time SDE using Euler Algo
    x = init_x
    with torch.no_grad():
        for time_step in tqdm.tqdm(time_step):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)

    return mean_x


# Euler + Langvein Dynamic Sampling (Predictor-Corrector Sampler) to generate higher quality data

In [None]:
## signal to noise ratio
SNR = 0.16
num_steps = 500
langevin_steps = 10

def pc_sampler(score_model, 
               marginal_prob_std, 
               diffusion_coeff, 
               batch_size=64, 
               num_steps=num_steps,
               corrector_steps=10,
               snr=SNR, 
               device='cuda', 
               eps=1e-3):
    """ Generate samples from score-based models with Predictor-Corrector method. 
    
    Args:
        score_model: A PyTorch model instance that represents a time-dependent score-based model.
        marginal_prob_std: A function that gives the standard deviation of the perturbation kernel.
        diffusion_coeff: A function that gives the diffusion coefficient of SDE
        batch_size: The number of samplers to generate by calling this function once.
        num_steps: The number of sampling steps. Equivalent to the number of discretized time steps. 
        corrector_steps: The number of langevin MCMC steps.
        snr: signal to noise ratio.
        device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
        eps: A tolerance value for numerical stability. 

    Returns:
        Samples.
    """

    # Step1: define initial time t and random sample from prior distribution 
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
                * marginal_prob_std(t)[:, None, None, None]
    
    # Step2: define sampling reverse-time grid and time step
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]

    # Step3: alter langevin sampling and Euler Algo
    x = init_x
    with torch.no_grad():
        for time_step in tqdm.tqdm(time_step):
            batch_time_step = torch.ones(batch_size, device=device) * time_step

            # Corrector step (Langevin MCMC)
            grad = score_model(x, batch_time_step)
            grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
            noise_norm = np.sqrt(np.prod(x.shape[1:]))
            langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
            print(f"{langevin_step_size=}") #TODO: typo?

            for _ in range(corrector_steps):
                x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.rand_like(x)
                grad = score_model(x, batch_time_step)
                grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() # reshape(d,-1) : infer the last dimension
                noise_norm = np.sqrt(np.prod(x.shape[1:]))
                langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2
                print(f"{langevin_step_size}")

            # predictor step (Euler-Maruyama)
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
            x = mean_x + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)

    return mean_x

# Orinary Differential Equation

\begin{align*}
d \mathbf(x)
\end{align*}

In [None]:
from scipy import integrate

## The error tolerance for the black-box ODE solver
error_tolerance = 1e-5 

def ode_sampler(score_model, 
                marginal_prob_std,
                diffusion_coeff,
                batch_size=64,
                atol=error_tolerance,
                rtol=error_tolerance,
                device='cuda',
                z=None,
                eps=1e-3
                ):
    
    # Step1 define initial time and x
    t = torch.ones(batch_size, device=device)
    if z is None:
        init_x = torch.randn(batch_size, 1, 28, 28, device=device) * marginal_prob_std(t)[:, None, None, None]
    else:
        init_x = z

    shape = init_x.shape

    # Step2: define score prediction function and ordinary differential function
    def score_eval_wrapper(sample, time_steps):
        """ A wrapper of the score-based model for use by the ODE solver. """

        sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
        time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape(sample.shape[0])
        with torch.no_grad():
            score = score_model(sample, time_steps)
        return score.cpu().numpy().reshape((-1,)).astype(np.float64)
    

    def ode_func(t, x):
        """ The ODE function for use by the ODE solver. """

        time_steps = np.ones((shape[0],)) * t
        g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
        return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)
    

    # Step 3 call ODE solver to compute the predict sample at time t = eps
    res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
    print(f"Number of function evaluation: {res.nfev}")

    x = torch.tensor(res.y[:, -1], device=device).reshape(shape)

    return x

# Import trained model and compare different sampling methods

In [None]:
from torchvision.utils import make_grid
import time

## load the pre-trained checkpoint from disk
device = 'cpu'
# device = 'cuda'

ckpt = torch.load('ckpt.pth', map_location=device)
score_model.load_state_dict(ckpt)

sample_batch_size = 64
sampler = pc_sampler # pc_sampler, euler_sampler or ode_sampler

t1 = time.time()
## Generate samples using the specified sampler.
samples = sampler(score_model, marginal_prob_std_fn, diffusion_coeff_fn, sample_batch_size, device=device)

t2 = time.time()
print(f"{str(sampler)} sampling costs {t2-t1}s")

## sample visualization
samples = samples.clamp(0.0, 1.0)
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))

plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(sample_grid.permute(1,2,0).cpu(), vmin=0., vmax=1.)
plt.show()

# Score-based diffusion model on MNIST experiemnt results.

Train 50 epochs, loss decreases to 16

ODE sampling: fast, low quality

Euler sampling: slow, middle quality

PC sampling: slowest, high quality