# Imports

In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import MultivariateNormal, Uniform
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose

device = "cuda" if torch.cuda.is_available() else "cpu"

# Introduction to Gaussian distribution

PDF of the Gaussian distribution is defined as:
$$ p(x) = N(x; \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x-\mu)^2}{2\sigma^2}\right),
 \qquad x \in \mathbb{R}.$$
 where mean $\mu$ tells us what is the "center" of the distribution and variance $\sigma^2$ how the distribution is spread from the mean.

The cool thing about normal distribution is that we can easily change the mean and variance of a random variable. Let's say we have $X \sim N(0, 1)$ and we want to change it to $Y \sim N(\mu, \sigma^2)$. We can do it as follows:
$$ Y = X * \sigma + \mu $$
Note that we're multiplying by standard deviation $\sigma$, not the variance!


In [None]:
def gaussian_pdf(x, mu, sigma):
    return 1 / np.sqrt(2 * np.pi * sigma ** 2) * np.exp(- (x - mu) ** 2 / (2 * sigma ** 2))

def gaussian_sample(n, mu, sigma):
    samples_standard = np.random.randn(n)
    return sigma * samples_standard + mu

In [None]:
x = np.linspace(-4,4,1000)
pdf_standard = gaussian_pdf(x, 0, 1)
pdf_shifted = gaussian_pdf(x, -1, 1)
pdf_scaled = gaussian_pdf(x, 0, 2)

plt.figure(figsize=(10, 6))
plt.plot(x, pdf_standard, label='0, 1')
plt.plot(x, pdf_shifted, label='-1, 1')
plt.plot(x, pdf_scaled, label='0, 4')
plt.xlabel('x')
plt.ylabel('Density')
plt.legend()
plt.show()

In [None]:
mu = 10
sigma = 5

x = np.linspace(-10, 30, 1000)
pdf = gaussian_pdf(x, mu, sigma)

samples = gaussian_sample(10000, mu, sigma)

plt.figure(figsize=(10, 6))
plt.plot(x, pdf)
plt.hist(samples, bins=50, density=True)
plt.xlabel('x')
plt.ylabel('Density')
plt.show()

# KL divergence
KL divergence between two distributions $p, q$ is defined as
$$ KL(p||q) = \int p(z) \log \frac{p(z)}{q(z)} dz $$
It tells us how two distribution differ from each other. Note that it's not simetrical so the order of $p$ and $q$ matters!  

If both the distributions are Gaussian, we can express the KLD in a way easier form:

$$ KL(p||q) = \log\frac{\sigma_q}{\sigma_p} + \frac{\sigma_p^2 + (\mu_p - \mu_q)^2}{2\sigma_q^2} - \frac12 $$

In [None]:
def kld(mu1, log_var1, mu2, log_var2):
    return (
        log_var2 - log_var1 + (np.exp(log_var1) + (mu1 - mu2) ** 2) / (np.exp(log_var2)) - 1
    ) / 2

In [None]:
mu1, log_var1 = 0, np.log(1)
mu2, log_var2 = 0, np.log(4)
kld_12 = kld(mu1, log_var1, mu2, log_var2)
kld_21 = kld(mu2, log_var2, mu1, log_var1)

kld_11 = kld(mu1, log_var1, mu1, log_var1)

print(f'KL(p||q): {kld_12:.4f}')
print(f'KL(q||p): {kld_21:.4f}')
print(f'KL(p||p): {kld_11:.4f}')

# Generating 2D moons

In [None]:
def generate_moons(width=1.0):
    moon1 = [
        [r * np.cos(a) - 2.5, r * np.sin(a) - 1.0]
        for r in np.arange(5 - width, 5 + width, 0.1 * width)
        for a in np.arange(0, np.pi, 0.01)
    ]
    moon2 = [
        [r * np.cos(a) + 2.5, r * np.sin(a) + 1.0]
        for r in np.arange(5 - width, 5 + width, 0.1 * width)
        for a in np.arange(np.pi, 2 * np.pi, 0.01)
    ]
    points = torch.tensor(moon1 + moon2)
    points += torch.rand(points.shape) * width
    return points.float()

