In [1]:
# Basic
import torch
import math
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import numpy as np
from numpy import random
torch.set_num_threads(4)

# PDEi
from utils_pde.utils_pde_2dpoisson import Poisson2D

# Viz
from utils_tools.utils_result_viz import plot_predictions_2D

# Base Mdoels
from utils_uqmd.utils_uq_dropout import DropoutPINN
from utils_uqmd.utils_uq_mlp import MLPPINN
from utils_uqmd.utils_uq_vi import VIBPINN

# CP
from utils_uqmd.utils_uq_cp import CP

# Ensure reproducibility
import random, numpy as np, torch
seed = 12345
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# ----------------------------------------------------------------------
# Data Noise
data_noise = 0.05          # same as your 2-D example


# Define the 3-D Helmholtz PDE
from utils_pde.utils_pde_3dhelmholtz import Helmholtz3D
domain = ((0.0, 1.0), (0.0, 1.0), (0.0, 1.0))   # (x0,x1),(y0,y1),(z0,z1)
k = math.pi                                     # wave number

true_solution = (
    lambda xyz: torch.sin(math.pi * xyz[..., 0:1])
              * torch.sin(math.pi * xyz[..., 1:2])
              * torch.sin(math.pi * xyz[..., 2:3])
)

pde = Helmholtz3D(k=k, domain=domain, true_solution=true_solution)

# Generate training / testing / calibration data
(X_train, Y_train)         = pde.data_generation(1_000, data_noise)
(X_test,  Y_test)          = pde.data_generation(   200, data_noise)
(X_calibration, Y_calibration) = pde.data_generation(400, data_noise)

# Number of interior collocation points for the PINN residual
colloc_pt_num = 200

# ----------------------------------------------------------------------


# Generating alphas to test
from utils_tools.utils_result_metrics import generating_alphas
alphas = generating_alphas(20)

# ----------------------------------------------------------------------

# Build base model
vi_pinn = VIBPINN(
    pde_class=pde, 
    input_dim=3,
    hidden_dims=[32, 64, 128, 128, 128, 64, 32], 
    output_dim=1, 
)

# Define tuning arguments
# Base Model
fit_args = {
    "coloc_pt_num":colloc_pt_num,
    "X_train":X_train, 
    "Y_train":Y_train
}

# Change the parameter grids to test
fit_kwargs_grid = {
    "epochs":[200],
    "λ_pde":[3.0, 5.0], "λ_bc":[5.0, 10.0], "λ_elbo":[1.0, 2.0, 5.0],  # Change based on baseline model
    "lr":[1e-3, 3e-4],
    # "scheduler_cls":[StepLR], "scheduler_kwargs":[{'step_size': 5000, 'gamma': 0.5}],
    "stop_schedule":[10000, 20000, 40000]
}

baseline_pred_kwargs = {  # Change all the kwargs here
    "n_samples":10
}


# CP Model
cp_pred_kwargs = {
        "X_train":X_train,  "Y_train":Y_train,
        "X_cal":X_calibration, "Y_cal":Y_calibration,
        "heuristic_u":"raw_std",
        "k":10
}

cp_testing_args = {
    "alphas":alphas, 
    "X_test":X_test, "Y_test":Y_test, 
    "X_cal":X_calibration, "Y_cal":Y_calibration, "X_train":X_train, "Y_train":Y_train, 
    "heuristic_u":"raw_std", # Change base on if the baseline model has its original uq band
    "k":10
}

baseline_testing_args = { 
    "uqmodel":vi_pinn,   # Change this
    "alphas":alphas, 
    "X_test":X_test, "Y_test":Y_test
}

# Defining testing grid
# --------------------------------------------
# Defining Plotting Grid
# --------------------------------------------

n_grid = 20
x = torch.linspace(domain[0][0], domain[0][1], n_grid)
y = torch.linspace(domain[1][0], domain[1][1], n_grid)
z = torch.linspace(domain[2][0], domain[2][1], n_grid)
X, Y, Z = torch.meshgrid(x, y, z, indexing='xy')
grid_test = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1), Z.reshape(-1, 1)], dim=1)

# ----------------------------------------------------------------------

# Start hyperparameter tuning 
from utils_tools.utils_tuning import hyperparameter_tuning
from utils_tools.utils_result_viz import plot_2D_comparison_with_coverage, plot_metrics_table
from utils_uqmd.utils_uq_cp import CP
from utils_tools.utils_result_metrics import cp_test_uncertainties, vi_test_uncertainties, do_test_uncertainties

hyperparameter_tuning(
    plot_title="VI CP Model", # Change this
    # Model Fitting & Predicting
    uqmodel=vi_pinn,  # Change this
    alpha=0.05, 
    X_test=grid_test, Y_test=Y_test, 
    fit_args=fit_args, fit_kwargs_grid=fit_kwargs_grid, baseline_pred_kwargs=baseline_pred_kwargs, cp_pred_kwargs=cp_pred_kwargs, 
    true_solution=pde.true_solution,
    # Coverage Test
    baseline_testing_args=baseline_testing_args, cp_testing_args=cp_testing_args,
    baseline_test_uncertainties=vi_test_uncertainties, # Change this
    # Plotting function
    plotting_func=plot_metrics_table,
    save_dir="3dhelmholtz_vi",  # Change this
    X_validation=X_test, Y_validation=Y_test
)

