# Evaluation Advection diffusion

In [1]:
import os
import yaml
import json

import torch

from matplotlib import pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import numpy as np

from src.nnets.utils import load_model
from src.data_loader.data_loader import get_data_loader
from src.data_loader.config_data_loader import SimulatorDataset
from src.forward_models.init_physics_model import init_physics_solver_model
from src.metrics.mmd import MMDLoss, RBF, estimate_mmd_bandwidth
from src.evaluation.evaluator import one_to_many_evaluation



In [2]:
task_name = "advdiff"
test_exp_name = "ac9b06feb779781d54bdb5d8191edd00"
dataset_file_path = f"../../../datasets/forward_models/advdiff/one_to_many/testing/data_{test_exp_name}"
path_model_file = f"../../../outputs/best_models/advdiff/ac9b/euler/blackbox/a8f4feaea3835b0f74c953fd537af5da/VDN8_score_0.54_epoch_9800/model_587bb13f417afaf415ead7ad0951e98d_exp_3ca6aaf4c2ce8bfff7b41801b2c141b2_a8f4feaea3835b0f74c953fd537af5da_salt_VDN8_best.pt"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(13)
np.random.seed(13)

# Load the model
params_dim = 1
T_model, parent_dir, config, _ = load_model(
    path_model_file=path_model_file,
    params_dim=params_dim,
)

test_dataset = SimulatorDataset(
    name_dataset=task_name,
    data_file_path=dataset_file_path,
    testing_set=True,
    device=device,
)
test_loader = get_data_loader(
    test_dataset, batch_size=512
)  # batch size can be adjusted
bandwidth = estimate_mmd_bandwidth(test_loader, median_heuristic=True)

rbf = RBF(bandwidth=bandwidth, n_kernels=6, device=device)
metrics = [MMDLoss(kernel=rbf)]

In [None]:
test_samples = test_loader.sample()
test_params = test_samples["params"]
test_sims = test_samples["x"]
x_dim = test_sims.shape[2:]
test_init_conds = test_samples["init_conds"]
print(test_params.shape)
print(test_init_conds.shape)
print(test_sims.shape)
# Retrieve the parameters and initial conditions for the incomplete model
noisy_samples = test_params.shape[
    1
]  # number of stochastic samples
print(f"Test number of noisy samples per parameter: {noisy_samples}" )


In [None]:
X_params = test_params[
    :, ::noisy_samples, :, :params_dim
].squeeze(1)

X_init_conds = test_init_conds[:, ::noisy_samples].squeeze(
    1
)
phys_solver = init_physics_solver_model(config=config, device=device)
res_sims = phys_solver(init_conds=X_init_conds, params=X_params)
X_sims = res_sims["x"]
# Print shapes
print(X_params.shape)
print(X_init_conds.shape)
print(X_sims.shape)

## OT OdeNET

In [6]:
T_noisy_samples = noisy_samples
pred = T_model.predict(X_sims, context=X_params, z_samples=T_noisy_samples)
pred = pred.reshape((-1, T_noisy_samples) + (x_dim))
pred = pred.reshape((-1, T_noisy_samples) + (x_dim))

In [None]:
res_eval = one_to_many_evaluation(
    X_params,
    pred,
    test_sims,
    metrics=metrics,
    type_evals=["marginal_score"]
)
scores = res_eval["lik_score"]
print(scores)

In [None]:
X_sims.shape

In [None]:
plot_samples = 10
xi = list(range(X_sims.shape[-1]))
x_labels = T_model.t_intg.detach().cpu().tolist()
x_labels = ["%.2f"%item for item in x_labels]

for j in range(0, plot_samples):
    random_idx = np.random.randint(0, X_sims.shape[0])
    print(random_idx)
    kwargs = {'vmin':0.0, 'aspect':3.5, 'cmap':'magma', 'interpolation':'none'}
    fig, axes = plt.subplots(2, noisy_samples+1, figsize=(30, 15))

    param_0 = X_params.reshape(-1, 20)[random_idx].unique()
    print(f"Simple Model parameters \n{param_0}")
    c_params = test_params[random_idx, :, 0].unique(dim=0)
    print(f"Complete Model parameters \n{c_params}")
    #x axis format

    axes[0, 0].imshow(X_sims[random_idx].detach().cpu().numpy(), **kwargs)
    axes[0, 0].set_xticks(xi, x_labels, minor=False)
    axes[0, 0].locator_params(axis='x', nbins=5)
    axes[0, 0].set_title(f"Simple model, coeff: {param_0.item():.4f}")
    # set x_axis range to be  T_model.t_intg.detach().cpu().tolist()
    # axes[0, 0].set_xscale('function', functions=(lambda x: x, lambda x: x))

    axes[1, 0].imshow(X_sims[random_idx].detach().cpu().numpy(), **kwargs)
    axes[1, 0].set_xticks(xi, x_labels, minor=False)
    axes[1, 0].locator_params(axis='x', nbins=5)
    axes[1, 0].set_title(f"Simple model, coeff: {param_0.item():.4f}")

    for i in range(noisy_samples):
        # Target
        axes[0, i+1].imshow(test_sims[random_idx, i].detach().cpu().numpy(), **kwargs)
        axes[0, i+1].set_xticks(xi, x_labels, minor=False)
        axes[0, i+1].locator_params(axis='x', nbins=5)
        axes[0, i+1].set_title(f"+dcoeff: {c_params[i][1]:.4f}")
        
        # Prediction
        axes[1, i+1].imshow(pred[random_idx, i].detach().cpu().numpy(), **kwargs)
        axes[0, i+1].set_xticks(xi, x_labels, minor=False)
        axes[0, i+1].locator_params(axis='x', nbins=5)
        axes[1, i+1].set_title(f"Pred sample {i+1}")
    
    #axes[0, 0].xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    plt.show()

