In [1]:
import tensorflow as tf
from losses import TrainState
from flax.training import checkpoints
from flax import jax_utils as flax_utils
from matplotlib import pyplot as plt
import jax
from jax import numpy as jnp

from models import super_simple
from models import utils as mutils
import sampling
import losses

from configs.vp.swirl_ssim_continuous import get_config
import sde_lib
import datasets

from matplotlib import pyplot as plt
from matplotlib.widgets import Slider
from sklearn.datasets import make_swiss_roll
import os

2025-01-31 11:56:48.673339: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
%matplotlib qt

In [4]:
rng = jax.random.PRNGKey(42)
model_rng, sampling_rng = jax.random.split(rng)
config = get_config()
score_model, init_model_state, init_model_params = mutils.init_model(model_rng, config)

In [5]:
optimizer = losses.get_optimizer(config)

state = TrainState.create(tx=optimizer, apply_fn=score_model.apply,
                          params=init_model_params, mutable_state=init_model_state,
                          rng=rng)

In [6]:
workdir = '/home/komodo/Documents/uni/thesis/score_sde'
ckpt_model_state = checkpoints.restore_checkpoint(os.path.join(workdir, 'checkpoints'), state)



In [7]:
if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")
inverse_scaler = datasets.get_data_inverse_scaler(config)

sampling_shape = (config.training.batch_size // jax.local_device_count(), config.data.image_size,
                config.data.image_size, config.data.num_channels)

In [8]:
# sampling_fn = pc_sampler
sampling_fn = sampling.get_sampling_fn(config, sde, score_model, sampling_shape, inverse_scaler, sampling_eps)

In [25]:
pstate = flax_utils.replicate(ckpt_model_state)
num_train_steps = config.training.n_iters

# In case there are multiple hosts (e.g., TPU pods), only log to host 0
rng = jax.random.fold_in(rng, jax.host_id())
rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
next_rng = jnp.asarray(next_rng)

a, b = sampling_fn(next_rng, pstate, jnp.array([sde.N]))



In [50]:
rng_vmap = jax.vmap(lambda r, t: sampling_fn(r, pstate, jnp.array([t])), in_axes=(0, None), out_axes=0)
both_vmap = jax.vmap(lambda r, t: rng_vmap(r, t), in_axes=(None, 0), out_axes=1)

rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() * 16 + 1)
next_rng = jnp.asarray(next_rng)[:, None, :]
next_rng.shape

ts = jnp.floor(jnp.linspace(0, sde.N, 50)).astype(jnp.int32)

a, b = both_vmap(next_rng, ts)
a.shape


(16, 50, 1, 128, 1, 1, 3)

In [63]:
a_sq = a.reshape((16, 50, config.training.batch_size, a.shape[-1]))
a_sq.shape

(16, 50, 128, 3)

In [74]:
jnp.mean(jnp.sqrt(jnp.sum(a_sq**2, axis=-1)), axis=(0, 2))

Array([ 1.1501262,  1.1625515,  1.2062881,  1.2344209,  1.2479209,
        1.2449558,  1.2364532,  1.2285364,  1.2229407,  1.2179016,
        1.2154219,  1.2090038,  1.2083977,  1.2106754,  1.2220304,
        1.2350386,  1.2503494,  1.2825582,  1.3160365,  1.3565032,
        1.3989656,  1.4543663,  1.5367174,  1.6435182,  1.7760359,
        2.0158653,  2.244335 ,  2.5211372,  2.8446605,  3.2235012,
        3.6534164,  4.141734 ,  4.687155 ,  5.276929 ,  5.9064536,
        6.577108 ,  7.281826 ,  8.187635 ,  8.942745 ,  9.691716 ,
       10.430546 , 11.145516 , 11.812059 , 12.429026 , 12.9862585,
       13.46254  , 13.850713 , 14.145807 , 14.33746  , 14.429163 ],      dtype=float32)

In [93]:
gt, colors = make_swiss_roll(1000, noise=0.3)
jnp.mean(jnp.sqrt(jnp.sum(gt**2, axis=-1)))

Array(14.74979, dtype=float32)

In [76]:
xs = a_sq[:, :, :, 0]
ys = a_sq[:, :, :, 2]

In [77]:
xs.shape

(16, 50, 128)

In [94]:
fig, ax = plt.subplots()
fig.set_figheight(5)
fig.set_figwidth(5)

circle = plt.Circle((0, 0), 1, color='r', fill=False)

ax.add_patch(circle)
line, = ax.plot(gt[:, 0], gt[:, 2], 'r+')

ax.grid(True)

plt.show()

In [96]:
fig, ax = plt.subplots()
fig.set_figheight(5)
fig.set_figwidth(5)

circle = plt.Circle((0, 0), 1, color='r', fill=False)

ax.add_patch(circle)
line, = ax.plot(xs[0, -1], ys[0, -1], 'r+')

axamp = fig.add_axes([0.03, 0.2, 0.0225, 0.63])
t_slider = Slider(
    ax=axamp,
    label="T",
    valmin=0,
    valmax=sde.N,
    valinit=sde.N,
    orientation="vertical",
    valstep=ts
)

# The function to be called anytime a slider's value changes
def update(val):
    ind = jnp.where(ts == val)[0]
    line.set_xdata(xs[0, ind])
    line.set_ydata(ys[0, ind])
    fig.canvas.draw_idle()


# register the update function with each slider
t_slider.on_changed(update)

ax.grid(True)

plt.show()

In [92]:
fig = plt.figure()
fig.set_figheight(5)
fig.set_figwidth(5)

ax = fig.add_subplot(111, projection="3d")
fig.add_axes(ax)


line_gt = ax.scatter(gt[:, 0], gt[:, 1], gt[:, 2], c=colors, s=50, alpha=1.)
line = ax.scatter(a_sq[:1024, 0], a_sq[:1024, 1], a_sq[:1024, 2], c='#ff0000', alpha=.25, marker='x')
ax.grid(True)

plt.show()