In [None]:
from diffuse.vae_train import Encoder

In [None]:
from diffuse.conditional import impl_log_likelihood
from functools import partial
import jax.numpy as jnp

# set the backend to be the cpu
import jax
jax.config.update("jax_platform_name", "cpu")

encoder = Encoder(32)
weight_path = 'vae_model-32/20250310_013841/vae_params.npz'
vae_trained = jnp.load(weight_path, allow_pickle=True)
encoder_params = vae_trained['encoder'].item()

likelihood_fn = partial(impl_log_likelihood, encoder, encoder_params, b=0, a=1, m=0.0001, measurement_noise=0.5)

In [None]:
from diffuse.conditional import CondSDEImplicit
from diffuse.sde import SDE, SDEState
from diffuse.sde import LinearSchedule
from diffuse.unet import UNet

# diffusion hyperparameters
tf = 2.0
n_t = 256
dt = tf / n_t
ts = jnp.linspace(0, tf, n_t)
dts = jnp.diff(ts)
beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=tf)

# score network
def nn_score(x, t):
    return score_net.apply(params, x, t)

score_nn_path = "ann_3499.npz"
score_net = UNet(dt, 64, upsampling="pixel_shuffle")
nn_trained = jnp.load(score_nn_path, allow_pickle=True)
params = nn_trained["params"].item()

# define sde
cond_sde = CondSDEImplicit(beta=beta, likelihood_fn=likelihood_fn, tf=tf, score=nn_score)

In [None]:
import matplotlib.pyplot as plt

data = jnp.load("dataset/mnist.npz")
xs = jnp.concatenate([data["x_train"], data["x_test"]], axis=0)

def preprocess(data): # Normalize to [-1, 1]
    max_val = data.max()
    min_val = data.min()
    data = (data - min_val) / (max_val - min_val) * 2 - 1
    return data

xs = preprocess(xs)
xi = xs[0]
plt.imshow(xi.reshape(28, 28), cmap="gray")
# add a colorbar
plt.colorbar()
plt.show()

In [None]:
from functools import partial

key = jax.random.PRNGKey(0)
num_generations = 1
ground_truth_shape = (28, 28, 1)

# Sampling code
init_samples = jax.random.normal(key, (num_generations, *ground_truth_shape)) # Sample from prior     
keys = jax.random.split(key, num_generations)
tfs = jnp.zeros((num_generations,)) + tf
state_f = SDEState(position=init_samples, t=tfs)

revert_sde = partial(cond_sde.reverso, score=nn_score, dts=dts, y=10, xi=xi) 

# Denoise
state_f, history = jax.vmap(revert_sde)(keys, state_f)

In [None]:
f, axs = plt.subplots(1, 1, figsize=(10, 5))

axs.imshow(state_f.position[0].reshape(28, 28), cmap="gray")
axs.set_title("Denoised Image")
plt.show()