# Evaluation pendulum

In [1]:
import os
import yaml
import json

import torch

from matplotlib import pyplot as plt
import numpy as np

from src.nnets import MLPConditionalGenerator, get_mlp_discriminator
from src.nnets import NeuralPendulum
from src.forward_models.pendulum.pendulum_model import SimplePendulum, PendulumSolver

from src.data_loader.config_data_loader import (
    get_x_params_sampler,
)

from src.data_loader.config_data_loader import SimulatorDataset, SimulationSampler
from src.nnets.utils import load_model


from src.metrics.mmd import MMDLoss, RBF, estimate_mmd_bandwidth
from src.data_loader.data_loader import get_data_loader


In [2]:
def plot_traj(t, test_params, test_sims, pred, X_params, X_sims, X_init_conds,
            n_plot_samples=10, rnd_samples=True, save_path=None):
    # Predict the trajectories
    # Plot the predicted trajectories

    for i in range(n_plot_samples):
        if rnd_samples:
            idx = np.random.randint(0, len(test_params))
        else:
            idx = i
        idx = np.random.randint(0, len(test_params))
        print(f"Sample {idx}")

        fig = plt.figure(figsize=(6, 6))
        ax0 = fig.add_subplot(111)

        # incomplete trajectories

        ax0.plot(t, X_sims[idx].tolist(), color="gray", label="Part", alpha=0.4)

        # complete trajectories
        for j in range(len(test_sims[idx])):
            ax0.plot(t, test_sims[idx][j].tolist(), color="blue", label="Full", alpha=0.4)
        
        # predicted params
        for j in range(len(pred[idx])):
            ax0.plot(t, pred[idx][j].tolist(), color="orange", label="Pred", alpha=0.6)
        
        init_cond = X_init_conds[idx][0].item()
        param_1 = X_params[idx][0].item()
        param_2 = test_params[idx][0][1].item()
        if test_params[idx][0].shape[0] > 2:
            param_3 = test_params[idx][0][2].item()
        else:
            param_3 = 0
        if test_params[idx][0].shape[0] > 3:
            param_4 = test_params[idx][0][3].item()
        else:
            param_4 = 0
        
        title= r"Full: $\vartheta_0$={:.2f}, $\omega$={:.2f}, $\xi$={:.2f}, A={:.2f}, $\phi$={:.2f}".format(
                                init_cond, param_1, param_2, param_3, param_4
                            )
        
        handles, labels = ax0.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        plt.legend(by_label.values(), by_label.keys(), loc='upper right')
        ax0.set_xlabel("Time")
        ax0.set_ylabel("Angle")
        ax0.set_title(title)
        if save_path is not None:
            plt.savefig(save_path + f"/{idx}.png")
        plt.show()

In [3]:
# Multidimensional Mean square root error
def rmse(y_true, y_pred):
    return torch.sqrt(torch.mean((torch.tensor(y_true) - torch.tensor(y_pred))**2, dim=1))

# Multidimensional Relative Mean square root error
def relative_rmse(y_true, y_pred):
    return torch.sqrt(torch.mean(((torch.tensor(y_true) - torch.tensor(y_pred)) / torch.tensor(y_true))**2, dim=1))

In [4]:
task_name = "pendulum"
# d9f46b9158a67a005d7ab3f3b99924d3 many modes 
#3c3163b57df8772092d4f07e4b1d3884 many modes many sampls z=80
# d717946223abddcb61ade980a7d7d3da one mode

test_exp_name = "ab3c7ed9be53f93b31ea13deba72659a"
dataset_file_path = f"../../../datasets/forward_models/pendulum/one_to_many/many_modes/testing/data_{test_exp_name}"
method="rk4"


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_dataset = SimulatorDataset(
    name_dataset=task_name,
    data_file_path=dataset_file_path,
    testing_set=True,
    device=device,
)
test_loader = get_data_loader(test_dataset, batch_size=1024)

In [None]:
test_samples = test_loader.sample()
test_params = test_samples["params"]
test_sims = test_samples["x"]
test_init_conds = test_samples["init_conds"]
print(test_params.shape)
print(test_sims.shape)

In [None]:
# Retrieve the parameters and initial conditions for the incomplete model
param_dim = 1
noisy_samples = test_params.shape[
    1
]  # number of stochastic samples

print(f"Test number of noisy samples per parameter: {noisy_samples}" )
X_params = test_params[:, :, :param_dim]
X_params = X_params.reshape(-1, param_dim)
X_params = X_params[::noisy_samples]
X_init_conds = test_init_conds.reshape(
    -1, test_init_conds.shape[-1]
)
X_init_conds = X_init_conds[::noisy_samples]
phys_model = PendulumSolver(
    len_episode=test_dataset.conf_data["len_episode"],
    dt=test_dataset.conf_data["dt"],
    method=method)
