# Homework 1: Score Matching and Langevin Dynamics

In this section, we will implement Langevin Dynamics on a simple 2D Mixture of Gaussians Model.

It is prohibited to use the library ```torch.distributions``` throughout this homework, as well as any other libraries that do similar calculations such as ```scipy.stats```.

Note that GPUs will not help in any part except the final section with the generation of images (everything else runs in less than a second), so it is recommended to run this part of the homework on a local CPU.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
import numpy as np
import random
import matplotlib.pyplot as plt

## Part 1: Computing the density and score of a Gaussian Mixture Model

A Gaussian Mixture Model is a distribution that is the combination of several Gaussian distributions, such as the following:

$$
X | Y \sim \mathcal{N}(\mu_1, \sigma_1 I)Y + \mathcal{N}(\mu_2, \sigma_2 I)(1 - Y)
$$

where $Y \sim \text{Ber}(p)$. In our case, $X \in \mathbb{R}^2$.

Implement the following functions to compute the density and score of a GMM. You can add any other functions required. For score computation, do not use torch's autograd (it will greatly hinder the performance).

In [None]:
mu_1 = torch.tensor([1., 1.])
mu_2 = torch.tensor([-1., -1.])

sigma_1 = torch.sqrt(torch.tensor(0.1))
sigma_2 = torch.sqrt(torch.tensor(0.1))

p_first = torch.tensor(0.15)  # The probability that I will sample from the first distribution.

In [None]:
def density_gmm(X, mu_1, mu_2, sigma_1, sigma_2, p):
    """
    :param X: input data: tensor of shape (N, 2)
    :param mu_1: mean of the first Gaussian
    :param mu_2: mean of the second Gaussian
    :param sigma_1: standard deviation of the first Gaussian
    :param sigma_2: standard deviation of the second Gaussian
    :param p: probability of the first Gaussian
    :return: density of the Gaussian Mixture Model at point X: tensor of shape (N,)
    """
    pass

In [None]:
def score_gmm(X, mu_1, mu_2, sigma_1, sigma_2, p):
    """
    :param X: input data: tensor of shape (N, 2)
    :param mu_1: mean of the first Gaussian: tensor of shape (2,)
    :param mu_2: mean of the second Gaussian: tensor of shape (2,)
    :param sigma_1: standard deviation of the first Gaussian: tensor of shape (1,)
    :param sigma_2: standard deviation of the second Gaussian: tensor of shape (1,)
    :return: score of the Gaussian Mixture Model at point X: tensor of shape (N, 2)
    """
    pass

### Visualization

Using a heatmap (with plt.imshow) and an arrow plot (plt.quiver), generate a plot of the density function and score.

In [None]:
# Steps:
# 1. Create a grid from -1.5 to 1.5 in both dimensions, with 100 points in each dimension (using torch.linspace and torch.meshgrid)
# 2. Compute the density of the Gaussian Mixture Model at each point on the grid and store it in a tensor of shape (100, 100)
# 3. For the quiver plot, create a grid from -1.5 to 1.5 in both dimensions, with 20 points in each dimension
# 4. Compute the score of the Gaussian Mixture Model at each point on the grid and store it in a tensor of shape (20, 20, 2)
# 5. Create a quiver plot with the grid points as the X and Y coordinates and the score as the U and V components
# 6. Create a heatmap of the density, with the grid points as the X and Y coordinates and the density as the color, becoming darker as the density is larger

DENSITY_GRID_SIZE = 100
SCORE_GRID_SIZE = 20

density_grid = torch.zeros(DENSITY_GRID_SIZE, DENSITY_GRID_SIZE)  # Replace this with the relevant matrix, using torch.linspace and torch.meshgrid
score_grid = torch.zeros(SCORE_GRID_SIZE, SCORE_GRID_SIZE, 2)  # Replace this with the relevant matrix, using torch.linspace and torch.meshgrid
density_values = torch.zeros(DENSITY_GRID_SIZE, DENSITY_GRID_SIZE)  # Fill this with the density values
score_values = torch.zeros(SCORE_GRID_SIZE, SCORE_GRID_SIZE, 2)  # Fill this with the score values

#### ADD YOUR CODE HERE ####



#### DONE WITH SOLUTION ####

