In [None]:
import torch, math, random, numpy as np
import torch.nn as nn

from utils_pde.utils_pde_2dpoisson import Poisson2D
from utils_uqmd.utils_uq_cp import CP
from utils_tools.utils_result_viz import plot_2D_comparison_with_coverage
from utils_tools.utils_result_metrics import cp_test_uncertainties, hmc_test_uncertainties
from utils_tools.utils_tuning import hyperparameter_tuning
from utils_uqmd.utils_uq_hmc import HMCBPINN

# Reproducibility
seed = 12345
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# PDE Definition
data_noise = 0.05
domain = ((0.0, 1.0), (0.0, 1.0))
true_solution = lambda xy: torch.sin(math.pi * xy[..., 0:1]) * torch.sin(math.pi * xy[..., 1:2])
pde = Poisson2D(domain, true_solution)

# Data Generation
X_train, Y_train = pde.data_generation(500, data_noise)
X_test, Y_test = pde.data_generation(100, data_noise)
X_calibration, Y_calibration = pde.data_generation(200, data_noise)
colloc_pt_num = 100

# Alphas
alphas = torch.linspace(0.0, 1.0, 10)[1:-1].view(-1, 1)


# Model
hmc_model = HMCBPINN(
    pde_class=pde, input_dim=2, hidden_dims=[16, 32, 64, 64, 64, 32, 16],
    output_dim=1, act_func=nn.Tanh, prior_std=1.0,
    step_size=1e-3, leapfrog_steps=5
)

# Fitting args
fit_args = {
    "coloc_pt_num": colloc_pt_num,
    "X_train": X_train,
    "Y_train": Y_train
}
fit_kwargs_grid = {
    "λ_pde": 3.0,
    "λ_bc": 5.0,
    "λ_data": 1.0,
    "epochs": [5000],
    "lr":[1e-3],
    "hmc_samples": [12000],
    "brun_in":[5000],
    "step_size": [5e-4],
    "leapfrog_steps": [13],
}

baseline_pred_kwargs = { "n_samples": 5000 }
cp_pred_kwargs = {
    "X_train": X_train, "Y_train": Y_train,
    "X_cal": X_calibration, "Y_cal": Y_calibration,
    "heuristic_u": "raw_std", "k": 10
}
cp_coverage_args = {
    "alphas": alphas, "X_test": X_test, "Y_test": Y_test,
    "X_cal": X_calibration, "Y_cal": Y_calibration,
    "X_train": X_train, "Y_train": Y_train,
    "heuristic_u": "raw_std", "k": 10
}
baseline_coverage_args = {
    # "uqmodel": hmc_model,
    "alphas": alphas,
    "X_test": X_test,
    "Y_test": Y_test,
    "n_samples": 5000
}

# Test grid
n_grid = 100
x = torch.linspace(0, 1, n_grid)
y = torch.linspace(0, 1, n_grid)
X, Y = torch.meshgrid(x, y, indexing='xy')
grid_test = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1)], dim=1)

# Run hyperparameter tuning
hyperparameter_tuning(
    plot_title="HMC CP Model",
    uqmodel=hmc_model, alpha=0.05,
    X_test=grid_test, Y_test=Y_test,
    fit_args=fit_args, fit_kwargs_grid=fit_kwargs_grid,
    baseline_pred_kwargs=baseline_pred_kwargs,
    cp_pred_kwargs=cp_pred_kwargs,
    true_solution=pde.true_solution,
    baseline_testing_args=baseline_coverage_args,
    cp_testing_args=cp_coverage_args,
    baseline_test_uncertainties=hmc_test_uncertainties,
    plotting_func=plot_2D_comparison_with_coverage,
    save_dir="2dpoisson_hmc_cp",
    X_validation=X_test, Y_validation=Y_test
)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu
Using device: cpu

[🔎] Trying: {'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_data': 1.0, 'epochs': 5000, 'lr': 0.001, 'hmc_samples': 12000, 'brun_in': 5000, 'step_size': 0.0005, 'leapfrog_steps': 13}

