In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch_geometric as pyg
import torch_geometric.datasets as pygd
import gpn.experiments.transductive_experiment as exp

2024-01-08 15:05:34.967987: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-08 15:05:35.137076: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-08 15:05:35.137110: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-08 15:05:35.165983: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-08 15:05:35.226876: I tensorflow/core/platform/cpu_feature_guar

In [55]:
from gpn.utils.config import (
    DataConfiguration,
    ModelConfiguration,
    RunConfiguration,
    TrainingConfiguration,
)


def create_experiment(model_name):
    run_cfg = RunConfiguration(
        job="train",
        eval_mode="default",
        experiment_directory="./saved_experiments",
        save_model=True,
        gpu=0,
        experiment_name="ood_loc",
    )
    data_cfg = DataConfiguration(
        dataset="CoraML",
        split_no=1,
        root="./data",
        ood_flag=True,
        train_samples_per_class=0.05,
        val_samples_per_class=0.15,
        test_samples_per_class=0.8,
        split="random",
        ood_setting="poisoning",
        ood_type="leave_out_classes",
        ood_num_left_out_classes=-1,
        ood_leave_out_last_classes=True,
    )

    model_cfg = ModelConfiguration(
        model_name=model_name,
        seed=42,
        init_no=1,
        dim_hidden=64,
        dropout_prob=0.5,
        K=10,
        add_self_loops=True,
        maf_layers=0,
        gaussian_layers=0,
        use_batched_flow=True,
        loss_reduction="sum",
        approximate_reg=True,
        flow_weight_decay=0.0,
        pre_train_mode="flow",
        alpha_evidence_scale="latent-new",
        alpha_teleport=0.1,
        entropy_reg=0.0001,
        dim_latent=16,
        radial_layers=10,
    )

    train_cfg = TrainingConfiguration(
        epochs=100000,
        stopping_mode="default",
        stopping_patience=50,
        stopping_restore_best=True,
        stopping_metric="val_CE",
        stopping_minimize=True,
        finetune_epochs=0,
        warmup_epochs=5,
        lr=0.01,
        weight_decay=0.001,
    )

    return exp.TransductiveExperiment(run_cfg, data_cfg, model_cfg, train_cfg)


gpn_e = create_experiment("GPN")
lop_e = create_experiment("GPN_LOP")

In [56]:
res = gpn_e.run()
res_lop = lop_e.run()