Using device: cpu
Using device: cpu

[🔎] Trying: {'epochs': 200, 'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_elbo': 1.0, 'lr': 0.001, 'stop_schedule': 10000}

[🟠] Training...
ep     0 | L=7.92e+03 | elbo=1.70e+02 | pde=2.58e+03  ic=0.00e+00  bc=3.35e-03 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     1 | L=6.94e+03 | elbo=1.91e+02 | pde=2.25e+03  ic=0.00e+00  bc=1.31e-03 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     2 | L=6.79e+03 | elbo=1.73e+02 | pde=2.21e+03  ic=0.00e+00  bc=3.08e-03 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     4 | L=7.07e+03 | elbo=1.91e+02 | pde=2.29e+03  ic=0.00e+00  bc=8.66e-04 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     6 | L=3.43e+03 | elbo=1.80e+02 | pde=1.08e+03  ic=0.00e+00  bc=3.94e-03 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     8 | L=9.80e+03 | elbo=1.75e+02 | pde=3.21e+03  ic=0.00e+00  bc=6.62e-06 | lr=1.00e-03 | learned noise_std=1.000e+00
ep    10 | L=5.29e+03 | elbo=1.94e+02 | pde=1.70e+03  ic=0.00e+00  bc=2.21e-03 | lr=1.00e-03 | learne

100%|██████████| 19/19 [00:02<00:00,  7.72it/s]
100%|██████████| 19/19 [04:28<00:00, 14.13s/it]



[✅] Data Loss = 4.807e-01

[🔎] Trying: {'epochs': 200, 'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_elbo': 1.0, 'lr': 0.001, 'stop_schedule': 20000}

[🟠] Training...
ep     0 | L=2.27e+03 | elbo=4.01e+02 | pde=6.21e+02  ic=0.00e+00  bc=2.64e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     1 | L=2.55e+03 | elbo=3.58e+02 | pde=7.29e+02  ic=0.00e+00  bc=1.39e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     2 | L=2.61e+03 | elbo=4.28e+02 | pde=7.28e+02  ic=0.00e+00  bc=3.21e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     4 | L=1.91e+03 | elbo=4.08e+02 | pde=5.01e+02  ic=0.00e+00  bc=2.24e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     6 | L=1.56e+03 | elbo=3.67e+02 | pde=3.98e+02  ic=0.00e+00  bc=2.70e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     8 | L=1.92e+03 | elbo=3.31e+02 | pde=5.29e+02  ic=0.00e+00  bc=9.22e-02 | lr=1.00e-03 | learned noise_std=1.000e+00
ep    10 | L=1.72e+03 | elbo=4.40e+02 | pde=4.28e+02  ic=0.00e+00  bc=1.65e-01 | lr=1.00e-03 | learned noise_s

100%|██████████| 19/19 [00:02<00:00,  7.51it/s]
100%|██████████| 19/19 [04:35<00:00, 14.47s/it]



[✅] Data Loss = 5.912e-01

[🔎] Trying: {'epochs': 200, 'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_elbo': 1.0, 'lr': 0.001, 'stop_schedule': 40000}

[🟠] Training...
ep     0 | L=2.11e+03 | elbo=3.86e+02 | pde=5.74e+02  ic=0.00e+00  bc=2.85e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     1 | L=2.33e+03 | elbo=3.03e+02 | pde=6.75e+02  ic=0.00e+00  bc=1.02e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     2 | L=2.06e+03 | elbo=3.66e+02 | pde=5.64e+02  ic=0.00e+00  bc=2.95e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     4 | L=1.79e+03 | elbo=3.79e+02 | pde=4.69e+02  ic=0.00e+00  bc=2.29e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     6 | L=1.56e+03 | elbo=3.70e+02 | pde=3.96e+02  ic=0.00e+00  bc=2.90e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     8 | L=1.77e+03 | elbo=3.22e+02 | pde=4.81e+02  ic=0.00e+00  bc=9.20e-02 | lr=1.00e-03 | learned noise_std=1.000e+00
ep    10 | L=1.67e+03 | elbo=3.94e+02 | pde=4.27e+02  ic=0.00e+00  bc=1.55e-01 | lr=1.00e-03 | learned noise_s

100%|██████████| 19/19 [00:02<00:00,  6.33it/s]
100%|██████████| 19/19 [04:44<00:00, 15.00s/it]



[✅] Data Loss = 5.787e-01

[🔎] Trying: {'epochs': 200, 'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_elbo': 1.0, 'lr': 0.0003, 'stop_schedule': 10000}