[🟠] Training...


MAP:  10%|█         | 515/5000 [00:03<00:31, 141.80it/s, loss=2.74e+01]

[MAP] epoch    500 −logPost=2.926e+01  Data=8.55e-01  PDE=1.34e+01  IC=0.00e+00  BC=1.79e+00


MAP:  20%|██        | 1015/5000 [00:07<00:28, 139.70it/s, loss=2.10e+01]

[MAP] epoch   1000 −logPost=2.066e+01  Data=7.57e-01  PDE=6.95e+00  IC=0.00e+00  BC=1.67e+00


MAP:  30%|███       | 1525/5000 [00:11<00:25, 134.86it/s, loss=1.89e+01]

[MAP] epoch   1500 −logPost=2.031e+01  Data=7.87e-01  PDE=7.21e+00  IC=0.00e+00  BC=1.78e+00


MAP:  41%|████      | 2027/5000 [00:14<00:21, 136.65it/s, loss=1.98e+01]

[MAP] epoch   2000 −logPost=1.865e+01  Data=7.24e-01  PDE=6.22e+00  IC=0.00e+00  BC=1.60e+00


MAP:  50%|█████     | 2514/5000 [00:18<00:18, 133.89it/s, loss=1.77e+01]

[MAP] epoch   2500 −logPost=1.845e+01  Data=6.94e-01  PDE=6.21e+00  IC=0.00e+00  BC=1.54e+00


MAP:  60%|██████    | 3019/5000 [00:22<00:14, 136.51it/s, loss=1.79e+01]

[MAP] epoch   3000 −logPost=1.729e+01  Data=6.51e-01  PDE=5.19e+00  IC=0.00e+00  BC=1.45e+00


MAP:  70%|███████   | 3511/5000 [00:25<00:11, 130.39it/s, loss=1.70e+01]

[MAP] epoch   3500 −logPost=1.739e+01  Data=5.14e-01  PDE=5.78e+00  IC=0.00e+00  BC=1.13e+00


MAP:  80%|████████  | 4019/5000 [00:29<00:07, 138.15it/s, loss=1.28e+01]

[MAP] epoch   4000 −logPost=1.275e+01  Data=1.42e-01  PDE=1.38e+00  IC=0.00e+00  BC=2.22e-01


MAP:  90%|█████████ | 4525/5000 [00:33<00:03, 139.60it/s, loss=1.21e+01]

[MAP] epoch   4500 −logPost=1.199e+01  Data=8.82e-02  PDE=4.26e-01  IC=0.00e+00  BC=1.92e-01


                                                                        

[MAP] epoch   5000 −logPost=1.172e+01  Data=4.99e-02  PDE=3.02e-01  IC=0.00e+00  BC=1.08e-01


HMC:   4%|▍         | 502/12000 [00:41<15:32, 12.34it/s, acc=0.76]

[HMC] iter    500  acc-rate=0.76


HMC:   8%|▊         | 1002/12000 [01:21<13:30, 13.57it/s, acc=0.75]

[HMC] iter   1000  acc-rate=0.75


HMC:  13%|█▎        | 1502/12000 [01:58<13:00, 13.46it/s, acc=0.71]

[HMC] iter   1500  acc-rate=0.71


HMC:  17%|█▋        | 2002/12000 [02:36<12:35, 13.23it/s, acc=0.69]

[HMC] iter   2000  acc-rate=0.69


HMC:  21%|██        | 2502/12000 [03:13<11:32, 13.72it/s, acc=0.67]

[HMC] iter   2500  acc-rate=0.67


HMC:  25%|██▌       | 3002/12000 [03:50<11:13, 13.36it/s, acc=0.67]

[HMC] iter   3000  acc-rate=0.67


HMC:  29%|██▉       | 3502/12000 [04:27<10:17, 13.75it/s, acc=0.66]

