In [None]:
from math import prod

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

from torchvision.datasets import MNIST

In [None]:
train_data = MNIST(root="data/", train=True, download=True)

In [None]:
class NoiseScheduler:
    def __init__(self, T = 1000, beta_min = 1e-4, beta_max = 0.02):
        self.T = T
        self.beta_min = beta_min
        self.beta_max = beta_max

    def get_beta_t(self, t):
        # Linear noise-scaling
        beta_t = self.beta_min + (t / self.T) * (self.beta_max - self.beta_min)
        return beta_t

    def get_alpha_t(self, t):
        beta_t = self.get_beta_t(t)
        alpha_t = 1.0 - beta_t
        return alpha_t

    def get_alpha_hat_t(self, t):
        alphas = [self.get_alpha_t(s) for s in range(1, t + 1)]
        alpha_hat = prod(alphas)
        return alpha_hat

    def add_noise(self, x_o, t):
        if t == 0:
            return x_o

        alpha_t = self.get_alpha_hat_t(t)
        epsilon = np.random.normal(
            loc=0, scale=np.sqrt(1.0 - alpha_t), size=x_o.shape
        )

        return np.sqrt(alpha_t) * x_o + epsilon

In [None]:
class Scaling:
    @staticmethod
    def transform(img):
        t_img = np.asarray(img) / 255
        t_img = t_img * 2 - 1
        return t_img

    @staticmethod
    def inverse_transform(output):
        output = (output + 1) / 2
        output = output * 255
        output = np.uint8(output)
        noise_img = Image.fromarray(output)
        return noise_img


Ts = [0, 1, 5, 50, 500]
scheduler = NoiseScheduler()
img = train_data[0][0]

f, axarr = plt.subplots(1, len(Ts), figsize=(12,6))

for idx, t in enumerate(Ts):
    o = scheduler.add_noise(Scaling.transform(img), t)
    axarr[idx].imshow(Scaling.inverse_transform(o))
    axarr[idx].set_title(f"t={t}")