In [34]:
import sys
sys.path.insert(0, '..')
import torch
from scipy import integrate
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from tqdm import tqdm
import threading
import functools

sigma = 2

from utils.kfp import diffusion_coeff, marginal_prob_std
from network.network import ScoreNet

marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)


error_tolerance = 1e-5

channels = 3
N = 20
H = 28
W = 28

In [35]:
## The error tolerance for the black-box ODE solver
error_tolerance = 1e-5
def ode_sampler(score_model,
                marginal_prob_std,
                diffusion_coeff,
                batch_size=64,
                atol=error_tolerance,
                rtol=error_tolerance,
                device='cpu',
                z=None,
                eps=1e-3):
  """Generate samples from score-based models with black-box ODE solvers.

  Args:
    score_model: A PyTorch model that represents the time-dependent score-based model.
    marginal_prob_std: A function that returns the standard deviation
      of the perturbation kernel.
    diffusion_coeff: A function that returns the diffusion coefficient of the SDE.
    batch_size: The number of samplers to generate by calling this function once.
    atol: Tolerance of absolute errors.
    rtol: Tolerance of relative errors.
    device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
    z: The latent code that governs the final sample. If None, we start from p_1;
      otherwise, we start from the given z.
    eps: The smallest time step for numerical stability.
  """
  t = torch.tensor(np.linspace(1., eps, N))
  # Create the latent code
  if z is None:
    initial_x = torch.randn(batch_size, 1, H, W, device=device) \
      * marginal_prob_std(t)[:, None, None, None]
  else:
    initial_x = z + torch.randn(batch_size, 1, H, W, device=device) \
      * marginal_prob_std(t)[:, None, None, None]

  shape = initial_x.shape

  def score_eval_wrapper(sample, time_steps):
    """A wrapper of the score-based model for use by the ODE solver."""
    sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
    time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
    with torch.no_grad():
      score = score_model(sample, time_steps)
    return score.cpu().numpy().reshape((-1,)).astype(np.float64)

  def ode_func(t, x):
    """The ODE function for use by the ODE solver."""
    time_steps = np.ones((shape[0],)) * t
    g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
    return  -0.5 * (g**2) * score_eval_wrapper(x, time_steps)

  # Run the black-box ODE solver.
  res = integrate.solve_ivp(ode_func, (1., eps), initial_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
  print(f"\nNumber of function evaluations: {res.nfev}")
  x = [res.y[:, 0].reshape(shape)[0][None]]
  T = len(res.t)
  for i in range(N-1):
    idx = int(i + (T/N))
    x.append(res.y[:, idx].reshape(shape)[i][None])
  x.append(res.y[:, -1].reshape(shape)[-1][None])
  x = np.concatenate(x, axis = 0)
  return x

In [36]:
# function for sampling on a thread
def diffuse_sample(channel, samples, diffusion_coeff, marginal_prob_std):

  model_score = ScoreNet(marginal_prob_std=marginal_prob_std)
  file = f'../model_cifar_thread_{channel}.pth'
  ckpt = torch.load(file)
  model_score.load_state_dict(ckpt)
  model_score.eval();

  sample_batch_size = N
  sampler = ode_sampler

  # Generate samples using the specified sampler.
  output = sampler(model_score,
                  marginal_prob_std,
                  diffusion_coeff,
                  sample_batch_size)

  samples.append(output)

In [49]:
#@title Sample each channel on a thread
threads = [None] * channels
samples = []

# diffuse all three channels concurrently
for ch in range(channels):
  threads[ch] = threading.Thread(target=diffuse_sample, args=[ch, samples, diffusion_coeff_fn, marginal_prob_std_fn])
  threads[ch].start()

for thread in threads:
  thread.join()

samples = np.concatenate(samples, axis = 1)
for i in range(samples.shape[0]):
  print(f'{samples[i].mean(), samples[i].std()}')
  plt.imshow(((samples[i] - samples[i].min())/(samples[i].max() - samples[i].min())).transpose(1, 2, 0))
  plt.show()




Number of function evaluations: 488

Number of function evaluations: 488

Number of function evaluations: 482


In [None]:
def get_frame(i):
  return ((samples[i] - samples[i].min())/(samples[i].max() - samples[i].min())).transpose(1, 2, 0)

fig = plt.figure(figsize=(8,8));

im = plt.imshow(get_frame(0), animated=True)

def frame(i):
  im.set_data(get_frame(i));

  return [im]

animation_fig = animation.FuncAnimation(fig, frame, frames=samples.shape[-1], interval=10, blit=False,repeat_delay=2);

animation_fig.save("animation.gif");