res_sims = phys_model(init_conds=X_init_conds, params=X_params)
X_sims = res_sims["x"]

# Print shapes
print(X_params.shape)
print(X_init_conds.shape)
print(X_sims.shape)

t = torch.linspace(0.0, phys_model.dt * (test_sims.shape[-1] - 1), test_sims.shape[-1])
save_path = f"../../../outputs/figures/pendulum/{test_exp_name}"


In [None]:
# Load T Generative Model
#many mode 
from src.train.ot_physics.utils import freeze
#path_model_file = f"../../../outputs/best_models/pendulum/3c316/marginal_score/c2st/best_ot/bb98da8acfc24506d5ff553ea9ee1dd3/4XAj_score_0.8606_epoch_1000/model_041da91b5da0d6e000275a5809966dec_exp_bb98da8acfc24506d5ff553ea9ee1dd3_126ccdcd5e2a45a2aafb643e8a5956fa_salt_4XAj_final.pt"
#path_model_file = f"../../../outputs/best_models/pendulum/two-modes/marginal_score/mmd/ot/24_10_24/bb98da8acfc24506d5ff553ea9ee1dd3/tMla_score_0.0036_epoch_16800/model_7872786aacc0c1fbe240105e860b6283_exp_bb98da8acfc24506d5ff553ea9ee1dd3_126ccdcd5e2a45a2aafb643e8a5956fa_salt_tMla_final.pt"
path_model_file = f"../../../multirun/2024-11-15/16-13-20/2/checkpoints/gW50/model_e8530c60a7875371df4b99bc3e63916b_exp_b0af176f7b1d439550e605188cbc1138_0c4be06cb37e9cdb02c2f92076ff3bef_salt_gW50_final.pt"

# 2 modes model
#path_model_file = f"../../../outputs/best_models/pendulum/d7179/marginal_score/bb98da8acfc24506d5ff553ea9ee1dd3_79f7b854121ce744e06e889d40b4e6d0/6L2l_score_0.0091_epoch_8000/model_82aafdf137d20aee5e42064fbc4dcd35_exp_bb98da8acfc24506d5ff

T_model, parent_dir, config, dic_chkpt = load_model(
            path_model_file=path_model_file,
        )
f_state = dic_chkpt["f_model_state_dict"]

freeze(T_model)
dic_chkpt["epoch"]

## OT OdeNET

## Conditional samples

In [31]:
# 
T_noisy_samples = noisy_samples
z_dist = config["z_dist"] if "z_dist" in config else "gauss"
pred = T_model.predict(X_sims, context=X_params, z_samples=T_noisy_samples, z_type_dist=z_dist)

In [None]:
# you should do l2 norm instead of mean dim(1) and later max and min
print((torch.norm(pred.reshape(-1, 50), dim=1)  - torch.norm(test_sims.reshape(-1, 50), dim=1)).mean())
print((torch.norm(pred.reshape(-1, 50), dim=1)  - torch.norm(test_sims.reshape(-1, 50), dim=1)).max())
print((torch.norm(pred.reshape(-1, 50), dim=1)  - torch.norm(test_sims.reshape(-1, 50), dim=1)).min())

In [None]:
os.makedirs(save_path+"/lik", exist_ok=True)
plot_traj(t, test_params, test_sims, pred, X_params, X_sims, X_init_conds, save_path=save_path+"/lik")

In [34]:
del pred

## Marginal samples

In [35]:
n_samples = 200
rnd = np.random.randint(0, len(X_sims))
idx = np.random.randint(0, len(X_sims), n_samples)

In [36]:
# Only Marginals Plots
X_init_conds = test_init_conds.reshape(-1, 2)[::noisy_samples]
x_params_sampler = get_x_params_sampler(
                        config, device=device
                    )
init_conds = X_init_conds[idx]
X_params = x_params_sampler.sample((init_conds.shape[0],))
res_x_sims = phys_model(init_conds=init_conds, params=X_params)
X_sims = res_x_sims["x"]

pred = T_model.predict(X_sims, context=X_params, z_samples=noisy_samples, z_type_dist=z_dist)

In [None]:
torch.isclose(pred.reshape(-1, test_sims.shape[-1])[:, 0 ], test_sims[idx].reshape(-1, test_sims.shape[-1])[:, 0 ]).all()