plt.imshow(density_values, extent=(-1.5, 1.5, -1.5, 1.5), origin='lower', cmap='Oranges', alpha=0.6)
plt.quiver(score_grid[:, :, 0], score_grid[:, :, 1], score_values[:, :, 0], score_values[:, :, 1], color='black')
plt.title('Score Values and Density Heatmap of Gaussian Mixture Model')
plt.show()

Example output on standard Gaussian:

![image.png](attachment:image.png)

## Part 2: Training a score network

Implement a model that is able to compute the score of a given point sampled from the GMM.

In [None]:
class ScoreMatcher(nn.Module):
    pass

In [None]:
def train(mu_1, mu_2, sigma_1, sigma_2, p, train_size, num_epochs, lr):
    """ Return a trained ScoreMatcher model """
    pass

Print the loss value at each epoch. Make sure the loss is lower than 0.01

In [None]:
model = train(mu_1, mu_2, sigma_1, sigma_2, p_first, 1000, 1000, 0.01)

Create a similar visualization to the visualization you created with the true score, replacing the true score with the model.

In [None]:
density_grid = None
score_grid = None
density_values = None
score_values = None

#### ADD YOUR CODE HERE ####


#### DONE WITH SOLUTION ####

plt.imshow(density_values, extent=(-1.5, 1.5, -1.5, 1.5), origin='lower', cmap='Oranges', alpha=0.6)
plt.quiver(score_grid[:, :, 0], score_grid[:, :, 1], score_values[:, :, 0], score_values[:, :, 1], color='black')
plt.title('Predicted Score Values of Gaussian Mixture Model')
plt.show()

Compare the visualizations and explain what could be causing the differences.

## Part 3: Sampling

In this section we will sample from our distribution using several methods:
1. Directly sampling using torch functions
2. Using Langevin Dynamics using the model that we trained.
3. Using Annealed Langevin Dynamics using the true score function.

Sample 1000 points from the GMM and display them using a scatterplot.

In [None]:
N = 1000

# YOUR CODE HERE


Sample 1000 points using the score model (Langevin Dynamics).

In [None]:
LANGEVIN_STEP_SIZE = 0.01
LANGEVIN_NUM_STEPS = 10000

### YOUR CODE HERE ###


Sample 1000 points using Annealed Langevin Dynamics and the true score function. Show the sample after applying each noise scale.
Remember that as the noise is smaller, we want to decrease the step size of Langevin Dynamics, so remember to update the step size according to the following formula:

$$\eta_t = \eta_\text{original} \cdot \frac{\sigma_t^2}{\sigma_T^2}$$

Because we are using normal distributions, the noised distributions $p_\sigma$ will just have more variance.

In [None]:
NOISE_SCALES = torch.logspace(torch.log10(torch.tensor(2.0)), torch.log10(torch.tensor(0.01)), 10)
LANGEVIN_STEP_SIZE = 1e-5
print(f'The noise being added to the data at each step is: {NOISE_SCALES} (geometric sequence)')

### YOUR CODE HERE ###



## Putting it all together!

In this section you will be using everything you learned to generate images from MNIST, provided to you in your virtual machines (and available in Pytorch). Ideas are based on the paper Generative Modeling by Estimating Gradients of the Data Distribution (2019). You can use the paper's official implementation as a guideline to the solution. It is recommended to go over the first tutorial before solving.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 1 - Denoising Score Matching

When the score function cannot be calculated directly (as above), we cannot simply use the MSE between the model output and the true score. Hence, we updated loss function, suiting it to our annealed noise schedule:

$$\ell_\text{denoise}(\theta; \sigma) := \frac{1}{2}\mathbb{E}_{x\sim p_\text{data}} \mathbb{E}_{\tilde{x}\sim q_\sigma(\cdot | x)} \left[ \Vert s_\theta(\tilde{x}, \sigma) - \log q_\sigma(\tilde{x}|x)\Vert^2 \right]$$
$$  L_\text{denoise}(\theta) = \frac{1}{L}\sum_{i=1}^L \lambda(\sigma_i)\ell_\text{denoise}(\theta; \sigma_i) \equiv \mathbb{E}_{i \sim \text{Uni}(1, L)} \left[ \sigma_i^2\ell_\text{denoise}(\theta; \sigma_i) \right] $$