[🟠] Training...
ep     0 | L=2.18e+03 | elbo=3.82e+02 | pde=6.00e+02  ic=0.00e+00  bc=2.82e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     1 | L=2.67e+03 | elbo=3.30e+02 | pde=7.79e+02  ic=0.00e+00  bc=1.26e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     2 | L=2.17e+03 | elbo=3.76e+02 | pde=5.98e+02  ic=0.00e+00  bc=3.64e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     4 | L=1.79e+03 | elbo=3.90e+02 | pde=4.67e+02  ic=0.00e+00  bc=2.51e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     6 | L=1.58e+03 | elbo=3.71e+02 | pde=4.01e+02  ic=0.00e+00  bc=3.43e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     8 | L=1.77e+03 | elbo=3.15e+02 | pde=4.85e+02  ic=0.00e+00  bc=1.07e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep    10 | L=1.66e+03 | elbo=3.92e+02 | pde=4.24e+02  ic=0.00e+00  bc=1.87e-01 | lr=3.00e-04 | learned noise_

100%|██████████| 19/19 [00:02<00:00,  7.60it/s]
100%|██████████| 19/19 [04:50<00:00, 15.30s/it]



[✅] Data Loss = 5.582e-01

[🔎] Trying: {'epochs': 200, 'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_elbo': 1.0, 'lr': 0.0003, 'stop_schedule': 20000}

[🟠] Training...
ep     0 | L=2.01e+03 | elbo=3.53e+02 | pde=5.52e+02  ic=0.00e+00  bc=2.63e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     1 | L=2.29e+03 | elbo=3.13e+02 | pde=6.58e+02  ic=0.00e+00  bc=1.17e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     2 | L=2.05e+03 | elbo=3.61e+02 | pde=5.64e+02  ic=0.00e+00  bc=3.32e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     4 | L=1.78e+03 | elbo=3.73e+02 | pde=4.70e+02  ic=0.00e+00  bc=2.50e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     6 | L=1.55e+03 | elbo=3.60e+02 | pde=3.95e+02  ic=0.00e+00  bc=3.12e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     8 | L=1.75e+03 | elbo=3.12e+02 | pde=4.80e+02  ic=0.00e+00  bc=1.04e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep    10 | L=1.66e+03 | elbo=3.97e+02 | pde=4.21e+02  ic=0.00e+00  bc=1.79e-01 | lr=3.00e-04 | learned noise_

100%|██████████| 19/19 [00:02<00:00,  7.94it/s]
100%|██████████| 19/19 [04:11<00:00, 13.24s/it]



[✅] Data Loss = 5.677e-01

[🔎] Trying: {'epochs': 200, 'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_elbo': 1.0, 'lr': 0.0003, 'stop_schedule': 40000}

[🟠] Training...
ep     0 | L=1.94e+03 | elbo=3.19e+02 | pde=5.41e+02  ic=0.00e+00  bc=2.82e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     1 | L=2.23e+03 | elbo=2.85e+02 | pde=6.49e+02  ic=0.00e+00  bc=1.07e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     2 | L=2.01e+03 | elbo=3.29e+02 | pde=5.61e+02  ic=0.00e+00  bc=3.37e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     4 | L=1.75e+03 | elbo=3.45e+02 | pde=4.67e+02  ic=0.00e+00  bc=2.32e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     6 | L=1.50e+03 | elbo=3.17e+02 | pde=3.95e+02  ic=0.00e+00  bc=2.70e-01 | lr=3.00e-04 | learned noise_std=1.000e+00
ep     8 | L=1.73e+03 | elbo=2.92e+02 | pde=4.79e+02  ic=0.00e+00  bc=8.37e-02 | lr=3.00e-04 | learned noise_std=1.000e+00
ep    10 | L=1.64e+03 | elbo=3.79e+02 | pde=4.21e+02  ic=0.00e+00  bc=1.46e-01 | lr=3.00e-04 | learned noise_

100%|██████████| 19/19 [00:02<00:00,  7.14it/s]
100%|██████████| 19/19 [04:22<00:00, 13.83s/it]



[✅] Data Loss = 5.693e-01

[🔎] Trying: {'epochs': 200, 'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_elbo': 2.0, 'lr': 0.001, 'stop_schedule': 10000}

[🟠] Training...
ep     0 | L=2.14e+03 | elbo=2.58e+02 | pde=5.42e+02  ic=0.00e+00  bc=3.01e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     1 | L=2.43e+03 | elbo=2.38e+02 | pde=6.49e+02  ic=0.00e+00  bc=1.00e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     2 | L=2.21e+03 | elbo=2.55e+02 | pde=5.68e+02  ic=0.00e+00  bc=3.25e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     4 | L=2.19e+03 | elbo=2.62e+02 | pde=5.55e+02  ic=0.00e+00  bc=1.65e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     6 | L=1.69e+03 | elbo=2.33e+02 | pde=4.06e+02  ic=0.00e+00  bc=2.23e-01 | lr=1.00e-03 | learned noise_std=1.000e+00
ep     8 | L=1.98e+03 | elbo=2.62e+02 | pde=4.85e+02  ic=0.00e+00  bc=6.79e-02 | lr=1.00e-03 | learned noise_std=1.000e+00
ep    10 | L=1.97e+03 | elbo=3.54e+02 | pde=4.19e+02  ic=0.00e+00  bc=1.12e-01 | lr=1.00e-03 | learned noise_s

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>