In [None]:
class InMemDataLoader(object):
    __initialized = False
    def __init__(self, tensors, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, drop_last=False):
        """A torch dataloader that fetches data from memory."""
        tensors = [torch.tensor(tensor) for tensor in tensors]
        dataset = torch.utils.data.TensorDataset(*tensors)
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            self.batch_size = None
            self.drop_last = None

        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = torch.utils.data.RandomSampler(dataset)
                else:
                    sampler = torch.utils.data.SequentialSampler(dataset)
            batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True

    def __setattr__(self, attr, val):
        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
            raise ValueError('{} attribute should not be set after {} is '
                             'initialized'.format(attr, self.__class__.__name__))

        super(InMemDataLoader, self).__setattr__(attr, val)

    def __iter__(self):
        for batch_indices in self.batch_sampler:
            yield self.dataset[batch_indices]

    def __len__(self):
        return len(self.batch_sampler)

    def to(self, device):
        self.dataset.tensors = tuple(t.to(device) for t in self.dataset.tensors)
        return self

In [None]:
moons = generate_moons(width=1.0)
moons_dl = InMemDataLoader([moons], batch_size=2048, shuffle=True)
plt.scatter(moons[:, 0], moons[:, 1], s=0.5)
plt.show()


# Variational Autoencoder (VAE)

VAE has two modules:
1. Encoder encoding an initial data point $x$ into latent representation $z$. More precisely, the encoder returns $\mu_z$ and $\log \sigma^2_z$ that are used to sample $z$ using reparametrization trick:
$$ z = \exp(\log (\sigma^2_z) / 2) * \epsilon + \mu_z $$
where $\epsilon \sim N(0, I)$.
2. Decoder decoding the latent $z$ into reconstructed $x\_recon$.

Training iteration contains the following steps:
1. Sample data $x$ from the dataset.
2. Use the Encoder to get $\mu_z$ and $\log \sigma^2_z$.
3. Use the reparametrization trick to sample $z$.
4. Use the Decoder to reconstruct $z$ to $x\_recon$.
4. Compute loss:
$$ L = MSE(x\_recon, x) + KL\big(N(\mu_z, \sigma_z^2) || N(0, I)\big) $$

For generation:
1. Sample $z \sim N(O, I)$.
2. Use the Decoder to get $x\_recon$.

In [None]:
class VAE(nn.Module):
    def __init__(self, in_dim=2, hid_dim=128, z_dim=2):
        super(VAE, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.z_dim = z_dim

        self.encoder = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, 2 * z_dim),
        )

        self.decoder = nn.Sequential(
            nn.Linear(z_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, in_dim),
        )

    def forward(self, x):
        sampled_z, z_mu, z_log_var = self.encode(x)
        x_recon = self.decoder(sampled_z)
        return x_recon, z_mu, z_log_var

    def encode(self, x):
        # Implement encoding procedure.
        # First, get z_mu and z_log_var from the encoder.
        # Second, compute z samples using the  the reparametrization trick.
        # TO!DO
        z_mu_log_var = self.encoder(x)
        z_mu, z_log_var = torch.chunk(z_mu_log_var, 2, dim=1)
        # CUT{ sampled_z = TO!DO
        epsilon = torch.randn(z_mu.shape).to(device)
        sampled_z = epsilon * torch.exp(z_log_var / 2) + z_mu
        # CUT}
        return sampled_z, z_mu, z_log_var

In [None]:
def kld(mu1, log_var1, mu2, log_var2):
    return (
        log_var2 - log_var1 + (log_var1.exp() + (mu1 - mu2) ** 2) / (log_var2.exp()) - 1
    ) / 2

def kl_loss(z_mu, z_log_var):
    kl_div = kld(z_mu, z_log_var, torch.zeros(1, device=device), torch.zeros(1, device=device))
    return kl_div.sum() / z_mu.shape[0]

In [None]:
hid_dim = 64
z_dim = 2
lr = 0.0003

vae = VAE(hid_dim=hid_dim, z_dim=z_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=lr)

In [None]:
for i in range(3000):
    recon_loss_acc = 0.0
    kl_acc = 0.0
    vae.train()
    for x, in moons_dl:
        x = x.float().to(device)

        x_recon, z_mu, z_log_var = vae(x)

        recon_loss = F.mse_loss(x_recon, x)
        kl = kl_loss(z_mu, z_log_var)
        loss = recon_loss + kl

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        recon_loss_acc += recon_loss.item() * len(x)
        kl_acc += kl.item() * len(x)

    if i % 200 == 0:
        print(
            f"Epoch: {i} loss: {(recon_loss_acc + kl_acc) / len(moons) :.4f} recon_loss: {recon_loss_acc / len(moons) :.4f} kl_loss: {kl_acc / len(moons) :.4f} avg mean: {z_mu.detach().mean() :.4f} avg std: {torch.exp(z_log_var.detach() / 2).mean() :.4f}"
        )
        vae.eval()
        with torch.no_grad():
            # Reconstruct data
            x_recon = x_recon.cpu()

            plt.scatter(x_recon[:, 0], x_recon[:, 1])
            plt.title("Reconstruction")
            plt.show()

            # Generate new data
            z = torch.randn(500, z_dim).to(device)
            x_gen = vae.decoder(z)
            x_gen = x_gen.cpu()

            plt.scatter(x_gen[:, 0], x_gen[:, 1])
            plt.title("Generation")
            plt.show()

