In [None]:
import torch as th
from torch.utils.data import DataLoader
import gzip
import pickle as pkl
import matplotlib.pyplot as plt
from music_diffusion_model.networks import Noiser, Denoiser
from music_diffusion_model.data import MNISTDataset

In [None]:
steps = 1000
s = 1e-2  # +/- 15bits
abscisse = th.arange(0, steps + 1)
f_values = th.pow(
            th.cos(
                0.5
                * th.pi
                * (abscisse / steps  + s)
                / (1 + s)
            ),
            2.0,
        )

alphas_cum_prod = f_values[1:] / f_values[0]
alphas_cum_prod_prev = f_values[:-1] / f_values[0]


betas = 1 - alphas_cum_prod / alphas_cum_prod_prev
betas[betas > 0.999] = 0.999

alphas = 1.0 - betas

sqrt_alphas_cum_prod = th.sqrt(alphas_cum_prod)
sqrt_one_minus_alphas_cum_prod = th.sqrt(1 - alphas_cum_prod)

betas_tiddle = (
    betas * (1.0 - th.cat([th.tensor([alphas_cum_prod[0]]), alphas_cum_prod_prev[1:]], dim=0)) / (1 - alphas_cum_prod)
)

In [None]:
plt.plot(abscisse[1:], alphas_cum_prod, color="r")
plt.plot(abscisse[1:], betas, color="g")

In [None]:
betas_values = th.linspace(1e-4, 2e-2, steps=steps + 1)
alphas_values = 1 - betas_values
alphas_cum_prod_values = alphas_values.cumprod(0)

In [None]:
plt.plot(abscisse, alphas_cum_prod_values, color="r")
plt.plot(abscisse, betas_values, color="g")

In [None]:
betas_values = th.pow(abscisse / steps, 2.)
alphas_values = 1 - betas_values
alphas_cum_prod_values = alphas_values.cumprod(0)

In [None]:
plt.plot(abscisse, alphas_cum_prod_values, color="r")
plt.plot(abscisse, betas_values, color="g")

In [None]:
betas_values = 2 ** (abscisse / steps) - 1
alphas_values = 1 - betas_values
alphas_cum_prod_values = alphas_values.cumprod(0)

In [None]:
plt.plot(abscisse, alphas_cum_prod_values, color="r")
plt.plot(abscisse, betas_values, color="g")

In [None]:
betas_values = th.pow(abscisse / steps, 3)
alphas_values = 1 - betas_values
alphas_cum_prod_values = alphas_values.cumprod(0)

In [None]:
plt.plot(abscisse, alphas_cum_prod_values, color="r")
plt.plot(abscisse, betas_values, color="g")

In [None]:
betas_values = -th.log(1 / (abscisse / steps + 1)) / th.log(th.tensor(1e4))
alphas_values = 1 - betas_values
alphas_cum_prod_values = alphas_values.cumprod(0)

In [None]:
plt.plot(abscisse, alphas_cum_prod_values, color="r")
plt.plot(abscisse, betas_values, color="g")

In [None]:
betas_values = th.pow(abscisse / steps, 2) * 1e-2 + 1e-4
alphas_values = 1 - betas_values
alphas_cum_prod_values = alphas_values.cumprod(0)

In [None]:
plt.plot(abscisse, alphas_cum_prod_values, color="r")
plt.plot(abscisse, betas_values, color="g")

In [None]:
betas_values = th.exp(2e-2 * abscisse / steps) - 1.
alphas_values = 1 - betas_values
alphas_cum_prod_values = alphas_values.cumprod(0)

In [None]:
plt.plot(abscisse, alphas_cum_prod_values, color="r")
plt.plot(abscisse, betas_values, color="g")

In [None]:
plt.plot()

# Models

In [None]:
beta_1 = 1e-4
beta_T = 2e-2

In [None]:
betas = th.linspace(beta_1, beta_T, steps=250)

In [None]:
alphas = 1 - betas
alphas_cum_prod = th.cumprod(alphas, dim=0)
sqrt_alphas_cum_prod = th.sqrt(alphas_cum_prod)
sqrt_minus_one_alphas_cum_prod = th.sqrt(1 - alphas_cum_prod)

In [None]:
mnist_dataset = MNISTDataset()
dataloader = DataLoader(mnist_dataset, batch_size=4)

In [None]:
for x in dataloader:
    print(x.size())
    break

In [None]:
x_0 = mnist_dataset[0][0].to(th.float)

In [None]:
t = 2

In [None]:
noise = th.randn_like(x_0)

In [None]:
x_t = sqrt_alphas_cum_prod[t] * x_0 + sqrt_minus_one_alphas_cum_prod[t] * noise

In [None]:
plt.matshow(x_0, cmap="Greys")

In [None]:
plt.matshow(x_t)

In [None]:
steps = 1024

In [None]:
n = Noiser(steps, 1e-4, 0.2)

In [None]:
o, _ = n(x_0[None, None, :, :], th.tensor([[t]]))

In [None]:
o.size()

In [None]:
plt.matshow(o[0, 0, 0])

In [None]:
d = Denoiser(1, steps, 8, 1e-4, 0.2, [(16, 32)], [(32, 16)])

In [None]:
x_0_d = d(o, th.tensor([[t]]))

In [None]:
x_0_d.size()

In [None]:
plt.matshow(x_0_d[0, 0, 0].detach(), cmap="Greys")

In [None]:
o = d.sample(th.randn(1, 1, 32, 32))

In [None]:
plt.matshow(o[0, 0].detach(), cmap="Greys")

In [None]:
print(d)