In [None]:
import torch, math, random, numpy as np
import torch.nn as nn

from utils_pde.utils_pde_2dpoisson import Poisson2D
from utils_uqmd.utils_uq_cp import CP
from utils_tools.utils_result_viz import plot_2D_comparison_with_coverage
from utils_tools.utils_result_metrics import cp_test_uncertainties, hmc_test_uncertainties
from utils_tools.utils_tuning import hyperparameter_tuning
from utils_uqmd.utils_uq_hmc import HMCBPINN

# Reproducibility
seed = 12345
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# PDE Definition
data_noise = 0.05
domain = ((0.0, 1.0), (0.0, 1.0))
true_solution = lambda xy: torch.sin(math.pi * xy[..., 0:1]) * torch.sin(math.pi * xy[..., 1:2])
pde = Poisson2D(domain, true_solution)

# Data Generation
X_train, Y_train = pde.data_generation(500, data_noise)
X_test, Y_test = pde.data_generation(100, data_noise)
X_calibration, Y_calibration = pde.data_generation(200, data_noise)
colloc_pt_num = 100

# Alphas
alphas = torch.linspace(0.0, 1.0, 10)[1:-1].view(-1, 1)


# Model
hmc_model = HMCBPINN(
    pde_class=pde, input_dim=2, hidden_dims=[16, 32, 64, 64, 64, 32, 16],
    output_dim=1, act_func=nn.Tanh, prior_std=1.0,
    step_size=1e-3, leapfrog_steps=5
)

# Fitting args
fit_args = {
    "coloc_pt_num": colloc_pt_num,
    "X_train": X_train,
    "Y_train": Y_train
}
fit_kwargs_grid = {
    "λ_pde": [1.0, 3.0, 5.0],
    "λ_bc": [5.0, 10.0],
    "λ_data": [1.0, 2.0, 5.0],
    "epochs": [100],
    "lr":[1e-4],
    "hmc_samples": [1000],
    "brun_in":[100],
    "step_size": [1e-3, 5e-4, 1e-4],
    "leapfrog_steps": [5],
}

baseline_pred_kwargs = { "n_samples": 500 }
cp_pred_kwargs = {
    "X_train": X_train, "Y_train": Y_train,
    "X_cal": X_calibration, "Y_cal": Y_calibration,
    "heuristic_u": "raw_std", "k": 10
}
cp_coverage_args = {
    "alphas": alphas, "X_test": X_test, "Y_test": Y_test,
    "X_cal": X_calibration, "Y_cal": Y_calibration,
    "X_train": X_train, "Y_train": Y_train,
    "heuristic_u": "raw_std", "k": 1
}
baseline_coverage_args = {
    # "uqmodel": hmc_model,
    "alphas": alphas,
    "X_test": X_test,
    "Y_test": Y_test,
    "n_samples": 1000
}

# Test grid
n_grid = 100
x = torch.linspace(0, 1, n_grid)
y = torch.linspace(0, 1, n_grid)
X, Y = torch.meshgrid(x, y, indexing='xy')
grid_test = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1)], dim=1)

# Run hyperparameter tuning
hyperparameter_tuning(
    plot_title="HMC CP Model",
    uqmodel=hmc_model, alpha=0.05,
    X_test=grid_test, Y_test=Y_test,
    fit_args=fit_args, fit_kwargs_grid=fit_kwargs_grid,
    baseline_pred_kwargs=baseline_pred_kwargs,
    cp_pred_kwargs=cp_pred_kwargs,
    true_solution=pde.true_solution,
    baseline_testing_args=baseline_coverage_args,
    cp_testing_args=cp_coverage_args,
    baseline_test_uncertainties=hmc_test_uncertainties,
    plotting_func=plot_2D_comparison_with_coverage,
    save_dir="2dpoisson_hmc_cp",
    X_validation=X_test, Y_validation=Y_test
)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu
Using device: cpu

[🔎] Trying: {'λ_pde': 1.0, 'λ_bc': 5.0, 'λ_data': 1.0, 'epochs': 100, 'lr': 0.0001, 'hmc_samples': 1000, 'brun_in': 100, 'step_size': 0.001, 'leapfrog_steps': 5}

[🟠] Training...


                                                                     

[MAP] epoch    100  −logPost=1.267e+02  Data=3.597e-01  PDE=8.387e+01  IC=0.000e+00  BC=1.347e-02


HMC:   9%|▉         | 92/1000 [00:02<00:24, 36.96it/s, acc=0.49]