<a href="https://colab.research.google.com/github/arturovallemacias/diffusion_models/blob/main/reverse_me.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
USERNAME="arturovallemacias"
TOKEN="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"

# Configura el helper de credenciales para almacenarlas en caché
!git config --global credential.helper store

# Clona el repositorio utilizando el token personal
!git clone https://$USERNAME:$TOKEN@github.com/$USERNAME/diffusion_models.git


!git config --global user.email "arturo_valle@live.com"
!git config --global user.name "arturovallemacias"

Cloning into 'diffusion_models'...
remote: Enumerating objects: 24, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 24 (delta 9), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (24/24), 1.28 MiB | 15.46 MiB/s, done.
Resolving deltas: 100% (9/9), done.


In [20]:
import requests
from pathlib import Path

# Download helper functions from Learn PyTorch repo (if not already downloaded)
if Path("helper_functions.py").is_file():
  print("helper_functions.py already exists, skipping download")
else:
  print("Downloading helper_functions.py")
  request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")
  with open("helper_functions.py", "wb") as f:
    f.write(request.content)

from helper_functions import plot_predictions, plot_decision_boundary

Downloading helper_functions.py


In [8]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll

In [3]:
def sample_batch(batch_size, device='cpu'):
    data, _ = make_swiss_roll(batch_size)
    data = data[:, [2, 0]] / 10
    data = data * np.array([1, -1])
    return torch.from_numpy(data).to(device)

In [17]:
class MLP(nn.Module):

    def __init__(self, N=40, data_dim=2, hidden_dim=64):
        super(MLP, self).__init__()

        self.network_head = nn.Sequential(nn.Linear(data_dim, hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(hidden_dim, hidden_dim),
                                          nn.ReLU(),)

        self.network_tail = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                                         nn.ReLU(),
                                                         nn.Linear(hidden_dim, data_dim * 2),) for t in range(N)])

    def forward(self, x, t):

        h = self.network_head(x) # [batch_size, hidden_dim]
        tmp = self.network_tail[t](h) # [batch_size, data_dim * 2]

        mu, h = torch.chunk(tmp, 2, dim=1)
        var = torch.exp(h)
        std = torch.sqrt(var)

        return mu, std

In [25]:
model = MLP()
t = 5
x = torch.randn((64, 2))
mu, std = model(x, t)


In [None]:
print(mu.shape)
print(std.shape)

In [23]:
class DiffusionModel():

    def __init__(self, T, model: nn.Module, dim=2):

        self.betas = torch.sigmoid(torch.linspace(-18, 10, T)) * (3e-1 - 1e-5) + 1e-5
        self.alphas = 1 - self.betas
        self.alphas_bar = torch.cumprod(self.alphas, 0)

        self.T = T
        self.model = model
        self.dim = dim

    def forward_process(self, x0, t):
        """
        :param t: Number of diffusion steps
        """

        assert t > 0, 't should be greater than 0'
        assert self.T <= self.T, f't should be lower or equal than {self.T}'

        t = t -1 # Because we start indexing at 0

        mu = torch.sqrt(self.alphas_bar[t]) * x0
        std = torch.sqrt(1 - self.alphas_bar[t])
        epsilon = torch.randn_like(x0)

        return mu + epsilon * std # data ~ N(mu, std)

    def reverse_process(self, xt, t):
        """
        :param t: Number of diffusion steps
        """

        assert t > 0, 't should be greater than 0'
        assert self.T <= self.T, f't should be lower or equal than {self.T}'

        t = t - 1 # Because we start indexing at 0

        mu, std = self.model(xt, t)
        epsilon = torch.randn_like(xt)

        return mu + epsilon * std # data ~ N(mu, std)


    def sample(self, batch_size):

        noise = torch.randn((batch_size, self.dim))
        x = noise

        samples = [x]
        for t in range(self.T, 0, -1):

            if not (t == 1):
                x = self.reverse_process(x, t)


            samples.append(x)

        return samples[::-1]


In [27]:
mlp_model = torch.load("model_paper1")

FileNotFoundError: ignored

In [26]:
x0 = sample_batch(3_000)
mlp_model = torch.load("model_paper1")
model = DiffusionModel(40, model)
xT = model.forward_process(x0, 20)

In [None]:
samples = model.sample(1000)

In [None]:
t = 40
plt.scatter(samples[t][:, 0].data.numpy(), samples[t][:, 1].data.numpy())

In [None]:
print(xT.mean(0))
print(xT.std(0))

In [None]:
fontsize = 14
fig = plt.figure(figsize=(10, 3))

data = [x0, model.forward_process(x0, 20), model.forward_process(x0, 40)]
for i in range(3):

    plt.subplot(1, 3, 1+i)
    plt.scatter(data[i][:, 0].data.numpy(), data[i][:, 1].data.numpy(), alpha=0.1)
    plt.xlim([-2, 2])
    plt.ylim([-2, 2])
    plt.gca().set_aspect('equal')

    if i == 0: plt.ylabel(r'$q(\mathbf{x}^{(0..T)})$', fontsize=fontsize)
    if i == 0: plt.title(r'$t=0$', fontsize=fontsize)
    if i == 1: plt.title(r'$t=\frac{T}{2}$', fontsize=fontsize)
    if i == 2: plt.title(r'$t=T$', fontsize=fontsize)
plt.savefig('forward_process.png', bbox_inches='tight')