[HMC] iter   3500  acc-rate=0.66


HMC:  33%|███▎      | 4002/12000 [05:03<09:41, 13.75it/s, acc=0.66]

[HMC] iter   4000  acc-rate=0.66


HMC:  38%|███▊      | 4502/12000 [05:39<09:09, 13.65it/s, acc=0.65]

[HMC] iter   4500  acc-rate=0.65


HMC:  42%|████▏     | 5000/12000 [06:16<09:28, 12.32it/s, acc=0.65]

[HMC] iter   5000  acc-rate=0.65


HMC:  46%|████▌     | 5502/12000 [06:52<07:44, 13.98it/s, acc=0.65]

[HMC] iter   5500  acc-rate=0.65


HMC:  50%|█████     | 6002/12000 [07:27<06:52, 14.54it/s, acc=0.65]

[HMC] iter   6000  acc-rate=0.65


HMC:  54%|█████▍    | 6502/12000 [08:02<06:13, 14.72it/s, acc=0.64]

[HMC] iter   6500  acc-rate=0.64


HMC:  58%|█████▊    | 7002/12000 [08:36<05:41, 14.62it/s, acc=0.64]

[HMC] iter   7000  acc-rate=0.64


HMC:  63%|██████▎   | 7502/12000 [09:10<05:04, 14.77it/s, acc=0.64]

[HMC] iter   7500  acc-rate=0.64


HMC:  67%|██████▋   | 8002/12000 [09:44<04:33, 14.61it/s, acc=0.63]

[HMC] iter   8000  acc-rate=0.63


HMC:  71%|███████   | 8502/12000 [10:18<03:56, 14.80it/s, acc=0.63]

[HMC] iter   8500  acc-rate=0.63


HMC:  75%|███████▌  | 9002/12000 [10:52<03:28, 14.35it/s, acc=0.63]

[HMC] iter   9000  acc-rate=0.63


HMC:  79%|███████▉  | 9502/12000 [11:27<02:51, 14.53it/s, acc=0.63]

[HMC] iter   9500  acc-rate=0.63


HMC:  83%|████████▎ | 10002/12000 [12:01<02:20, 14.26it/s, acc=0.63]

[HMC] iter  10000  acc-rate=0.63


HMC:  88%|████████▊ | 10502/12000 [12:36<01:42, 14.64it/s, acc=0.62]

[HMC] iter  10500  acc-rate=0.62


HMC:  92%|█████████▏| 11002/12000 [13:10<01:08, 14.60it/s, acc=0.62]

[HMC] iter  11000  acc-rate=0.62


HMC:  96%|█████████▌| 11502/12000 [13:45<00:34, 14.61it/s, acc=0.62]

[HMC] iter  11500  acc-rate=0.62


                                                                    

[HMC] iter  12000  acc-rate=0.62
Finished HMC: avg acceptance 0.616
Kept 11500 posterior samples

[🟠] Base Model Inferencing...

[🟠] CP Model Inferencing...

[🟠] Computing Coverage...


100%|██████████| 8/8 [00:15<00:00,  2.00s/it]
  return func(*args, **kwargs)
100%|██████████| 8/8 [00:47<00:00,  5.93s/it]



[✅] Data Loss = 1.201e-01

[🔎] Trying: {'λ_pde': 3.0, 'λ_bc': 5.0, 'λ_data': 2.0, 'epochs': 5000, 'lr': 0.001, 'hmc_samples': 12000, 'brun_in': 5000, 'step_size': 0.0005, 'leapfrog_steps': 13}

[🟠] Training...


MAP:  10%|█         | 510/5000 [00:05<00:42, 104.70it/s, loss=2.45e+01]

[MAP] epoch    500 −logPost=2.791e+01  Data=9.22e-01  PDE=1.19e+01  IC=0.00e+00  BC=1.95e+00


MAP:  20%|██        | 1014/5000 [00:10<00:39, 101.43it/s, loss=2.14e+01]

