In [1]:
import tensorflow as tf
from losses import TrainState
from flax.training import checkpoints
from flax import jax_utils as flax_utils
from matplotlib import pyplot as plt
import jax
from jax import numpy as jnp

from models import super_simple
from models import utils as mutils
import sampling
import losses

from configs.vp.disk_ssim_continuous import get_config
import sde_lib
import datasets

from matplotlib import pyplot as plt
from matplotlib.widgets import Slider
from sklearn.datasets import make_swiss_roll
import os

2025-01-31 15:40:08.975322: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
%matplotlib qt

In [4]:
rng = jax.random.PRNGKey(42)
model_rng, sampling_rng = jax.random.split(rng)
config = get_config()
score_model, init_model_state, init_model_params = mutils.init_model(model_rng, config)

In [5]:
optimizer = losses.get_optimizer(config)

state = TrainState.create(tx=optimizer, apply_fn=score_model.apply,
                          params=init_model_params, mutable_state=init_model_state,
                          rng=rng)

In [6]:
workdir = '/home/komodo/Documents/uni/thesis/score_sde'
ckpt_model_state = checkpoints.restore_checkpoint(os.path.join(workdir, '../backups/disk_modern/ssim/checkpoints'), state)



In [45]:
if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")
inverse_scaler = datasets.get_data_inverse_scaler(config)

sampling_shape = (config.training.batch_size // jax.local_device_count(), config.data.image_size,
                config.data.image_size, config.data.num_channels)

In [42]:
# sampling_fn = pc_sampler
sampling_fn = sampling.get_sampling_fn(config, sde, score_model, sampling_shape, inverse_scaler, sampling_eps)

In [43]:
pstate = flax_utils.replicate(ckpt_model_state)
num_train_steps = config.training.n_iters

# In case there are multiple hosts (e.g., TPU pods), only log to host 0
rng = jax.random.fold_in(rng, jax.host_id())
rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
next_rng = jnp.asarray(next_rng)

a, b = sampling_fn(next_rng, pstate, jnp.array([sde.N]))



In [10]:
rng_vmap = jax.vmap(lambda r, t: sampling_fn(r, pstate, jnp.array([t])), in_axes=(0, None), out_axes=0)
both_vmap = jax.vmap(lambda r, t: rng_vmap(r, t), in_axes=(None, 0), out_axes=1)

rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() * 16 + 1)
next_rng = jnp.asarray(next_rng)[:, None, :]
next_rng.shape

ts = jnp.floor(jnp.linspace(0, sde.N, 50)).astype(jnp.int32)

a, b = both_vmap(next_rng, ts)
a.shape

(16, 50, 1, 128, 1, 1, 2)

In [11]:
a_sq = a.reshape((16, 50, config.training.batch_size, a.shape[-1]))
a_sq.shape

(16, 50, 128, 2)

In [12]:
jnp.mean(jnp.sqrt(jnp.sum(a_sq**2, axis=-1)), axis=(0, 2))

Array([0.8908451 , 0.87487286, 0.88806045, 0.8920609 , 0.8923133 ,
       0.8886476 , 0.8897487 , 0.90048707, 0.8990839 , 0.88678765,
       0.8871851 , 0.8868524 , 0.87637115, 0.8735732 , 0.86853105,
       0.86097294, 0.8458158 , 0.8418015 , 0.8364288 , 0.83557713,
       0.8196401 , 0.80558825, 0.79425514, 0.787498  , 0.7718704 ,
       0.7621784 , 0.74935544, 0.73565745, 0.7249646 , 0.71591586,
       0.70578986, 0.6920155 , 0.68628585, 0.67271775, 0.65858054,
       0.6535178 , 0.6532285 , 0.64719296, 0.6423175 , 0.64397675,
       0.646213  , 0.6448644 , 0.64511395, 0.64148045, 0.6421057 ,
       0.64803654, 0.6502265 , 0.6508821 , 0.6527535 , 0.6566098 ],      dtype=float32)

In [93]:
gt, colors = make_swiss_roll(1000, noise=0.3)
jnp.mean(jnp.sqrt(jnp.sum(gt**2, axis=-1)))

Array(14.74979, dtype=float32)

In [35]:
temp = a_sq.transpose((1, 0, 2, 3)).reshape((50, -1, a.shape[-1]))

In [36]:
xs = temp[:, :, 0]
ys = temp[:, :, 2 if config.data.dataset == 'SWIRL' else 1]

In [94]:
fig, ax = plt.subplots()
fig.set_figheight(5)
fig.set_figwidth(5)

circle = plt.Circle((0, 0), 1, color='r', fill=False)

ax.add_patch(circle)
line, = ax.plot(gt[:, 0], gt[:, 2], 'r+')

ax.grid(True)

plt.show()

In [38]:
fig, ax = plt.subplots()
fig.set_figheight(5)
fig.set_figwidth(5)

circle = plt.Circle((0, 0), 1, color='r', fill=False)

ax.add_patch(circle)
line, = ax.plot(xs[-1], ys[-1], 'r+')

axamp = fig.add_axes([0.03, 0.2, 0.0225, 0.63])
t_slider = Slider(
    ax=axamp,
    label="T",
    valmin=0,
    valmax=sde.N,
    valinit=sde.N,
    orientation="vertical",
    valstep=ts
)

# The function to be called anytime a slider's value changes
def update(val):
    ind = jnp.where(ts == val)[0]
    line.set_xdata(xs[ind])
    line.set_ydata(ys[ind])
    fig.canvas.draw_idle()


# register the update function with each slider
t_slider.on_changed(update)

ax.grid(True)

plt.show()