In [35]:
import tensorflow as tf
from losses import TrainState
from flax.training import checkpoints
from flax import jax_utils as flax_utils
from matplotlib import pyplot as plt
import jax
from jax import numpy as jnp

from models import super_simple
from models import utils as mutils
import sampling
import losses

from configs.vp.disk_inner_ssim_continuous import get_config
import sde_lib
import datasets

from matplotlib import pyplot as plt
from matplotlib.widgets import Slider
import os

In [36]:
%matplotlib qt

### Load State

In [46]:
rng = jax.random.PRNGKey(42)
model_rng, sampling_rng = jax.random.split(rng)
config = get_config()
score_model, init_model_state, init_model_params = mutils.init_model(model_rng, config)
optimizer = losses.get_optimizer(config)

state = TrainState.create(tx=optimizer, apply_fn=score_model.apply,
                          params=init_model_params, mutable_state=init_model_state,
                          rng=rng)
workdir = '/home/komodo/Documents/uni/thesis/score_sde'
ckpt_model_state = checkpoints.restore_checkpoint(os.path.join(workdir, 'checkpoints'), state)



### Setup SDE

In [50]:
if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")
inverse_scaler = datasets.get_data_inverse_scaler(config)

sampling_shape = (config.training.batch_size // jax.local_device_count(), config.data.image_size,
                config.data.image_size, config.data.num_channels)

### Sample Points
Samples 16 * batch_size points for every time step (50 time steps)

In [51]:
num_rng = 16
num_sample_steps = 50

# sampling_fn = pc_sampler
sampling_fn = sampling.get_sampling_fn(config, sde, score_model, sampling_shape, inverse_scaler, sampling_eps)
pstate = flax_utils.replicate(ckpt_model_state)
num_train_steps = config.training.n_iters

rng_vmap = jax.vmap(lambda r, t: sampling_fn(r, pstate, jnp.array([t])), in_axes=(0, None), out_axes=0)
both_vmap = jax.vmap(lambda r, t: rng_vmap(r, t), in_axes=(None, 0), out_axes=1)

rng = jax.random.fold_in(rng, jax.host_id())
rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() * num_rng + 1)
next_rng = jnp.asarray(next_rng)[:, None, :]
next_rng.shape

ns = jnp.floor(jnp.linspace(0, sde.N, num_sample_steps)).astype(jnp.int32)

a, b = both_vmap(next_rng, ns)
a_sq = a.reshape((num_rng, num_sample_steps, config.training.batch_size, a.shape[-1]))
temp = a_sq.transpose((1, 0, 2, 3)).reshape((num_sample_steps, -1, a.shape[-1]))
xs = temp[:, :, 0]
ys = temp[:, :, temp.shape[-1] - 1]
jnp.mean(jnp.sum(temp**2, axis=-1), axis=1)



Array([2.016763  , 1.7723984 , 1.8877945 , 1.9363682 , 1.950089  ,
       2.020043  , 2.0649948 , 2.0568416 , 2.0921037 , 2.0961423 ,
       2.1371517 , 2.1302905 , 2.0573492 , 2.069267  , 2.0702255 ,
       2.0944643 , 2.0210187 , 2.007275  , 1.9687707 , 1.948046  ,
       1.9317667 , 1.9014664 , 1.8822906 , 1.8842808 , 1.8674684 ,
       1.8818185 , 1.8182663 , 1.801059  , 1.7159332 , 1.6768332 ,
       1.6334797 , 1.5566115 , 1.496969  , 1.4129115 , 1.332406  ,
       1.2440677 , 1.1406351 , 1.0438135 , 0.94005775, 0.842557  ,
       0.75944215, 0.6677152 , 0.58697784, 0.5152065 , 0.43940556,
       0.37313795, 0.32205784, 0.28674954, 0.26061177, 0.2581532 ],      dtype=float32)

### Visualise Data

Prepare ground truth

In [52]:
_, test_ds, _ = datasets.get_dataset(config)
gt = tf.stack([x['image'] for x in test_ds.take(num_rng)]).numpy().reshape(temp.shape[1:])

Visualise ground truth

In [61]:
fig, ax = plt.subplots()
fig.set_figheight(10)
fig.set_figwidth(10)

limits = jnp.max(jnp.abs(temp[-1])) * 1.2
left, right = -limits, limits

circle = plt.Circle((0, 0), 1, color='r', fill=False)

ax.add_patch(circle)
line, = ax.plot(gt[:, 0], gt[:, gt.shape[-1] - 1], 'r+')

ax.set_ylim(left, right)
ax.set_xlim(left, right)

ax.grid(True)

plt.show()

Compute Scores

In [62]:
N = 64
ticks = jnp.linspace(left, right, N)
coords = jnp.stack(jnp.meshgrid(ticks, ticks), axis=-1).reshape(-1, gt.shape[-1])

ckpt = ckpt_model_state
time_vmap = jax.vmap(lambda t: ckpt.apply_fn({'params': ckpt.params}, coords, jnp.tile(t, N**2),
                                             train=False, mutable=False) * t**0.75)

timesteps = jnp.linspace(sde.T, sampling_eps, sde.N)
timesteps = jnp.stack([timesteps[ind] for ind in ns], axis=0)

scores = -1. * time_vmap(timesteps).squeeze()
scores.shape

(50, 4096, 2)

### Plot Data and Scores

Note that score magnitudes are not completely to scale

In [63]:
fig, ax = plt.subplots()
fig.set_figheight(10)
fig.set_figwidth(10)

ax.set_xlim(left, right)
ax.set_ylim(left, right)

circle = plt.Circle((0, 0), 1, color='r', fill=False)
ax.add_patch(circle)

quiver = ax.quiver(coords[:, 0], coords[:, 1], scores[-1, :, 0], scores[-1, :, 1])
line, = ax.plot(xs[-1, :], ys[-1, :], 'rx', alpha=0.7)


axamp = fig.add_axes([0.01, 0.2, 0.0225, 0.63])
t_slider = Slider(
    ax=axamp,
    label="T",
    valmin=0,
    valmax=sde.N,
    valinit=sde.N,
    orientation="vertical",
    valstep=ns
)

# The function to be called anytime a slider's value changes
def update(val):
    ind = jnp.where(ns == val)[0]
    quiver.set_UVC(scores[ind, :, 0], scores[ind, :, 1])
    line.set_xdata(xs[ind])
    line.set_ydata(ys[ind])
    fig.canvas.draw_idle()


# register the update function with each slider
t_slider.on_changed(update)

ax.grid(True)