Where the model is denoted as $s_\theta$ and $\sigma_i$ are the noise values.

In [None]:
def denoising_score_loss(model, x, noise_scales):
    """
    model: The score matching model.
    x: The input data [BATCH_SIZE, WIDTH, HEIGHT]
    noise_scales: The noise scales to use for denoising [NUM_NOISE_SCALES] (sigma_1, sigma_2, ..., sigma_L)
    """

    ### YOUR CODE HERE ###
    pass

### 2 - UNet

A great candidate for a model that receives an image and returns a value that is of the same dimensions as the image is a UNet. Below is a UNet architecture already implemented for you. In this section you don't need to implement / change anything, if you do please specify your changes below.

In [None]:
class CondInstanceNorm(nn.Module):
    def __init__(self, in_channels, n_noise_scale=10, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(n_noise_scale, in_channels))
        self.beta = nn.Parameter(torch.zeros(n_noise_scale, in_channels))
        self.alpha = nn.Parameter(torch.zeros(n_noise_scale, in_channels))
        self.eps = eps

    def forward(self, x, noise_scale_idx):
        bsz = x.shape[0]
        gamma = self.gamma[noise_scale_idx].view(bsz, -1, 1, 1)  # (bsz, in_channels, 1, 1)
        beta = self.beta[noise_scale_idx].view(bsz, -1, 1, 1)
        alpha = self.alpha[noise_scale_idx].view(bsz, -1, 1, 1)


        mu = x.mean(dim=(2, 3), keepdim=True)  # (batch_size, in_channels, 1, 1)
        var = x.var(dim=(2, 3), keepdim=True)  # (batch_size, in_channels, 1, 1)
        sigma = torch.sqrt(var + self.eps)  # (batch_size, in_channels, 1, 1)
        

        x = (x - mu) / sigma  # (batch_size, in_channels, height, width)
        x = gamma * x + beta  # (batch_size, in_channels, height, width)

        m = mu.mean(dim=1, keepdim=True)  # (batch_size, 1, 1, 1)
        if mu.shape[1] == 1:
            s = torch.ones_like(mu)
        else:
            v = mu.var(dim=1, keepdim=True)  # (batch_size, 1, 1, 1)
            s = torch.sqrt(v + self.eps)  # (batch_size, 1, 1, 1)

        x = x + alpha * (mu - m) / s  # (batch_size, in_channels, height, width)
        return x


class ResidualConvUnit(nn.Module):
    def __init__(self, channels, norm=True, kernel_size=3, n_noise_scale=10):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding='same')
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding='same')
        self.norm1 = CondInstanceNorm(channels, n_noise_scale) if norm else None
        self.norm2 = CondInstanceNorm(channels, n_noise_scale) if norm else None
        self.act = nn.ELU()

    def forward(self, x, noise_scale_idx):
        # x: (batch_size, in_channels, height, width)
        h = self.norm1(x, noise_scale_idx) if self.norm1 is not None else x
        h = self.act(h)
        h = self.conv1(h)

        h = self.norm2(h, noise_scale_idx) if self.norm2 is not None else h
        h = self.act(h)
        h = self.conv2(h)
        
        return x + h
    
    

class StridedConvUnit(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, n_noise_scale=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=2)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding='same')
        self.norm1 = CondInstanceNorm(in_channels, n_noise_scale)
        self.norm2 = CondInstanceNorm(out_channels, n_noise_scale)
        self.act = nn.ELU()
    
    def forward(self, x, noise_scale_idx):
        # x: (batch_size, in_channels, height, width)

        h = self.norm1(x, noise_scale_idx)
        h = self.act(h)
        h = self.conv1(h)

        h = self.norm2(h, noise_scale_idx)
        h = self.act(h)
        h = self.conv2(h)
        
        return h


