In [1]:
%pylab inline
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree, ReversibleHeun
from scipy.integrate import simps
tfd = tfp.distributions
tfb = tfp.bijectors
import pickle

import haiku as hk
from models import SmallUResNet
from normalization import SNParamsTree

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib




In [2]:
def load_denoising_model(path):
    
    with open(path, 'rb') as handle:
        params, state, opt_state = pickle.load(handle)

    return params, state, opt_state

In [3]:
def score_fn_denoiser(y, s, model, params, state):

    score, _ = model.apply(params, state, None, y, s.reshape((-1,1,1,1)), is_training=False)
    
    return score

In [4]:
filepath = './models/score_model_0.025/model-2.pckl'
params, state, opt_state = load_denoising_model(filepath)

model = hk.transform_with_state(lambda x, sigma, is_training=False: SmallUResNet()(x, sigma, is_training))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
# not used at the moment
# def make_likelihood(sigma=0.):
#     """
#     Returns a mixture of Gaussians, convolved with a Gaussian of specified sigma
#     """
    
#     return tfd.Normal(2.5, jnp.sqrt(0.1**2 +sigma**2) )


In [6]:
def score_fn(t,x, model, params, state, pure_prior = True):
    if(pure_prior == False):  
        return 1
    else:
        return score_fn_denoiser(x, t, model, params, state)

# Sampling at fixed time

In [7]:
t0, t1 = 2., 0.
s_init = jnp.ones((1))

drift = lambda t, y, args: - 0.5 * score_fn(s_init, y, model, params, state)
diffusion = lambda t, y, args: jnp.ones_like(y)
solver = Euler()

@jax.jit
@jax.vmap
def get_samples(y, seed):
    
    brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-6, shape=(), key=seed)
    terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
    
    return diffeqsolve(terms, solver, t0, t1, dt0=-0.001, y0=y, max_steps=10_000).ys[0]

In [None]:
n_samples = 10

initial_samples = jax.random.normal(shape = (n_samples, 1, 128, 128, 1), key=jax.random.PRNGKey(13))
res = get_samples(
    initial_samples,
    jax.random.split(jax.random.PRNGKey(3), n_samples)
)

In [None]:
t0, t1 = 5., 0.
drift = lambda t, y, args: - 0.5 * score_fn(t/t0, y) # Now sigma is a function of time
diffusion = lambda t, y, args: jnp.ones_like(y)
solver = Euler()
saveat = SaveAt(ts=jnp.linspace(t0,t1))

@jax.jit
@jax.vmap
def get_samples(y, seed):
    brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-4, shape=(), key=seed)
    terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
    return diffeqsolve(terms, solver, t0, t1, dt0=-0.001, y0=y, max_steps=10_000, saveat=saveat).ys

In [None]:
ref_samples = res+0

res = get_samples(
    ref_samples,
    jax.random.split(jax.random.PRNGKey(8), 10_000)
)