In [None]:
def get_grid(data):
    """Generate a dataset of points that lie on grid and span the given data range."""

    xmin, xmax = np.floor(data.min(0)), np.ceil(data.max(0))
    xg, yg = np.meshgrid(
        np.arange(xmin[0], xmax[0] + 1, 1), np.arange(xmin[1], xmax[1] + 1, 1)
    )
    mxg = np.hstack(
        (
            np.hstack((xg, np.zeros((xg.shape[0], 1)) + np.nan)).ravel(),
            np.hstack((xg.T, np.zeros((xg.shape[1], 1)) + np.nan)).ravel(),
        )
    )
    myg = np.hstack(
        (
            np.hstack((yg, np.zeros((yg.shape[0], 1)) + np.nan)).ravel(),
            np.hstack((yg.T, np.zeros((yg.shape[1], 1)) + np.nan)).ravel(),
        )
    )
    grid = np.vstack((mxg, myg)).T
    return grid

In [None]:
data = np.array(moons)[np.random.permutation(moons.shape[0])[:1000]]
grid = get_grid(data)

data_colors = (data[:, 0] - min(data[:, 0])) / (max(data[:, 0]) - min(data[:, 0]))
data_colors = plt.cm.jet(data_colors)

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)

plt.plot(grid[:, 0], grid[:, 1], color="gray", alpha=0.3)
plt.scatter(data[:, 0], data[:, 1], color=data_colors, s=1.0)
_ = plt.axis("equal")
plt.title("Data in original space")

vae.eval()

plt.subplot(1, 2, 2)

with torch.no_grad():
    enc_grid, _, _ = vae.encode(
        torch.from_numpy(grid).to(device).float()
    )
    enc_data, _, _ = vae.encode(
        torch.from_numpy(data).to(device).float()
    )
enc_grid = enc_grid.cpu().numpy()
enc_data = enc_data.cpu().numpy()

plt.plot(enc_grid[:, 0], enc_grid[:, 1], color="gray", alpha=0.3)
plt.scatter(enc_data[:, 0], enc_data[:, 1], color=data_colors, s=1.0)
_ = plt.axis("equal")
plt.title("Data in latent space")

In [None]:
latent_samples = torch.randn(1000, z_dim)

latent_colors = (latent_samples[:, 0] - min(latent_samples[:, 0])) / (
    max(latent_samples[:, 0]) - min(latent_samples[:, 0])
)
latent_colors = plt.cm.jet(latent_colors.numpy())

latent_grid = get_grid(latent_samples.numpy())

vae.eval()
with torch.no_grad():
    x_gen = vae.decoder(latent_samples.to(device))
    x_gen = x_gen.cpu()

    grid_gen = vae.decoder(
        torch.from_numpy(latent_grid).float().to(device)
    )
    grid_gen = grid_gen.cpu()

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(latent_grid[:, 0], latent_grid[:, 1], color="gray", alpha=0.3)
plt.scatter(latent_samples[:, 0], latent_samples[:, 1], color=latent_colors, s=1)
_ = plt.axis("equal")
plt.title("Z in latent space")

plt.subplot(1, 2, 2)

plt.plot(grid_gen[:, 0], grid_gen[:, 1], color="gray", alpha=0.3)
plt.scatter(x_gen[:, 0], x_gen[:, 1], color=latent_colors, s=1)
_ = plt.axis("equal")
plt.title("Generated data in original space")

# Generative Adversarial Network (GAN)
GAN is built with two networks:
1. Discriminator judging if a given sample is real or no.
2. Generator creating samples from initial noise.

They plan a min-max game with each other. The Generator tries to generate samples that are hard to distinguish from real ones by the Discriminator. Meanwhile, the Discriminator tries to spot the difference between real and fake samples.