In [38]:
t = torch.linspace(0.0, phys_model.dt * (test_sims.shape[-1] - 1), test_sims.shape[-1])

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(6, 6))
#axes.plot(t, X_sims[idx].tolist(), color="gray", label="Part", alpha=0.4)
y_hat = pred.reshape(-1, 50)
y = test_sims[idx].reshape(-1, 50)
t = t.repeat((y_hat.shape[0], 1))
for i in range(len(y_hat)):
    axes.plot(t[i], y_hat[i].tolist(), color="orange", label="Pred", alpha=0.35)
    axes.plot(t[i], y[i].tolist(), color="blue", label="True", alpha=0.3)

axes.set_xlabel("Time")
axes.set_ylabel("Angle")
os.makedirs(save_path+"/marginals", exist_ok=True)
if save_path is not None:
    plt.savefig(save_path+"/marginals/" + f"/{rnd}.png")
#plt.show()

## Check conditional many samples

In [39]:
from src.sampler.distributions import BoxUniformSampler
n_samples = 50
params_samples = int(50)
t = torch.linspace(0.0, phys_model.dt * (test_sims.shape[-1] - 1), test_sims.shape[-1])

rnd = np.random.randint(0, len(test_init_conds))
idx = np.random.randint(0, len(test_init_conds), n_samples)
init_cond_sampler = BoxUniformSampler(torch.tensor([-1.57, 0], device=device), torch.tensor([1.57, 0], device=device), device=device)
init_conds = init_cond_sampler.sample((n_samples,))

x_params_sampler = get_x_params_sampler(
                        config, device=device
                    )
X_params = x_params_sampler.sample((init_conds.shape[0],))

param_1_lb = 0.2
param_1_ub = 1.5
y_params_sampler = BoxUniformSampler(torch.tensor([param_1_lb], device=device), torch.tensor([param_1_ub], device=device), device=device)
y_params = y_params_sampler.sample((n_samples, params_samples))
y_params = y_params.reshape(n_samples, params_samples, 1)


In [None]:
print(X_params.shape, init_conds.shape, y_params.shape)
tmp_X = X_params.repeat((1, params_samples)).reshape(-1, params_samples, 1)
params = torch.cat([tmp_X, y_params], dim=-1)
print(params.shape)

In [41]:
# source samples (incomplete model)
X_sims = phys_model(
        init_conds=init_conds.view(-1, init_conds.shape[-1]),
        params=X_params.view(-1, X_params.shape[-1])
    )["x"].view(n_samples, -1)

# target sampler
# concat X params with y_params
tmp_init_conds = init_conds.repeat((1,params_samples)).reshape(-1, params_samples, 2)
y_sims = phys_model(
        init_conds=tmp_init_conds.view(-1, init_conds.shape[-1]),
        params=params.view(-1, params.shape[-1])
    )["x"].view(n_samples, params_samples, -1)

In [None]:
print(X_sims.shape, y_sims.shape)
print(X_params.shape)

In [None]:
noisy_samples=params_samples
# 
z_dist = config["z_dist"] if "z_dist" in config else "gauss"
pred = T_model.predict(
        X=X_sims.view(-1, X_sims.shape[-1]),
        context=X_params.view(-1, X_params.shape[-1]),
        z_samples=noisy_samples,
        z_type_dist=z_dist
    )
print(pred.shape)

In [None]:
print(X_sims.shape, y_sims.shape, pred.shape)

In [None]:
save_path

In [None]:
for i in range(n_samples):
    idx = i
    fig, axes = plt.subplots(1, 1, figsize=(12, 6))
    ax0 = axes
    # incomplete trajectories
    ax0.plot(t, X_sims[idx].tolist(), color="gray", label="Part", alpha=0.2)

    # complete trajectories
    for k in range(len(y_sims[idx])):
        ax0.plot(t, y_sims[idx][k].tolist(), color="blue", label="Full", alpha=0.35)
    
    # predicted params
    for j in range(len(pred[idx])):
        ax0.plot(t, pred[idx][j].tolist(), color="orange", label="Pred", alpha=0.35)

    # Ground thruth min and max trajectories 
    x0 = init_conds[idx].repeat(2, 1)
    param_0 = torch.tensor([
                    [X_params[idx].item(), param_1_lb],
                    [X_params[idx].item(), param_1_ub]
                ],
            device=device)
        
    min_max_y = phys_model(
        init_conds=x0,
        params=param_0
    )["x"]

    for j in range(len(min_max_y)):
        ax0.plot(t, min_max_y[j].tolist(), color="red", label="Min/Max", alpha=1)
    
    init_cond = X_init_conds[idx][0].item()
    param_1 = X_params[idx][0].item()    
    
    title= r"Idx{:} Full: $\vartheta_0$={:.2f}, $\omega$={:.2f}".format(
                            idx, init_cond, param_1 
                        )
    
    handles, labels = ax0.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax0.legend(by_label.values(), by_label.keys(), loc='upper right')
    ax0.set_xlabel("Time")
    ax0.set_ylabel("Angle")
    ax0.set_title(title)

    #fig.subplots_adjust(hspace=0.4)
    os.makedirs(save_path+"/", exist_ok=True)    
    plt.savefig(save_path + f"/{idx}.png")
    plt.show()


