In [1]:
from matplotlib import pyplot as plt
import torch
import torch.nn as nn

import numpy as np

In [2]:
def generate_star(n_spikes=5, inner_radius=0.4, outer_radius=1.0, n_samples=1000, center=(0, 0)):
    points = []
    angle_step = np.pi / n_spikes

    # Generate the star's vertices
    vertices = []
    for i in range(2 * n_spikes):
        angle = i * angle_step
        radius = outer_radius if i % 2 == 0 else inner_radius

        x = radius * np.cos(angle) + center[0]
        y = radius * np.sin(angle) + center[1]
        vertices.append([x, y])
    vertices.append(vertices[0])

    # Sample points along the star's edges
    vertices = np.array(vertices)
    sampled_points = []

    for i in range(len(vertices) - 1):
        start_point = vertices[i]
        end_point = vertices[i + 1]

        # Interpolate points along the edge
        for t in np.linspace(0, 1, n_samples // (len(vertices) - 1)):
            point = (1 - t) * start_point + t * end_point
            sampled_points.append(point)

    return np.array(sampled_points)

In [3]:
star = generate_star(n_samples=5000)

In [4]:
import torch
import numpy as np

def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return betas, alphas_cumprod

In [5]:
def q_sample(x_0, t, alphas_cumprod, noise=None):
    if noise is None:
        noise = torch.randn_like(x_0)

    alphas_cumprod_t = alphas_cumprod[t].view(-1, 1)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod_t)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod_t)

    return sqrt_alphas_cumprod * x_0 + sqrt_one_minus_alphas_cumprod * noise

In [6]:
class TinyUNet2D(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.net = nn.Sequential(
            nn.Linear(2 + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, x, t):
        t = t.float()
        t_embed = self.time_embed(t.view(-1, 1))
        # [B, 1] --> [B, H]
        # x: [B, 2] + [B, H] -- > [B, H + 2] -- > [B, 2]
        x_input = torch.cat([x, t_embed], dim=1)

        return self.net(x_input)

In [7]:
from torch.utils.data import TensorDataset, DataLoader
from torch.nn import functional as F

In [79]:
device = "cuda" if torch.cuda.is_available() else "cpu"

star = generate_star(n_samples=5000)
dataset = TensorDataset(torch.tensor(star, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

model = TinyUNet2D(256).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 2500
timesteps = 1000

for epoch in range(1, num_epochs + 1):
  model.train()
  total_loss = 0.

  for batch in dataloader:
    batch = batch[0].to(device)
    # [B, 2]            | -- > cat([B,2 ], [B, H]) --> [B, H+2] -- > [B, H]
    # [B, 1] --> [B, H] |
    t = torch.randint(0, timesteps, (batch.size(0),), device=device)

    _, alphas_cumprod = cosine_beta_schedule(timesteps)
    alphas_cumprod = alphas_cumprod.to(device)

    noise = torch.randn_like(batch)
    x_t = q_sample(batch, t, alphas_cumprod, noise)

    pred_noise = model(x_t, t)

    loss = F.mse_loss(pred_noise, noise)

    total_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


  if epoch % 100 == 0:
    print(f"Epoch: {epoch} Train Loss: {total_loss / len(dataloader)}")

Epoch: 100 Train Loss: 0.7969829320907593
Epoch: 200 Train Loss: 0.38335553407669065
Epoch: 300 Train Loss: 0.3621167317032814
Epoch: 400 Train Loss: 0.35263767242431643
Epoch: 500 Train Loss: 0.34834235459566115
Epoch: 600 Train Loss: 0.3291810601949692
Epoch: 700 Train Loss: 0.3424492970108986
Epoch: 800 Train Loss: 0.3304730996489525
Epoch: 900 Train Loss: 0.34369952976703644
Epoch: 1000 Train Loss: 0.32376574724912643
Epoch: 1100 Train Loss: 0.33083646893501284
Epoch: 1200 Train Loss: 0.3245365709066391
Epoch: 1300 Train Loss: 0.323319086432457
Epoch: 1400 Train Loss: 0.31863618791103365
Epoch: 1500 Train Loss: 0.295289670675993
Epoch: 1600 Train Loss: 0.31735763475298884
Epoch: 1700 Train Loss: 0.3059804379940033
Epoch: 1800 Train Loss: 0.31012887358665464
Epoch: 1900 Train Loss: 0.29451688975095747
Epoch: 2000 Train Loss: 0.30268605798482895
Epoch: 2100 Train Loss: 0.2940620809793472
Epoch: 2200 Train Loss: 0.2945985645055771
Epoch: 2300 Train Loss: 0.3054888390004635
Epoch: 2400

In [27]:
@torch.no_grad()
def sample_ddpm(model, shape, betas, device, returned_t=None):
    model.eval()
    T = len(betas)
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0).to(device)
    alphas_cumprod_prev = torch.cat([torch.tensor([1.], device=device), alphas_cumprod[:-1]])

    eps = 1e-5
    x_t = torch.randn(shape, device=device)

    to_return = []
    for t in reversed(range(T)):
        t_tensor = torch.full((shape[0],), t, device=device).long()

        beta_t = betas[t].to(device)
        alpha_t = alphas[t].to(device)

        sqrt_one_minus_ac = torch.sqrt(1 - alphas_cumprod[t] + eps)
        sqrt_recip_alpha = torch.sqrt(1. / (alpha_t + eps))

        eps_theta = model(x_t, t_tensor)

        model_mean = sqrt_recip_alpha * (x_t - beta_t / sqrt_one_minus_ac * eps_theta)

        if t > 0:
            noise = torch.randn_like(x_t)
            posterior_var = beta_t * (1 - alphas_cumprod_prev[t]) / (1 - alphas_cumprod[t] + eps)
            x_t = model_mean + torch.sqrt(posterior_var + eps) * noise
        else:
            x_t = model_mean

        if returned_t and t in returned_t:
          to_return.append(x_t)

    to_return.append(x_t)

    return to_return

In [73]:
device = "cuda"
betas, alphas = cosine_beta_schedule(1000)
betas = betas.to(device)
preds = sample_ddpm(model, (1500, 2), betas, device, returned_t=[100*i for i in range(1, 11)])

In [80]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

def animate_2d_samples(sampled_steps, interval=200, save_path="ddpm_evolution.gif"):
    fig, ax = plt.subplots()
    scat = ax.scatter([], [], s=10)

    def init():
        ax.set_xlim(-5, 5)
        ax.set_ylim(-5, 5)
        return scat,

    def update(frame):
        data = sampled_steps[frame].detach().cpu()  # [N, 2]
        scat.set_offsets(data.cpu().numpy())
        ax.set_title(f"Timestep {frame * 100}")
        return scat,

    ani = animation.FuncAnimation(fig, update, frames=len(sampled_steps), init_func=init, blit=True, interval=interval)
    ani.save(save_path, writer='pillow')
    plt.close()

In [81]:
animate_2d_samples(preds, interval=500)