In [7]:
import jax
from jax import jit, vmap, grad, vjp
import jax.random as random
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from diffusionjax.run_lib import get_model, train
from diffusionjax.utils import get_score, get_sampler, get_times
from diffusionjax.solvers import EulerMaruyama, Inpainted, Projected
from diffusionjax.inverse_problems import get_pseudo_inverse_guidance
from diffusionjax.plot import plot_scatter, plot_score, plot_heatmap
import diffusionjax.sde as sde_lib
from absl import app, flags
from ml_collections.config_flags import config_flags
from flax import serialization
import time
import os

%load_ext autoreload
%autoreload 2

In [2]:
from torch.utils.data import Dataset


In [5]:

class CircleDataset(Dataset):
  """Dataset containing samples from the circle."""
  def __init__(self, num_samples):
    self.train_data = self.sample_circle(num_samples)

  def __len__(self):
    return self.train_data.shape[0]

  def __getitem__(self, idx):
    return self.train_data[idx]

  def sample_circle(self, num_samples):
    """Samples from the unit circle, angles split.

    Args:
      num_samples: The number of samples.

    Returns:
      An (num_samples, 2) array of samples.
    """
    alphas = jnp.linspace(0, 2 * jnp.pi * (1 - 1/num_samples), num_samples)
    xs = jnp.cos(alphas)
    ys = jnp.sin(alphas)
    samples = jnp.stack([xs, ys], axis=1)
    return samples

  def metric_names(self):
    return ['mean']

  def calculate_metrics_batch(self, batch):
    return vmap(lambda x: jnp.mean(x, axis=0))(batch)[0, 0]

  def get_data_scaler(self, config):
    def data_scaler(x):
      return x / jnp.sqrt(2)
    return data_scaler

  def get_data_inverse_scaler(self, config):
    def data_inverse_scaler(x):
      return x * jnp.sqrt(2)
    return data_inverse_scaler


In [6]:
num_epochs = 4000
num_samples = 8
circleobj = CircleDataset(num_samples)
samples = circleobj.sample_circle(num_samples)
N = samples.shape[1]
plot_scatter(samples=samples, index=(0, 1), fname="samples", lims=((-3, 3), (-3, 3)))
rng = random.PRNGKey(2023)

In [8]:
from diffusionjax.utils import get_linear_beta_function
beta, log_mean_coeff = get_linear_beta_function(
  beta_min=config.model.beta_min, beta_max=config.model.beta_max)\

sde = sde_lib.VP(beta, log_mean_coeff)
sde = VP()
>>>
>>> def log_hat_pt(x, t):
>>>     """
>>>     Empirical distribution score.
>>>
>>>     Args:
>>>     x: One location in $\mathbb{R}^2$
>>>     t: time
>>>     Returns:
>>>     The empirical log density, as described in the Jupyter notebook
>>>     .. math::
>>>         \hat{p}_{t}(x)
>>>     """
>>>     mean, std = sde.marginal_prob(samples, t)
>>>     potentials = jnp.sum(-(x - mean)**2 / (2 * std**2), axis=1)
>>>     return logsumexp(potentials, axis=0, b=1/num_samples)
>>>
>>> # Get a jax grad function, which can be batched with vmap
>>> nabla_log_hat_pt = jit(vmap(grad(log_hat_pt), in_axes=(0, 0), out_axes=(0)))
>>>
>>> # Running the reverse SDE with the empirical drift
>>> plot_score(score=nabla_log_hat_pt, t=0.01, area_min=-3, area_max=3, fname="empirical score")


NameError: name 'VP' is not defined