In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_swiss_roll
from helper_plot import hdr_plot_style

device = 'mps'
hdr_plot_style()

In [2]:
def sample_batch(size, noise=0.5):
    x, _ = make_swiss_roll(size, noise=noise)
    return x[:, [0, 2]] / 10.0

In [3]:
data = sample_batch(10 ** 4).T

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

# Sliced Score Matching
# Jacobian vector product trick
def sliced_score_matching(model, samples):
    samples.requires_grad_(True)

    # Construct random vectors
    vectors = torch.randn_like(samples)
    vectors = vectors / torch.norm(vectors, dim=-1, keepdim=True)

    # Compute the optimized vector-product jacobian
    logp, jvp = torch.autograd.functional.jvp(model, samples, vectors, create_graph=True)

    # Compute the norm loss
    norm_loss = (logp * vectors) ** 2 / 2.

    # Compute the Jacobian loss
    v_jvp = jvp * vectors
    jacob_loss = v_jvp
    loss = jacob_loss + norm_loss
    return loss.mean(-1).mean(-1)

In [5]:
def denoising_score_mathcing(scorenet, samples, sigma=0.01):
    perturbed_samples = samples + torch.randn_like(samples) * sigma
    target = -1 / (sigma ** 2) * (perturbed_samples - samples)
    scores = scorenet(perturbed_samples)
    target = target.view(target.shape[0], -1)
    scores = scores.view(scores.shape[0], -1)
    
    loss = 1 / 2 * ((scores - target) ** 2).sum(dim=-1).mean(dim=0)
    return loss

In [6]:
model = nn.Sequential(
    nn.Linear(2, 128), nn.Softplus(),
    nn.Linear(128, 128), nn.Softplus(),
    nn.Linear(128, 2)
)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
dataset = torch.tensor(data.T).float()

In [7]:
def train():
    for t in range(5000):
        print(f'\r{t}/5000', end='   ')
        loss = denoising_score_mathcing(model, dataset)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if not t % 1000: print(loss)

# train()    

In [8]:
def plot_gradients(model, data, plot_scatter=True):
    xx = np.stack(np.meshgrid(np.linspace(-1.5, 2.0, 50), np.linspace(-1.5, 2.0, 50)), axis=-1).reshape(-1, 2)
    scores = model(torch.from_numpy(xx).float()).detach().cpu()
    scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)
    scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)

    plt.figure(figsize=(16, 12))
    if plot_scatter:
        plt.scatter(*data, alpha=0.3, color='red', edgecolors='white', s=40)
    plt.quiver(xx.T[0], xx.T[1], scores_log1p[:,0], scores_log1p[:,1], width=0.002, color='white')
    plt.xlim(-1.5, 2.0)
    plt.ylim(-1.5, 2.0)

# plot_gradients(model, data)

In [9]:
def sample_langevin(model, x, n_steps=10, eps=1e-3, decay=.9, temperature=1.0):
    x_sequence = [x.unsqueeze(0)]
    for _ in range(n_steps):
        z_t = torch.randn(x.size())
        x = x + (eps / 2) * model(x) + (np.sqrt(eps) * temperature * z_t)
        x_sequence.append(x.unsqueeze(0))
        eps *= decay
    return torch.cat(x_sequence)

In [10]:
def plot_langevin(x=None):
    if x is None:
        x = torch.Tensor([1.5, -1.5])
    samples = sample_langevin(model, x).detach()

    plot_gradients(model, data)
    plt.scatter(samples[:, 0], samples[:, 1], color='green', edgecolors='white', s=150)
    deltas = (samples[1:] - samples[:-1])
    deltas = deltas - deltas / torch.tensor(np.linalg.norm(deltas, keepdims=True, axis=-1)) * 0.04
    for i, arrow in enumerate(deltas):
        plt.arrow(samples[i, 0], samples[i, 1], arrow[0], arrow[1], width=1e-4, head_width=2e-2, color='green', linewidth=3)