## C2ST score for conditional 

In [None]:
from src.sampler.distributions import BoxUniformSampler
n_samples = 5
params_samples = int(1e4)
t = torch.linspace(0.0, phys_model.dt * (test_sims.shape[-1] - 1), test_sims.shape[-1])

rnd = np.random.randint(0, len(test_init_conds))
idx = np.random.randint(0, len(test_init_conds), n_samples)
init_cond_sampler = BoxUniformSampler(torch.tensor([-1.57, 0], device=device), torch.tensor([1.57, 0], device=device), device=device)
init_conds = init_cond_sampler.sample((n_samples,))

x_params_sampler = get_x_params_sampler(
                        config, device=device
                    )
X_params = x_params_sampler.sample((init_conds.shape[0],))

param_1_lb = 0.2
param_1_ub = 1.5
y_params_sampler = BoxUniformSampler(torch.tensor([param_1_lb], device=device), torch.tensor([param_1_ub], device=device), device=device)
y_params = y_params_sampler.sample((n_samples, params_samples))

print(X_params.shape, init_conds.shape, y_params.shape)
tmp_X = X_params.repeat((1, params_samples)).reshape(-1, params_samples, 1)
params = torch.cat([tmp_X, y_params], dim=-1)
print(params.shape)

# source samples (incomplete model)
X_sims = phys_model(
        init_conds=init_conds.view(-1, init_conds.shape[-1]),
        params=X_params.view(-1, X_params.shape[-1])
    )["x"].view(n_samples, -1)

# target sampler
# concat X params with y_params
tmp_init_conds = init_conds.repeat((1,params_samples)).reshape(-1, params_samples, 2)
y_sims = phys_model(
        init_conds=tmp_init_conds.view(-1, init_conds.shape[-1]),
        params=params.view(-1, params.shape[-1])
    )["x"].view(n_samples, params_samples, -1)

z_samples = int(1e4)
pred = T_model.predict(
        X=X_sims.view(-1, X_sims.shape[-1]),
        context=X_params.view(-1, X_params.shape[-1]),
        z_samples=z_samples
    )
print(pred.shape)
print(y_sims.shape, pred.shape)

In [None]:
from src.metrics.c2st_torch import c2st
scores = []

for i in range(2):
    c2st_score, _ = c2st(pred[i].to(device), y_sims[i].to(device))
    scores.append(c2st_score)

In [None]:
scores

# Vanilla OT

## VANITALLA OT MLP

In [None]:
# Load T Generative Model
#salt_model = "43IY"
#dir_model= "../../outputs/checkpoints/" + salt_model
t = torch.linspace(0.0, phys_model.dt * (test_sims.shape[-1] - 1), test_sims.shape[-1])
path_model_file = "../../../outputs/best_models/pendulum/0e992/bb/marginal_score/mmd/bb98da8acfc24506d5ff553ea9ee1dd3/uFsC_score_0.0362_epoch_43802/model_d5cb9a1c1bc2f77c518eaf33f81d1c83_exp_bb98da8acfc24506d5ff553ea9ee1dd3_0e99273f5187e7e8151a7929b3392bb4_salt_uFsC_best.pt"
#/ifhO/model_050d5bf7c16cad304eab5746e8da4e01_exp_99914b932bd37a50b983c5e7c90ae93b_8fea9929862bbc6572434bbefdcad037_salt_ifhO
#/SaTt/model_e4dd69ffa34b36665286da8d99887128_exp_99914b932bd37a50b983c5e7c90ae93b_8fea9929862bbc6572434bbefdcad037_salt_SaTt
T_model, phys_model, parent_dir, config = load_model(
            path_model_file=path_model_file,
            params_dim=X_params.shape[-1],
        )

In [None]:
T_noisy_samples = noisy_samples
pred = T_model.predict(X_sims, context=X_params, z_samples=T_noisy_samples)

In [None]:
plot_traj(t, test_params, test_sims, pred, X_params, X_sims, X_init_conds)