In [None]:
import os

import jax
import jax.numpy as jnp
from jax import vmap

from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P

from function_diffusion.utils.model_utils import (
    create_model,
    create_train_state,
    create_optimizer,
    compute_total_params,
)
from function_diffusion.utils.checkpoint_utils import (
    create_checkpoint_manager,
    restore_checkpoint,
)
from function_diffusion.utils.data_utils import create_dataloader
from function_diffusion.utils.baseline_utils import create_eval_step

from burgers.data_utils import create_dataset

In [None]:
from configs import base

fno_config = base.get_config('fno')
# vit_config = base.get_config('vit')
# unet_config = base.get_config('unet')

In [None]:
def restore_model(config):
    # Initialize model
    model = create_model(config)
    # Create learning rate schedule and optimizer
    lr, tx = create_optimizer(config)

    # Create train state
    state = create_train_state(config,  model, tx)
    num_params = compute_total_params(state)
    print(f"Model storage cost: {num_params * 4 / 1024 / 1024:.2f} MB of parameters")

    # Device count
    num_local_devices = jax.local_device_count()
    num_devices = jax.device_count()
    print(f"Number of devices: {num_devices}")
    print(f"Number of local devices: {num_local_devices}")

    # Create checkpoint manager
    job_name = f"{config.model.model_name}"
    ckpt_path = os.path.join(os.getcwd(), job_name, "ckpt")
    ckpt_mngr = create_checkpoint_manager(config.saving, ckpt_path)

    state = restore_checkpoint(ckpt_mngr, state)
    print(f"Model loaded from step {state.step}")

    # Create sharding for data parallelism
    mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), "batch")
    state = multihost_utils.host_local_array_to_global_array(state, mesh, P())

    eval_step = create_eval_step(model, mesh)

    return state, eval_step

In [None]:
fno_state, fno_eval_step = restore_model(fno_config)
# vit_state, vit_eval_step = restore_model(vit_config)
# unet_state, unet_eval_step = restore_model(unet_config)

In [None]:
# Create dataloaders
train_dataset, test_dataset = create_dataset(fno_config)
test_loader = create_dataloader(test_dataset,
                                batch_size=fno_config.dataset.train_batch_size,
                                num_workers=fno_config.dataset.num_workers,
                                shuffle=False)

In [None]:
rng_key = jax.random.PRNGKey(12345)

downsample_factor = 4
noise_level = 0.2

input_list = []
pred_list = []
ref_list = []
for x in test_loader:
    rng_key, subkey = jax.random.split(rng_key)
    x = jax.tree.map(jnp.array, x)
    y = x  # clean target, not used for model evaluation 

    x_downsampled = x[:, ::downsample_factor, ::downsample_factor]

    noise = jax.random.normal(rng_key, x_downsampled.shape) * 0.2 * noise_level
    x_noise = x_downsampled + noise

    x = jax.image.resize(x_noise, (x.shape[0], 256, 256, x.shape[-1]), method='bilinear')
    y = jax.image.resize(y, x.shape, method='bilinear')
    
    batch = (x, y)

    # Evaluate model
    pred = fno_eval_step(fno_state.params, batch)

    pred_list.append(pred)
    input_list.append(x)
    ref_list.append(y)
    

u_pred = jnp.concatenate(pred_list, axis=0).squeeze()
u_ref = jnp.concatenate(ref_list, axis=0).squeeze()
u_input = jnp.concatenate(input_list, axis=0).squeeze()

In [None]:
def compute_error(pred, y):
    return jnp.linalg.norm(pred.flatten() - y.flatten()) / jnp.linalg.norm(y.flatten())

error = vmap(compute_error)(u_pred, u_ref)

print(f"FNO Relative L2 Error: {jnp.mean(error) * 100:.2f} % Â± {jnp.std(error) * 100:.2f} %")

In [None]:
# Visualization of some examples
import matplotlib.pyplot as plt

k = 0

fig = plt.figure(figsize=(15, 4))
plt.subplot(1, 4, 1)
plt.title('Input')
plt.pcolor(u_input[k, :, :], cmap='jet')
plt.colorbar()

plt.subplot(1, 4, 2)
plt.title('Reference')
plt.pcolor(u_ref[k, :, :], cmap='jet')
plt.colorbar()

plt.subplot(1, 4, 3)
plt.title('FNO Prediction')
plt.pcolor(u_pred[k, :, :], cmap='jet')
plt.colorbar()

plt.subplot(1, 4, 4)
plt.title('FNO Error')
plt.pcolor(jnp.abs(u_pred[k, :, :] - u_ref[k, :, :]), cmap='jet')
plt.colorbar()

plt.tight_layout()
plt.show()