In [1]:
import tensorflow as tf
from losses import TrainState
from flax.training import checkpoints
from flax import jax_utils as flax_utils
from matplotlib import pyplot as plt
import jax
from jax import numpy as jnp

from models import super_simple
from models import utils as mutils
import sampling
import losses

from configs.vp.disk_ssim_continuous import get_config
import sde_lib
import datasets

from matplotlib import pyplot as plt
from matplotlib.widgets import Slider
from sklearn.datasets import make_swiss_roll
import os

2025-01-31 17:12:11.224391: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
%matplotlib qt

### Load State

In [4]:
rng = jax.random.PRNGKey(42)
model_rng, sampling_rng = jax.random.split(rng)
config = get_config()
score_model, init_model_state, init_model_params = mutils.init_model(model_rng, config)
optimizer = losses.get_optimizer(config)

state = TrainState.create(tx=optimizer, apply_fn=score_model.apply,
                          params=init_model_params, mutable_state=init_model_state,
                          rng=rng)
workdir = '/home/komodo/Documents/uni/thesis/score_sde'
ckpt_model_state = checkpoints.restore_checkpoint(os.path.join(workdir, 'checkpoints'), state)



### Setup SDE

In [5]:
if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")
inverse_scaler = datasets.get_data_inverse_scaler(config)

sampling_shape = (config.training.batch_size // jax.local_device_count(), config.data.image_size,
                config.data.image_size, config.data.num_channels)

### Sample Points
Samples 16 * batch_size points for every time step (50 time steps)

In [6]:
num_rng = 16
num_sample_steps = 50

# sampling_fn = pc_sampler
sampling_fn = sampling.get_sampling_fn(config, sde, score_model, sampling_shape, inverse_scaler, sampling_eps)
pstate = flax_utils.replicate(ckpt_model_state)
num_train_steps = config.training.n_iters

# In case there are multiple hosts (e.g., TPU pods), only log to host 0
rng = jax.random.fold_in(rng, jax.host_id())
rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
next_rng = jnp.asarray(next_rng)

rng_vmap = jax.vmap(lambda r, t: sampling_fn(r, pstate, jnp.array([t])), in_axes=(0, None), out_axes=0)
both_vmap = jax.vmap(lambda r, t: rng_vmap(r, t), in_axes=(None, 0), out_axes=1)

rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() * num_rng + 1)
next_rng = jnp.asarray(next_rng)[:, None, :]
next_rng.shape

ns = jnp.floor(jnp.linspace(0, sde.N, num_sample_steps)).astype(jnp.int32)

a, b = both_vmap(next_rng, ns)
a_sq = a.reshape((num_rng, num_sample_steps, config.training.batch_size, a.shape[-1]))
temp = a_sq.transpose((1, 0, 2, 3)).reshape((num_sample_steps, -1, a.shape[-1]))
xs = temp[:, :, 0]
ys = temp[:, :, 2 if config.data.dataset == 'SWIRL' else 1]
jnp.mean(jnp.sqrt(jnp.sum(temp**2, axis=-1)), axis=1)



Array([1.227253  , 1.1333348 , 1.1525204 , 1.1521152 , 1.1544287 ,
       1.1699781 , 1.1798029 , 1.2024733 , 1.184582  , 1.18432   ,
       1.180551  , 1.1799562 , 1.1803519 , 1.1787719 , 1.1718082 ,
       1.1735826 , 1.1826493 , 1.1861877 , 1.1916692 , 1.195907  ,
       1.1885736 , 1.1740906 , 1.1756004 , 1.1647692 , 1.1478615 ,
       1.1424778 , 1.1340883 , 1.1232352 , 1.111315  , 1.103612  ,
       1.1003463 , 1.0780939 , 1.0647472 , 1.0348638 , 1.0045075 ,
       0.97948414, 0.9696005 , 0.9378516 , 0.9058557 , 0.8757992 ,
       0.85386527, 0.8176763 , 0.7860434 , 0.75315857, 0.7263725 ,
       0.70824707, 0.69089925, 0.6772013 , 0.6685418 , 0.66535664],      dtype=float32)

### Visualise Data

Prepare ground truth

In [7]:
_, test_ds, _ = datasets.get_dataset(config)
gt = tf.stack([x['image'] for x in test_ds.take(num_rng)]).numpy().reshape(temp.shape[1:])

Visualise ground truth

In [8]:
fig, ax = plt.subplots()
fig.set_figheight(5)
fig.set_figwidth(5)

circle = plt.Circle((0, 0), 1, color='r', fill=False)

ax.add_patch(circle)
line, = ax.plot(gt[:, 0], gt[:, 1], 'r+')

ax.grid(True)

plt.show()

Visualise Data

In [52]:
fig, ax = plt.subplots()
fig.set_figheight(10)
fig.set_figwidth(10)

left = -2.
right = 2.

ax.set_xlim(left, right)
ax.set_ylim(left, right)

circle = plt.Circle((0, 0), 1, color='r', fill=False)

ax.add_patch(circle)
line, = ax.plot(xs[-1], ys[-1], 'r+')

axamp = fig.add_axes([0.01, 0.2, 0.0225, 0.63])
t_slider = Slider(
    ax=axamp,
    label="T",
    valmin=0,
    valmax=sde.N,
    valinit=sde.N,
    orientation="vertical",
    valstep=ns
)

# The function to be called anytime a slider's value changes
def update(val):
    ind = jnp.where(ns == val)[0]
    line.set_xdata(xs[ind])
    line.set_ydata(ys[ind])
    fig.canvas.draw_idle()


# register the update function with each slider
t_slider.on_changed(update)

ax.grid(True)

plt.show()

### Visualising Gradient Field

In [48]:
N = 64
ticks = jnp.linspace(left, right, N)
coords = jnp.stack(jnp.meshgrid(ticks, ticks), axis=-1).reshape(-1, 2)

ckpt = ckpt_model_state
time_vmap = jax.vmap(lambda t: ckpt.apply_fn({'params': ckpt.params}, coords, jnp.tile(t, N**2),
                                             train=False, mutable=False) * t**0.75)

timesteps = jnp.linspace(sde.T, sampling_eps, sde.N)
timesteps = jnp.stack([timesteps[ind] for ind in ns], axis=0)

scores = -1. * time_vmap(timesteps).squeeze()
scores.shape

(50, 4096, 2)

In [49]:
fig, ax = plt.subplots()
fig.set_figheight(10)
fig.set_figwidth(10)

ax.set_xlim(left, right)
ax.set_ylim(left, right)


circle = plt.Circle((0, 0), 1, color='r', fill=False)
ax.add_patch(circle)

quiver = ax.quiver(coords[:, 0], coords[:, 1], scores[-1, :, 0], scores[-1, :, 1])

axamp = fig.add_axes([0.01, 0.2, 0.0225, 0.63])
t_slider = Slider(
    ax=axamp,
    label="T",
    valmin=0,
    valmax=sde.N,
    valinit=sde.N,
    orientation="vertical",
    valstep=ns
)

# The function to be called anytime a slider's value changes
def update(val):
    ind = jnp.where(ns == val)[0]
    quiver.set_UVC(scores[ind, :, 0], scores[ind, :, 1])
    fig.canvas.draw_idle()


# register the update function with each slider
t_slider.on_changed(update)

ax.grid(True)