class DilatedConvUnit(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=2, n_noise_scale=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=dilation)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding='same')
        self.norm1 = CondInstanceNorm(in_channels, n_noise_scale)
        self.norm2 = CondInstanceNorm(out_channels, n_noise_scale)
        self.act = nn.ELU()

    def forward(self, x, noise_scale_idx):
        # x: (batch_size, in_channels, height, width)

        h = self.norm1(x, noise_scale_idx)
        h = self.act(h)
        h = self.conv1(h)

        h = self.norm2(h, noise_scale_idx)
        h = self.act(h)
        h = self.conv2(h)
        
        return h
    


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_layers=2, downsample='stride', dilation=2, n_noise_scale=10):
        assert downsample in ['stride', 'dilation']
        super().__init__()
        self.downsample = downsample
        self.main = nn.ModuleList([])
        for _ in range(n_layers):
            self.main.append(ResidualConvUnit(in_channels, n_noise_scale=n_noise_scale))
        
        if downsample == 'stride':
            self.main.append(StridedConvUnit(in_channels, out_channels, n_noise_scale=n_noise_scale))
        elif downsample == 'dilation':
            self.main.append(DilatedConvUnit(in_channels, out_channels, dilation=dilation, n_noise_scale=n_noise_scale))

    def forward(self, x, noise_scale_idx):
        # x: (batch_size, in_channels, height, width)

        for layer in self.main:
            x = layer(x, noise_scale_idx)
        
        return x
    

class AdaptiveConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_noise_scale=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv2 = ResidualConvUnit(out_channels, n_noise_scale=n_noise_scale)
        self.conv3 = ResidualConvUnit(out_channels, n_noise_scale=n_noise_scale)

    def forward(self, x, noise_scale_idx):
        # x: (batch_size, in_channels, height, width)

        h = self.conv1(x)
        h = self.conv2(h, noise_scale_idx)
        h = self.conv3(h, noise_scale_idx)
        
        return h
    

class MultiResolutionFusion(nn.Module):
    def __init__(self, channels, n_noise_scale=10):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding='same')
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding='same')
        self.norm1 = CondInstanceNorm(channels, n_noise_scale)
        self.norm2 = CondInstanceNorm(channels, n_noise_scale)

    def forward(self, x, y=None, noise_scale_idx=0):
        if y is None:
            return x
        else:
            h1 = self.norm1(x, noise_scale_idx)
            h1 = self.conv1(h1)

            h2 = self.norm2(y, noise_scale_idx)
            h2 = self.conv2(h2)

            return h1 + h2
        

class ResidualPoolingBlock(nn.Module):
    def __init__(self, channels, n_noise_scale=10):
        super().__init__()
        self.norm1 = CondInstanceNorm(channels, n_noise_scale)
        self.pool1 = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
        self.norm2 = CondInstanceNorm(channels, n_noise_scale)
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding='same')

    def forward(self, x, noise_scale_idx):
        h = self.norm1(x, noise_scale_idx)
        h = self.pool1(h)
        h = self.norm2(h, noise_scale_idx)
        h = self.conv1(h)

        return h


class ChainedResidualPool(nn.Module):
    def __init__(self, channels, n_noise_scale=10):
        super().__init__()
        self.act = nn.ELU()
        self.pool1 = ResidualPoolingBlock(channels, n_noise_scale=n_noise_scale)
        self.pool2 = ResidualPoolingBlock(channels, n_noise_scale=n_noise_scale)


    def forward(self, x, noise_scale_idx):
        x = self.act(x)
        h = self.pool1(x, noise_scale_idx)
        x = x + h
        h = self.pool2(h, noise_scale_idx)
        x = x + h

        return x

class RefineNetBlock(nn.Module):
    def __init__(self, x1_in, x2_in, channels, n_noise_scale=10):
        super().__init__()
        self.adap_x1 = AdaptiveConvBlock(x1_in, channels, n_noise_scale=n_noise_scale)
        self.adap_x2 = AdaptiveConvBlock(x2_in, channels, n_noise_scale=n_noise_scale)

        self.fusion = MultiResolutionFusion(channels, n_noise_scale=n_noise_scale)
        self.pool = ChainedResidualPool(channels, n_noise_scale=n_noise_scale)

        self.out = ResidualConvUnit(channels, n_noise_scale=n_noise_scale)

    def forward(self, x1, x2=None, noise_scale_idx=0):
        h1 = self.adap_x1(x1, noise_scale_idx)
        h2 = self.adap_x2(x2, noise_scale_idx) if x2 is not None else None
        h = self.fusion(h1, h2, noise_scale_idx)
        h = self.pool(h, noise_scale_idx)
        h = self.out(h, noise_scale_idx)

        return h


