In [1]:
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch
import torchvision.transforms as transforms

# Get data

In [2]:
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

next(iter(trainloader))[0].shape

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:04<00:00, 2289354.05it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 278435.09it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2338590.87it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1344237.14it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






torch.Size([64, 1, 28, 28])

# Noise Scheduler

betas = $\beta$ uniformly distributed

alphas = 1 - $\beta$ 

alpha_cum = $\overline{\alpha_t}$ = $\prod_{i=1}^t \alpha_i$

sqrt_alpha_cum = $\sqrt{\overline{\alpha_t}}$

sqrt_1_minus_alpha_cum = $\sqrt{1 -\overline{\alpha_t}}$

In [3]:
class NoiseScheduler:
    def __init__(self, t, beta_start, beta_end):
        self.t = t
        self.beta_start = beta_start
        self.beta_end = beta_end
        
        self.betas = torch.linspace(beta_start, beta_end, t)
        self.alphas = 1 - self.betas
        self.alphas_cum = torch.cumprod(self.alphas, 0)
        self.sqrt_alpha_cum = torch.sqrt(self.alphas_cum)
        self.sqrt_1_minus_alpha_cum = torch.sqrt(1 - self.alphas_cum)
    
    def add_noise(self, batch, noise, t):
        batch_shape = batch.shape
        batch_size = batch_shape[0]

        sqrt_alpha_cum = self.sqrt_alpha_cum[t].reshape(batch_size)
        sqrt_1_minus_alpha_cum = self.sqrt_1_minus_alpha_cum[t].reshape(batch_size)

        for _ in range(len(batch_shape)-1):
            sqrt_alpha_cum = sqrt_alpha_cum.unsqueeze(-1)
            sqrt_1_minus_alpha_cum = sqrt_1_minus_alpha_cum.unsqueeze(-1)
        
        return sqrt_alpha_cum * batch + sqrt_1_minus_alpha_cum * noise
    
    def sample_prev_timestep(self, xt, noise_pred, t):
        x0 = (xt - (self.sqrt_1_minus_alpha_cum[t] * noise_pred)) / self.sqrt_alpha_cum[t]
        x0 = torch.clamp(x0, -1, 1)

        mean = xt - ((self.betas[t] * noise_pred) /(self.sqrt_1_minus_alpha_cum[t]))
        mean = mean /torch.sqrt(self.alphas[t])

        if t == 0:
            return mean, x0
        else:
            variance = (1 - self.alphas_cum[t-1]) / (1 - self.alpha_cum[t])
            variance = variance * self.betas[t]
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)
        
        return mean + sigma*z, x0
      