Training looks as follows:
1. Get data point $x$.
2. Sample initial noise $z \sim N(0, I)$.
3. Generate $x\_fake$ from $z$ using the Generator.
4. Do binary classification on $x$ and $x\_fake$ with the Discrimator.
5. Compute loss for the Generator:
$$ L_G = \log(D(G(z)))$$
6. Compute loss for the Discrimator:
$$ L_D = \log (D(x)) + \log(1 - D(G(z))) $$

To generate new data:
1. Sample initial noise $z \sim N(0, I)$.
2. Generate $x\_fake$ from $z$ using the Generator.


In [None]:
class Generator(nn.Module):
    def __init__(self, in_dim=2, hid_dim=128, out_dim=2):
        super(Generator, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim

        self.layers = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, out_dim),
        )

    def forward(self, x):
        return self.layers(x)


class Discriminator(nn.Module):
    def __init__(self, in_dim=2, hid_dim=128, out_dim=1):
        super(Discriminator, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim

        self.layers = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, out_dim),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
def generator_loss(DG, eps=1e-6):
    # Define Generator loss. Use eps for numerical stability of log.
    # CUT{ return TO!DO
    loss = torch.log(DG + eps)  # loss = TO!DO
    return -torch.mean(loss)
    # CUT}


def discriminator_loss(DR, DG, eps=1e-6):
    # Define Discriminator loss. Use eps for numerical stability of log.
    # CUT{ return
    loss = torch.log(DR + eps) + torch.log(1 - DG + eps)  # loss = TO!DO
    return -torch.mean(loss)
    # CUT}

In [None]:
z_dim = 2
hid_dim = 64
lr = 0.0001

G = Generator(in_dim=z_dim, hid_dim=hid_dim).to(device)
D = Discriminator(hid_dim=hid_dim).to(device)
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

In [None]:
for i in range(4500):
    G_loss_acc = 0.0
    D_loss_acc = 0.0
    G.train()
    D.train()
    for x, in moons_dl:
        x = x.float().to(device)

        # Generate fake data from z ~ N(0,1).
        # Calculate Generator loss.
        z = torch.randn(x.size(0), z_dim, device=device)
        # CUT{ x_fake = TO!DO  # Use the generator to compute x_Fake
        x_fake = G(z)
        # CUT}

        # make a copy of x_fake and detach it, we'll use the copy to train the Discriminator
        x_fake_detached = x_fake.detach()

        # CUT{ G_loss = TO!DO  # Now use the discriminator and compute generator loss
        G_loss = generator_loss(D(x_fake))
        # CUT}

        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        # Calculate Discriminator loss.
        # Remember to use x_fake_detached to prevent backpropagating through generator!
        # CUT{ D_loss=
        D_loss = discriminator_loss(D(x), D(x_fake_detached))
        # CUT}

        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        G_loss_acc += G_loss.item() * len(x)
        D_loss_acc += D_loss.item() * len(x)

    if i % 100 == 0:
        G.eval()
        with torch.no_grad():
            z = torch.randn(1000, z_dim, device=device)
            x_gen = G(z).cpu()
            plt.scatter(x_gen[:, 0], x_gen[:, 1])
            plt.title(
                f"Epoch: {i} Generator loss: {G_loss_acc / len(moons) :.4f} Discriminator loss: {D_loss_acc / len(moons) :.4f}"
            )
            plt.show()

In [None]:
latent_samples = torch.randn(1000, z_dim)

latent_colors = (latent_samples[:, 0] - min(latent_samples[:, 0])) / (
    max(latent_samples[:, 0]) - min(latent_samples[:, 0])
)
latent_colors = plt.cm.jet(latent_colors.numpy())

latent_grid = get_grid(latent_samples.numpy())

G.eval()
# !: compute the projection into data space of the latent saples and the grid
# CUT{ x_gen = TO!DO\ngrid_gen = TO!DO
with torch.no_grad():
    x_gen = G(latent_samples.to(device)).cpu()
    grid_gen = G(torch.from_numpy(latent_grid).float().to(device)).cpu()
# CUT}

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(latent_grid[:, 0], latent_grid[:, 1], color="gray", alpha=0.3)
plt.scatter(latent_samples[:, 0], latent_samples[:, 1], color=latent_colors, s=1)
_ = plt.axis("equal")
plt.title("Z in latent space")

plt.subplot(1, 2, 2)

plt.plot(grid_gen[:, 0], grid_gen[:, 1], color="gray", alpha=0.3)
plt.scatter(x_gen[:, 0], x_gen[:, 1], color=latent_colors, s=1)
_ = plt.axis("equal")
plt.title("Generated data in original space")