class RefineNet(nn.Module):
    def __init__(self, in_channels, hidden_channels=(128, 256), n_noise_scale=10):
        super().__init__()
        self.res1 = ResidualBlock(in_channels, hidden_channels[0], n_layers=2, downsample='stride')
        self.res2 = ResidualBlock(hidden_channels[0], hidden_channels[1], n_layers=2, downsample='dilation', dilation=2)
        ### YOU CAN ADD MORE RESIDUAL BLOCKS HERE ###
        
        self.refine1 = RefineNetBlock(x1_in=hidden_channels[-1], x2_in=hidden_channels[-1], channels=hidden_channels[-1], n_noise_scale=n_noise_scale)
        self.refine2 = RefineNetBlock(x1_in=hidden_channels[-2], x2_in=hidden_channels[-1], channels=hidden_channels[-2], n_noise_scale=n_noise_scale)
        ### EVERY RESIDUAL BLOCK SHOULD BE FOLLOWED BY A REFINE BLOCK ###

        self.up_norm = CondInstanceNorm(hidden_channels[-2], n_noise_scale)
        self.up_conv = nn.ConvTranspose2d(hidden_channels[-2], hidden_channels[-2], kernel_size=3, stride=2, padding=1, output_padding=1)
        self.out = AdaptiveConvBlock(hidden_channels[-2], in_channels, n_noise_scale=n_noise_scale)


    def forward(self, x, noise_scale_idx):
        h1 = self.res1(x, noise_scale_idx)
        h2 = self.res2(h1, noise_scale_idx)

        h = self.refine1(h2, x2=None, noise_scale_idx=noise_scale_idx)  # As we know, U-Nets accept two inputs (previous layer and skip connection)
        h = self.refine2(h1, h, noise_scale_idx)

        h = self.up_norm(h, noise_scale_idx)
        h = self.up_conv(h)
        h = self.out(h, noise_scale_idx)

        return h
    
def q_sample(x, sigma, noise=None):
    # x: (B, C, H, W)
    # sigma: (B, )
    if noise is None:
        noise = torch.randn_like(x)
    while sigma.dim() < x.dim():
        sigma = sigma.unsqueeze(-1)
    return x + sigma * noise
    


### 3 - Training

Train a UNet according to the loss implemented above on MNIST. It is recommended to save checkpoints - the model's state dict in case of the machine turning off / crashing (the code should not need to run more than an hour anyways). All hyperparameters have been provided to you and should work in the case everything was implemented correctly. However, you are free to change any values for better performance.

In [None]:
N_EPOCHS = 30
BATCH_SIZE = 64
NOISE_SCALES = torch.logspace(torch.log10(torch.tensor(1.0)), torch.log10(torch.tensor(0.01)), 10).to(device)
LR = 1e-3  # Learning rate for the training - not the same as the step size of Langevin Dynamics
NUM_LANGEVIN_STEPS = 100  # Number of steps of Langevin Dynamics for each noise scale.
LANGEVIN_STEP_SIZE = 2e-5  # Base step size in Annealed Langevin Dynamics

In [None]:
# use MNIST as train set
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4, drop_last=True)

In [None]:
### YOUR CODE HERE ###


### 4 - Sampling

Generate images using Annealed Langevin Dynamics with the model you have trained above. Visualize the process, you can use the function below or any other method. The images do not need to be perfect but it must be clear that they are handwritten digits with a black background. Example imperfect output:

![image.png](attachment:image.png)

In [None]:
def sample_refinenet(model, shape, noise_scales, device, n_steps, eps, visualize_dynamics_file):
    """Generates images using the model that was trained with score matching.
    noise_scales: Sigmas for annealed LD.
    n_steps: Number of steps for each sigma.
    eps: Step size in LD.
    visualize_dynamics_file: Saves the dynamics (the denoising process of the Langevin dynamics) in a file.
    """
    model.eval()
    with torch.no_grad():
        fig, ax = plt.subplots(shape[0], len(noise_scales), figsize=(len(noise_scales)*2, shape[0]*2))
        fig.suptitle('Sigma')

        ### YOUR CODE HERE ###


In [None]:
# Sample from the RefineNet model