[MAP] epoch   1000 −logPost=2.099e+01  Data=8.71e-01  PDE=6.82e+00  IC=0.00e+00  BC=1.93e+00


MAP:  30%|███       | 1513/5000 [00:15<00:34, 100.43it/s, loss=1.90e+01]

[MAP] epoch   1500 −logPost=2.003e+01  Data=7.69e-01  PDE=7.11e+00  IC=0.00e+00  BC=1.73e+00


MAP:  40%|████      | 2015/5000 [00:20<00:30, 99.09it/s, loss=1.96e+01] 

[MAP] epoch   2000 −logPost=1.758e+01  Data=6.92e-01  PDE=5.32e+00  IC=0.00e+00  BC=1.55e+00


MAP:  50%|█████     | 2516/5000 [00:25<00:25, 97.70it/s, loss=1.87e+01] 

[MAP] epoch   2500 −logPost=1.669e+01  Data=6.69e-01  PDE=4.57e+00  IC=0.00e+00  BC=1.48e+00


MAP:  60%|██████    | 3014/5000 [00:30<00:19, 101.53it/s, loss=1.74e+01]

[MAP] epoch   3000 −logPost=1.764e+01  Data=5.78e-01  PDE=5.78e+00  IC=0.00e+00  BC=1.29e+00


MAP:  70%|███████   | 3519/5000 [00:35<00:15, 98.40it/s, loss=1.43e+01] 

[MAP] epoch   3500 −logPost=1.539e+01  Data=3.90e-01  PDE=4.02e+00  IC=0.00e+00  BC=5.90e-01


MAP:  80%|████████  | 4020/5000 [00:40<00:09, 99.47it/s, loss=1.19e+01] 

[MAP] epoch   4000 −logPost=1.197e+01  Data=6.45e-02  PDE=4.37e-01  IC=0.00e+00  BC=1.50e-01


MAP:  90%|█████████ | 4516/5000 [00:45<00:05, 92.83it/s, loss=1.17e+01] 

[MAP] epoch   4500 −logPost=1.164e+01  Data=4.83e-02  PDE=2.22e-01  IC=0.00e+00  BC=1.08e-01


                                                                        

[MAP] epoch   5000 −logPost=1.150e+01  Data=3.60e-02  PDE=2.44e-01  IC=0.00e+00  BC=7.35e-02


HMC:   4%|▍         | 501/12000 [01:01<22:39,  8.46it/s, acc=0.73]

[HMC] iter    500  acc-rate=0.73


HMC:   8%|▊         | 1001/12000 [02:06<25:20,  7.23it/s, acc=0.71]

[HMC] iter   1000  acc-rate=0.71


HMC:  13%|█▎        | 1501/12000 [03:06<21:07,  8.29it/s, acc=0.70]

[HMC] iter   1500  acc-rate=0.70


HMC:  17%|█▋        | 2001/12000 [04:05<18:59,  8.78it/s, acc=0.69]

[HMC] iter   2000  acc-rate=0.69


HMC:  21%|██        | 2501/12000 [05:03<17:41,  8.95it/s, acc=0.67]

[HMC] iter   2500  acc-rate=0.67


HMC:  25%|██▌       | 3001/12000 [05:55<15:20,  9.78it/s, acc=0.66]

[HMC] iter   3000  acc-rate=0.66


HMC:  29%|██▉       | 3501/12000 [06:45<14:06, 10.04it/s, acc=0.65]

[HMC] iter   3500  acc-rate=0.65


HMC:  33%|███▎      | 4001/12000 [07:34<13:26,  9.92it/s, acc=0.65]

[HMC] iter   4000  acc-rate=0.65


HMC:  38%|███▊      | 4501/12000 [08:20<11:56, 10.47it/s, acc=0.65]

[HMC] iter   4500  acc-rate=0.65


HMC:  39%|███▉      | 4663/12000 [08:35<11:28, 10.65it/s, acc=0.65]