# Diffusion model

Diffusion models can be seen as a combination of hierarchical VAE where there are multiple hidden states and normalizing flows where the goal is to learn invertible mapping from data distribution to Gaussian one.
  
  For a given schedule of noise level $\{\beta_t\}_{t=1}^T$, we can define two processes:
  1. *Forward process* that gradually adds small noise to the initial image $x_0$, ending up in a normally-distributed sample $x_T$, following the equation:
  $$q(\mathbf{x}_t \vert \mathbf{x}_0) = N(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I})$$
  where $\alpha_t = 1 - \beta_t$ and $\bar{\alpha}_t = \prod_{i=1}^t \alpha_i$.

  2. *Backward process* used during inference to remove noise from initial Gaussian sample $x_T$ up until a clean image $x_0$, following the $p(x_{t-1}|x_t)$ distribution.

  The general goal is to learn the $p(x_{t-1}|x_t)$ distribution. However, it is intractable and we just assume it's Gaussian. Then, we do something similar to VAE, and look at the variational posterior $q(x_{t-1} | x_t, x_0)$ that is also Gaussian and is tractable.


We can summarize the training in the following steps:
  1. Sample timesteps for the batch of $x_0$.
  2. Add noise according to $q(\mathbf{x}_t \vert \mathbf{x}_0)$.
  3. Predict noise using denoiser backbone from $x_t$ and timestep.
  4. Calculate MSE loss between the true noise and predicted one.

And finally the inference:
  1. Sample $x_T$ from Gaussian distribution.
  2. Loop over timesteps in reversed order:
    1. Predict noise from the current timestep and $x_t$.
    2. Get parameters for $p(x_{t-1}|x_t)$: mean from $q(x_{t-1} | x_t, x_0)$ and a fixed variance.
    3. Sample $x_{t-1}$ from $p(x_{t-1}|x_t)$.

Note: There are a couple of simplifications used in this exercise (justified in the literature as well):
  1. We use MSE as a loss, omitting ELBO.
  2. We don't learn the variance of $p(x_{t-1}|x_t)$. Instead, we just use a fixed schedule for it.
  3. Since we're working with moons dataset, we don't need to discritize $x_0$.

In [None]:
class ConditionalLinear(nn.Module):
    def __init__(self, in_dim, out_dim, n_timesteps):
        super(ConditionalLinear, self).__init__()
        self.out_dim = out_dim
        self.fc = nn.Linear(in_dim, out_dim)
        self.embed = nn.Embedding(n_timesteps, out_dim)

    def forward(self, x, t):
        x = self.fc(x)
        t_emb = self.embed(t)
        x = t_emb.reshape(-1, self.out_dim) * x
        return x

class Denoiser(nn.Module):
    def __init__(self, data_dim, hid_dim, n_timesteps):
        super(Denoiser, self).__init__()

        self.cl1 = ConditionalLinear(data_dim, hid_dim, n_timesteps)
        self.cl2 = ConditionalLinear(hid_dim, hid_dim, n_timesteps)
        self.cl3 = nn.Linear(hid_dim, data_dim)
        self.activation = nn.SiLU()

    def forward(self, x, t):
        x = self.activation(self.cl1(x, t))
        x = self.activation(self.cl2(x, t))
        return self.cl3(x)