Epoch 1/5:
[2K [Elapsed 0:00:00 | 5.77 it/s] REG: 0.09641, UCE: 243.87514, train_CE: 2.57203, train_ECE: 0.39081, train_accuracy: 0.27160, train_average_entropy: 0.87890, train_avg_prediction_confidence_aleatoric: 0.66241, train_avg_prediction_confidence_epistemic: 594.50482, train_avg_sample_confidence_aleatoric: 0.66241, train_avg_sample_confidence_epistemic: 897.29657, train_avg_sample_confidence_epistemic_entropy: 11.70835, train_avg_sample_confidence_features: 998.38892, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.93140, train_confidence_aleatoric_apr: 0.32734, train_confidence_aleatoric_auroc: 0.47227, train_confidence_epistemic_apr: 0.30802, train_confidence_epistemic_auroc: 0.44068, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 1.93415, val_ECE: 0.66249, val_accuracy: 0.00000, val_average_entropy: 0.87863, val_avg_prediction_confidence_aleatoric: 0.00000, val_avg_prediction_confidence_epistemic: 0.00000, val_avg_sam



[2K [Elapsed 0:00:00 | 11.35 it/s] REG: 0.08337, UCE: 183.93839, train_CE: 1.72700, train_ECE: 0.33839, train_accuracy: 0.27160, train_average_entropy: 1.03268, train_avg_prediction_confidence_aleatoric: 0.61000, train_avg_prediction_confidence_epistemic: 223.71484, train_avg_sample_confidence_aleatoric: 0.61000, train_avg_sample_confidence_epistemic: 366.55356, train_avg_sample_confidence_epistemic_entropy: 8.35763, train_avg_sample_confidence_features: 407.55859, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.91344, train_confidence_aleatoric_apr: 0.32906, train_confidence_aleatoric_auroc: 0.55932, train_confidence_epistemic_apr: 0.32725, train_confidence_epistemic_auroc: 0.55855, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 1.72836, val_ECE: 0.34087, val_accuracy: 0.26923, val_average_entropy: 1.03244, val_avg_prediction_confidence_aleatoric: 0.61010, val_avg_prediction_confidence_epistemic: 228.00153, val_avg_sample_confi



[2K [Elapsed 0:00:00 | 26.25 it/s] REG: 0.11866, UCE: 13.60091, train_CE: 0.09328, train_ECE: 0.08384, train_accuracy: 1.00000, train_average_entropy: 0.32951, train_avg_prediction_confidence_aleatoric: 0.91616, train_avg_prediction_confidence_epistemic: 6832.29004, train_avg_sample_confidence_aleatoric: 0.91616, train_avg_sample_confidence_epistemic: 7327.96045, train_avg_sample_confidence_epistemic_entropy: 18.45831, train_avg_sample_confidence_features: 15708.31641, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.10093, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.39328, val_ECE: 0.08333, val_accuracy: 0.90171, val_average_entropy: 0.54715, val_avg_prediction_confidence_aleatoric: 0.81838, val_avg_prediction_confidence_epistemic: 3837.19653, val_avg_sample_confiden



[2K [Elapsed 0:00:00 | 24.77 it/s] REG: 0.13104, UCE: 11.29425, train_CE: 0.08354, train_ECE: 0.07349, train_accuracy: 1.00000, train_average_entropy: 0.27238, train_avg_prediction_confidence_aleatoric: 0.92651, train_avg_prediction_confidence_epistemic: 14439.50293, train_avg_sample_confidence_aleatoric: 0.92651, train_avg_sample_confidence_epistemic: 15212.35938, train_avg_sample_confidence_epistemic_entropy: 20.62238, train_avg_sample_confidence_features: 36926.83203, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.09114, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.39970, val_ECE: 0.05796, val_accuracy: 0.88462, val_average_entropy: 0.49765, val_avg_prediction_confidence_aleatoric: 0.83340, val_avg_prediction_confidence_epistemic: 7159.29541, val_avg_sample_confid



[2K [Elapsed 0:00:00 | 25.70 it/s] REG: 0.13806, UCE: 9.35850, train_CE: 0.06101, train_ECE: 0.05683, train_accuracy: 1.00000, train_average_entropy: 0.24032, train_avg_prediction_confidence_aleatoric: 0.94317, train_avg_prediction_confidence_epistemic: 19257.99219, train_avg_sample_confidence_aleatoric: 0.94317, train_avg_sample_confidence_epistemic: 20156.29883, train_avg_sample_confidence_epistemic_entropy: 21.78674, train_avg_sample_confidence_features: 53325.06641, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.06897, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.43873, val_ECE: 0.06734, val_accuracy: 0.88034, val_average_entropy: 0.49188, val_avg_prediction_confidence_aleatoric: 0.83181, val_avg_prediction_confidence_epistemic: 9029.19531, val_avg_sample_confide



[2K [Elapsed 0:00:00 | 23.35 it/s] REG: 0.14960, UCE: 8.42258, train_CE: 0.07416, train_ECE: 0.06520, train_accuracy: 1.00000, train_average_entropy: 0.25692, train_avg_prediction_confidence_aleatoric: 0.93480, train_avg_prediction_confidence_epistemic: 24443.14258, train_avg_sample_confidence_aleatoric: 0.93480, train_avg_sample_confidence_epistemic: 25451.14844, train_avg_sample_confidence_epistemic_entropy: 21.94646, train_avg_sample_confidence_features: 75642.04688, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.07792, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.45334, val_ECE: 0.05365, val_accuracy: 0.86752, val_average_entropy: 0.51936, val_avg_prediction_confidence_aleatoric: 0.82271, val_avg_prediction_confidence_epistemic: 9320.96777, val_avg_sample_confide



[2K [Elapsed 0:00:00 | 25.71 it/s] REG: 0.15141, UCE: 7.47583, train_CE: 0.06167, train_ECE: 0.05674, train_accuracy: 1.00000, train_average_entropy: 0.22256, train_avg_prediction_confidence_aleatoric: 0.94326, train_avg_prediction_confidence_epistemic: 38008.21094, train_avg_sample_confidence_aleatoric: 0.94326, train_avg_sample_confidence_epistemic: 39605.09766, train_avg_sample_confidence_epistemic_entropy: 23.41898, train_avg_sample_confidence_features: 111668.74219, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.07229, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.42290, val_ECE: 0.05916, val_accuracy: 0.88462, val_average_entropy: 0.48989, val_avg_prediction_confidence_aleatoric: 0.82608, val_avg_prediction_confidence_epistemic: 13434.61719, val_avg_sample_confi



[2K [Elapsed 0:00:00 | 24.63 it/s] REG: 0.15344, UCE: 6.04379, train_CE: 0.06397, train_ECE: 0.05299, train_accuracy: 0.98765, train_average_entropy: 0.22235, train_avg_prediction_confidence_aleatoric: 0.94596, train_avg_prediction_confidence_epistemic: 27087.54102, train_avg_sample_confidence_aleatoric: 0.94596, train_avg_sample_confidence_epistemic: 28150.82031, train_avg_sample_confidence_epistemic_entropy: 22.52510, train_avg_sample_confidence_features: 81402.82812, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.06750, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: 1.00000, train_confidence_epistemic_apr: 0.99984, train_confidence_epistemic_auroc: 0.98750, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.44322, val_ECE: 0.08908, val_accuracy: 0.88889, val_average_entropy: 0.52647, val_avg_prediction_confidence_aleatoric: 0.81512, val_avg_prediction_confidence_epistemic: 9387.24316, val_avg_sample



[2K [Elapsed 0:00:00 | 26.04 it/s] REG: 0.15888, UCE: 7.91593, train_CE: 0.05581, train_ECE: 0.05169, train_accuracy: 1.00000, train_average_entropy: 0.21454, train_avg_prediction_confidence_aleatoric: 0.94831, train_avg_prediction_confidence_epistemic: 32293.94727, train_avg_sample_confidence_aleatoric: 0.94831, train_avg_sample_confidence_epistemic: 33517.10938, train_avg_sample_confidence_epistemic_entropy: 23.09034, train_avg_sample_confidence_features: 106469.92969, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.06406, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.42477, val_ECE: 0.07300, val_accuracy: 0.88462, val_average_entropy: 0.53059, val_avg_prediction_confidence_aleatoric: 0.81161, val_avg_prediction_confidence_epistemic: 9490.39746, val_avg_sample_confid



[2K [Elapsed 0:00:00 | 24.83 it/s] REG: 0.15435, UCE: 6.95597, train_CE: 0.05149, train_ECE: 0.04626, train_accuracy: 1.00000, train_average_entropy: 0.18892, train_avg_prediction_confidence_aleatoric: 0.95374, train_avg_prediction_confidence_epistemic: 57643.08594, train_avg_sample_confidence_aleatoric: 0.95374, train_avg_sample_confidence_epistemic: 59164.88281, train_avg_sample_confidence_epistemic_entropy: 24.84934, train_avg_sample_confidence_features: 152444.96875, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.05626, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.43783, val_ECE: 0.05750, val_accuracy: 0.86325, val_average_entropy: 0.46400, val_avg_prediction_confidence_aleatoric: 0.83784, val_avg_prediction_confidence_epistemic: 21536.26758, val_avg_sample_confi



[2K [Elapsed 0:00:00 | 26.20 it/s] REG: 0.16046, UCE: 5.68627, train_CE: 0.04424, train_ECE: 0.04168, train_accuracy: 1.00000, train_average_entropy: 0.18240, train_avg_prediction_confidence_aleatoric: 0.95832, train_avg_prediction_confidence_epistemic: 48014.83984, train_avg_sample_confidence_aleatoric: 0.95832, train_avg_sample_confidence_epistemic: 49596.59375, train_avg_sample_confidence_epistemic_entropy: 24.62757, train_avg_sample_confidence_features: 140164.34375, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.05109, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.45542, val_ECE: 0.08993, val_accuracy: 0.88889, val_average_entropy: 0.50084, val_avg_prediction_confidence_aleatoric: 0.81905, val_avg_prediction_confidence_epistemic: 17287.01367, val_avg_sample_confi



Epoch 94/100000:
[2K [Elapsed 0:00:00 | 25.11 it/s] REG: 0.16070, UCE: 4.95118, train_CE: 0.04921, train_ECE: 0.04624, train_accuracy: 1.00000, train_average_entropy: 0.20051, train_avg_prediction_confidence_aleatoric: 0.95376, train_avg_prediction_confidence_epistemic: 46962.58984, train_avg_sample_confidence_aleatoric: 0.95376, train_avg_sample_confidence_epistemic: 48287.36328, train_avg_sample_confidence_epistemic_entropy: 23.84192, train_avg_sample_confidence_features: 134335.59375, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.05639, train_confidence_aleatoric_apr: 1.00000, train_confidence_aleatoric_auroc: nan, train_confidence_epistemic_apr: 1.00000, train_confidence_epistemic_auroc: nan, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 0.39742, val_ECE: 0.07427, val_accuracy: 0.89316, val_average_entropy: 0.51060, val_avg_prediction_confidence_aleatoric: 0.82233, val_avg_prediction_confidence_epistemic: 13661.36328, val



Epoch 1/5:
[2K [Elapsed 0:00:00 | 8.05 it/s] REG: 0.04871, UCE: 222.49030, train_CE: 2.74902, train_ECE: 0.31406, train_accuracy: 0.27160, train_average_entropy: 0.87750, train_avg_prediction_confidence_aleatoric: 0.59541, train_avg_prediction_confidence_epistemic: 594.40326, train_avg_sample_confidence_aleatoric: 0.59541, train_avg_sample_confidence_epistemic: 896.89032, train_avg_sample_confidence_epistemic_entropy: 5.83342, train_avg_sample_confidence_features: 998.38892, train_avg_sample_confidence_neighborhood: nan, train_brier_score: 0.94404, train_confidence_aleatoric_apr: 0.30826, train_confidence_aleatoric_auroc: 0.44145, train_confidence_epistemic_apr: 0.30802, train_confidence_epistemic_auroc: 0.44068, train_confidence_structure_apr: nan, train_confidence_structure_auroc: nan, val_CE: 2.70493, val_ECE: 0.32492, val_accuracy: 0.27350, val_average_entropy: 0.87747, val_avg_prediction_confidence_aleatoric: 0.60697, val_avg_prediction_confidence_epistemic: 605.92322, val_avg_sa

In [54]:
import pandas as pd
from gpn.utils.utils import results_dict_to_df

df_lop = results_dict_to_df(res_lop)
df = results_dict_to_df(res)

pd.DataFrame(dict(
    val_gpn=df["val"],
    test_gpn=df["test"],
    val_lop=df_lop["val"],
    test_lop=df_lop["test"]
))

Unnamed: 0,val_gpn,test_gpn,val_lop,test_lop
accuracy,0.884615,0.889306,0.897436,0.891182
brier_score,0.280006,0.253266,0.409283,0.40185
ECE,0.047523,0.051982,0.196041,0.200875
confidence_aleatoric_apr,0.964445,0.982131,0.950391,0.979077
confidence_epistemic_apr,0.957138,0.967346,0.950205,0.98323
confidence_structure_apr,,,,
confidence_aleatoric_auroc,0.778136,0.871925,0.717262,0.857967
confidence_epistemic_auroc,0.745393,0.79526,0.716071,0.88353
confidence_structure_auroc,,,,
CE,0.420443,0.339508,0.528332,0.497391


In [None]:
a = torch.tensor([[2, 8], [20, 80], [200, 800]], dtype=torch.float32)

w = torch.tensor([[0.5, 0.5, 0], [0, 1, 0], [0, 0.5, 0.5]])

y = torch.tensor([1, 0, 1])

mix_sum = w @ a.sum(-1)

(w @ a), (w @ a).gather(-1, y.view(-1, 1)).squeeze(-1)