In [10]:
del pred

## Extrapolation

In [None]:
# change model's integrator setting
T_model.len_episode = (4) * T_model.len_episode 
print(f"Extrapolatin of {T_model.len_episode} steps")
T_model.t_intg = torch.linspace(0.0, T_model.dt* T_model.len_episode, T_model.len_episode, device=T_model.device)
print(f"Max time: {T_model.t_intg.max()}")

with torch.no_grad():
    extr_pred = T_model.predict(X_sims, context=X_params, z_samples=T_noisy_samples)
    extr_pred = extr_pred.reshape((-1, T_noisy_samples) + (extr_pred.shape[1:]))
    #pred = pred.reshape((-1, T_noisy_samples) + (x_dim))
    print(extr_pred.shape)

In [None]:
print(test_params.shape)
x_grid_dim = test_params.shape[2]
params = test_params.reshape(-1, x_grid_dim, 2)
new_init_conds = X_init_conds.reshape(-1, 1, x_grid_dim).repeat(1, 5,1)
new_init_conds = new_init_conds.reshape(-1, x_grid_dim)
print(params.shape, new_init_conds.shape)

In [13]:
# complete_model
phys_solver.len_episode = T_model.len_episode
phys_solver.dt = T_model.dt
phys_solver.t = T_model.t_intg


res_extr_Y = phys_solver(init_conds=new_init_conds, params=params)
Y_extr = res_extr_Y["x"].reshape(-1, noisy_samples, x_grid_dim, phys_solver.len_episode)
params = params.reshape(-1, noisy_samples, x_grid_dim, 2)


In [None]:
print(extr_pred.shape, Y_extr.shape)
assert torch.isclose(extr_pred.reshape(-1, 20, extr_pred.shape[-1])[:, :, 0], Y_extr.reshape(-1, 20, Y_extr.shape[-1])[:, :, 0]).all()

In [None]:
plot_samples = 10
#xi = list(range(extr_pred.shape[-1]))
x_labels = T_model.t_intg.detach().cpu().tolist()
x_labels = ["%.2f"%item for item in x_labels]
for j in range(0, plot_samples):
    random_idx = np.random.randint(0, extr_pred.shape[0])
    print(random_idx)
    
    fig, axes = plt.subplots(noisy_samples, 2, figsize=(15, 15))

    param_0 = X_params.reshape(-1, 20)[random_idx].unique()
    print(f"Simple Model parameters \n{param_0}")
    c_params = test_params[random_idx, :, 0].unique(dim=0)
    print(f"Complete Model parameters \n{c_params}")
    
    y_params_1 = params[random_idx, :, 0].unique(dim=0)
    print(y_params_1.shape)
    #x axis format

    for i in range(noisy_samples):
        # Prediction
        #vmax = np.maximum(extr_pred[random_idx, i].max().item(), Y_extr[random_idx, i].max().item())
        kwargs = {'vmin':0.0, 'vmax': 1.5, 'aspect':3.5, 'cmap':'magma', 'interpolation':'none'}
        axes[i, 0].imshow(extr_pred[random_idx, i].detach().cpu().numpy(), **kwargs)
        #axes[i].set_xticks(xi, x_labels, minor=False)
        #axes[i].locator_params(axis='x', nbins=10)
        axes[i, 0].set_title(f"Pred sample {i+1}, dcoeff: {param_0.item():.4f}")
        
        # Prediction
        axes[i, 1].imshow(Y_extr[random_idx, i].detach().cpu().numpy(), **kwargs)
        #axes[i].set_xticks(xi, x_labels, minor=False)
        #axes[i].locator_params(axis='x', nbins=10)
        axes[i, 1].set_title(f"GD sample {i+1}, dcoeff: {param_0.item():.4f}, c-coeff: {y_params_1[i][1].tolist():.4f}")
    
    #axes[0, 0].xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    fig.subplots_adjust(hspace=0.5)
    plt.show()