In [None]:
class Diffusion(nn.Module):
  def __init__(self, denoiser, n_timesteps=10, data_dim=2, beta_start=1e-5, beta_end=0.9):
    super(Diffusion, self).__init__()
    self.denoiser = denoiser
    self.n_timesteps = n_timesteps
    self.data_dim = data_dim

    self.beta = nn.Parameter(torch.linspace(beta_start, beta_end, n_timesteps), requires_grad=False)
    self.set_params()

  def forward(self, x0):
    timesteps = torch.randint(self.n_timesteps, (x0.shape[0],)).to(x0.device)
    eps, xt = self.get_noisy_sample(x0, timesteps)

    eps_pred = self.denoiser(xt, timesteps)

    loss = F.mse_loss(eps_pred, eps)
    return loss

  def sample(self, batch_size=1):
    with torch.no_grad():
      xt = torch.randn(batch_size, self.data_dim).to(self.device)

      for t in reversed(range(self.n_timesteps)):
        timesteps = torch.tensor([t] * xt.shape[0]).to(xt.device)
        eps_pred = self.denoiser(xt, timesteps)
        mean, logvar = self.get_p_params(xt, timesteps, eps_pred)
        noise = torch.randn_like(xt) #if t > 0 else torch.zeros_like(xt)
        xt = mean + noise * torch.exp(logvar / 2)

    return xt

  def get_p_params(self, xt, timesteps, eps_pred):
    # we use fixed variance schedule for p(x_{t-1} | x_t)
    p_logvar = self.broadcast(torch.log(self.beta[timesteps]), dim=xt.ndim)

    # get mean for p(x_{t-1} | x_t)
    p_mean = self.get_q_params(xt, timesteps, eps_pred)
    return p_mean, p_logvar

  def get_q_params(self, xt, timesteps, eps_pred):
    # TO!DO
    # predict x0 from xt and eps_pred
    coef1_x0 = self.broadcast(self.coef1_x0[timesteps], dim=xt.ndim)
    coef2_x0 = self.broadcast(self.coef2_x0[timesteps], dim=xt.ndim)
    x0 = coef1_x0 * xt - coef2_x0 * eps_pred

    # TO!DO
    # q(x_{t-1} | x_t, x_0)
    coef1_q = self.broadcast(self.coef1_q[timesteps], dim=xt.ndim)
    coef2_q = self.broadcast(self.coef2_q[timesteps], dim=xt.ndim)
    q_mean = coef1_q * x0 + coef2_q * xt

    return q_mean

  def get_noisy_sample(self, x0, timesteps):
    # TO!DO
    # sample from q(xt | x0)
    eps = torch.randn_like(x0)
    xt = self.broadcast(torch.sqrt(self.alpha_bar[timesteps]), dim=x0.ndim) * x0 + self.broadcast(torch.sqrt(1 - self.alpha_bar[timesteps]), dim=x0.ndim) * eps
    return eps, xt

  def set_params(self):
    # helper method for all of the constants needed
    self.alpha = nn.Parameter(1 - self.beta, requires_grad=False)
    self.alpha_bar = nn.Parameter(torch.cumprod(self.alpha, dim=0), requires_grad=False)
    self.alpha_bar_prev = nn.Parameter(torch.cat([torch.ones(1,), self.alpha_bar[:-1]]), requires_grad=False)

    # to caluclate x0 from eps_pred
    self.coef1_x0 = nn.Parameter(torch.sqrt(1.0 / self.alpha_bar), requires_grad=False)
    self.coef2_x0 = nn.Parameter(torch.sqrt(1.0 / self.alpha_bar - 1), requires_grad=False)

    # for q(x_{t-1} | x_t, x_0)
    self.coef1_q = nn.Parameter(self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar), requires_grad=False)
    self.coef2_q = nn.Parameter((1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar), requires_grad=False)

  def broadcast(self, arr, dim=2):
    # helper method to increase tensor's dimension number
    while arr.dim() < dim:
        arr = arr[:, None]
    return arr.to(self.device)

  @property
  def device(self):
    return next(self.denoiser.parameters()).device

In [None]:
hid_dim = 128
data_dim = 2
n_timesteps = 10
lr = 0.001

denoiser = Denoiser(data_dim, hid_dim, n_timesteps)
diffusion = Diffusion(denoiser, n_timesteps=n_timesteps, data_dim=data_dim).to(device)
optimizer = optim.Adam(diffusion.parameters(), lr=lr)

In [None]:
timesteps = torch.arange(n_timesteps)
x0 = next(iter(moons_dl))[0]

In [None]:
for t in timesteps[::1]:
  t = torch.zeros(x0.shape[0]) + t
  xt = diffusion.get_noisy_sample(x0, t.int())[1]
  xt = xt.cpu()
  plt.figure()
  plt.xlim([-9, 9])
  plt.ylim([-6, 6])
  plt.scatter(xt[:, 0], xt[:, 1])
  plt.title(f'Noisy samples for timestep {t[0].int()}')

In [None]:
for i in range(3000):
    loss_acc = 0.0
    diffusion.train()
    for x0, in moons_dl:
      x0 = x0.float()

      optimizer.zero_grad()
      loss = diffusion(x0)
      loss.backward()
      optimizer.step()

      loss_acc += loss.item()

    if i % 100 == 0:
      diffusion.eval()
      with torch.no_grad():
        samples = diffusion.sample(2048)
        samples = samples.cpu()

        plt.scatter(samples[:, 0], samples[:, 1])
        plt.title(f"Epoch: {i} loss: {loss_acc / len(moons_dl) :.4f}")
        plt.show()