In [1]:
import os

go_up_n_directories = lambda path, n: os.path.abspath(
    os.path.join(*([os.path.dirname(path)] + [".."] * n))
)
os.chdir(go_up_n_directories(os.getcwd(), 3))  # run once (otherwise restart kernel)

In [2]:
# DEPENDENCIES
# Python native
import functools

# Data handling
import ocpa.algo.predictive_monitoring.factory as feature_factory

# PyG
import torch
import torch.nn.functional as F
import torch.optim as O
import torch_geometric.transforms as T

import utilities.torch_utils
from models.definitions.geometric_models import HigherOrderGNN
from utilities import hetero_data_utils, hetero_experiment_utils

# Print system info
utilities.torch_utils.print_system_info()
utilities.torch_utils.print_torch_info()

# INITIAL CONFIGURATION
cs_hoeg_config = {
    "model_output_path": "models/CS/hoeg",
    "STORAGE_PATH": "data/CS/feature_encodings/HOEG/hoeg",
    "SPLIT_FEATURE_STORAGE_FILE": "CS_split_[C2_P2_P3_O3_eas].fs",
    "OBJECTS_DATA_DICT": "cs_ofg+oi_graph+krs_node_map+krv_node_map+cv_node_map.pkl",
    "events_target_label": (feature_factory.EVENT_REMAINING_TIME, ()),
    "objects_target_label": "@@object_lifecycle_duration",
    "graph_level_target": False,
    "regression_task": True,
    "target_node_type": "event",
    "object_types": ["krs", "krv", "cv"],
    "meta_data": (
        ["event", "krs", "krv", "cv"],
        [
            ("event", "follows", "event"),
            ("krs", "interacts", "event"),
            ("krv", "interacts", "event"),
            ("cv", "interacts", "event"),
        ],
    ),
    "BATCH_SIZE": 16,
    "RANDOM_SEED": 42,
    "EPOCHS": 30,
    "early_stopping": 4,
    "optimizer": O.Adam,
    "optimizer_settings": {
        "lr": 0.001,
        "betas": (0.9, 0.999),
        "eps": 1e-08,
        "weight_decay": 0,
        "amsgrad": False,
    },
    "loss_fn": torch.nn.L1Loss(),
    "verbose": True,
    "skip_cache": False,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "squeeze": True,
    "track_time": True,
}

# CONFIGURATION ADAPTATIONS may be set here
# cs_hoeg_config["device"] = torch.device("cpu")
# cs_hoeg_config['skip_cache'] = True

CPU: Intel(R) Core(TM) i5-7500 CPU @ 3.40GHz (4x)
Total CPU memory: 46.93GB
Available CPU memory: 29.55GB
GPU: NVIDIA GeForce GTX 960
Total GPU memory: 4096.0MB
Available GPU memory: 3172.0MB
Platform: Linux-6.2.0-31-generic-x86_64-with-glibc2.35
Torch version: 1.13.1+cu117
Cuda available: True
Torch geometric version: 2.3.1


In [3]:
# DATA PREPARATION
transformations = [
    hetero_data_utils.ToUndirected(
        exclude_edge_types=[("event", "follows", "event")]
    ),  #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    hetero_data_utils.AddObjectSelfLoops(),  # Prepares object-object relations, which are filled when `T.AddSelfLoops()` is executed
    T.AddSelfLoops(),  # Add self-loops to the graph
    T.NormalizeFeatures(),  # Normalize node features of the graph
]
# Get data and dataloaders
ds_train, ds_val, ds_test = hetero_data_utils.load_hetero_datasets(
    storage_path=cs_hoeg_config["STORAGE_PATH"],
    split_feature_storage_file=cs_hoeg_config["SPLIT_FEATURE_STORAGE_FILE"],
    objects_data_file=cs_hoeg_config["OBJECTS_DATA_DICT"],
    event_node_label_key=cs_hoeg_config["events_target_label"],
    object_nodes_label_key=cs_hoeg_config["objects_target_label"],
    edge_types=cs_hoeg_config["meta_data"][1],
    object_node_types=cs_hoeg_config["object_types"],
    graph_level_target=cs_hoeg_config["graph_level_target"],
    transform=T.Compose(transformations),
    train=True,
    val=True,
    test=True,
    skip_cache=cs_hoeg_config["skip_cache"],
    debug=True,
)
for data in ds_val:
    data: hetero_data_utils.HeteroData
    if data.metadata() != cs_hoeg_config["meta_data"]:
        cs_hoeg_config["meta_data"] = data.metadata()
        break
(
    train_loader,
    val_loader,
    test_loader,
) = hetero_data_utils.hetero_dataloaders_from_datasets(
    batch_size=cs_hoeg_config["BATCH_SIZE"],
    ds_train=ds_train,
    ds_val=ds_val,
    ds_test=ds_test,
    num_workers=3,
    seed_worker=functools.partial(
        utilities.torch_utils.seed_worker, state=cs_hoeg_config["RANDOM_SEED"]
    ),
    generator=torch.Generator().manual_seed(cs_hoeg_config["RANDOM_SEED"]),
)

In [5]:
hetero_data_utils.print_hetero_dataset_summaries(ds_train, ds_val, ds_test)

Train set


  0%|          | 0/15638 [00:00<?, ?it/s]

100%|██████████| 15638/15638 [00:18<00:00, 843.70it/s]


HOEG (#graphs=15638):
+------------+----------+----------+
|            |   #nodes |   #edges |
|------------+----------+----------|
| mean       |     25.3 |     70.9 |
| std        |     13.1 |     39.4 |
| min        |      7   |     16   |
| quantile25 |     16   |     43   |
| median     |     23   |     64   |
| quantile75 |     32   |     91   |
| max        |    161   |    478   |
+------------+----------+----------+ 

Validation set
HOEG (#graphs=6255):
+------------+----------+----------+
|            |   #nodes |   #edges |
|------------+----------+----------|
| mean       |     25.4 |     71.1 |
| std        |     13.2 |     39.5 |
| min        |      7   |     16   |
| quantile25 |     16   |     43   |
| median     |     23   |     64   |
| quantile75 |     32   |     91   |
| max        |    146   |    433   |
+------------+----------+----------+ 

Test set
HOEG (#graphs=9384):
+------------+----------+----------+
|            |   #nodes |   #edges |
|------------+------

In [7]:
invalid_train_graphs = hetero_data_utils.validate_cs_hoeg_dataset(
    ds_train, verbose=False
)
invalid_val_graphs = hetero_data_utils.validate_cs_hoeg_dataset(ds_val, verbose=False)
invalid_test_graphs = hetero_data_utils.validate_cs_hoeg_dataset(ds_test, verbose=False)

Event node indices valid in all batches for edge type event-event:  True
KRS node indices valid in all batches for edge type krs-event:  True
KRV node indices valid in all batches for edge type krv-event:  True
CV node indices valid in all batches for edge type cv-event:  True
Event node indices valid in all batches for edge type krs-event:  True
Event node indices valid in all batches for edge type krv-event:  True
Event node indices valid in all batches for edge type cv-event:  True
KRS node indices valid in all batches for edge type krs-krs:  True
KRV node indices valid in all batches for edge type krv-krv:  True
CV node indices valid in all batches for edge type cv-cv:  True
HOEG dataset valid:  True


In [4]:
cs_hoeg_config["verbose"] = True
cs_hoeg_config["squeeze"] = True
cs_hoeg_config["model_output_path"] = "models/CS/hoeg/exp_information_leakage"
cs_hoeg_config["device"] = torch.device("cpu")

# lr_range = [0.01, 0.001]
# hidden_dim_range = [8, 16, 24, 32, 48, 64, 128, 256]
lr_range = [0.001]
hidden_dim_range = [64]
for lr in lr_range:
    for hidden_dim in hidden_dim_range:
        hetero_experiment_utils.run_hoeg_experiment_configuration(
            HigherOrderGNN,
            lr=lr,
            hidden_dim=hidden_dim,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            hoeg_config=cs_hoeg_config,
        )


lr=0.001, hidden_dim=64:
Training started, progress available in Tensorboard
EPOCH 0:


100it [00:54,  5.57it/s]

  batch 100 loss: 0.7125392013788223


200it [01:13,  3.54it/s]

  batch 200 loss: 0.7612828561663627


300it [01:34,  6.64it/s]

  batch 300 loss: 0.7427035737037658


400it [02:02,  4.95it/s]

  batch 400 loss: 0.7265765404701233


499it [02:33,  2.84it/s]

  batch 500 loss: 0.7559178838133812


601it [02:54,  7.49it/s]

  batch 600 loss: 0.7314078605175018


700it [03:22,  5.87it/s]

  batch 700 loss: 0.7047266858816147


800it [04:15,  1.80it/s]

  batch 800 loss: 0.6987791302800178


900it [05:38,  1.49s/it]

  batch 900 loss: 0.6885481190681457


1000it [07:10,  1.01s/it]

  batch 1000 loss: 0.7327316254377365


1100it [08:30,  2.64it/s]

  batch 1100 loss: 0.7191672164201737


1200it [09:49,  1.37it/s]

  batch 1200 loss: 0.6891599503159523


1300it [11:27,  1.36s/it]

  batch 1300 loss: 0.7218409541249275


1369it [12:38,  1.81it/s]


Epoch loss -> train: 0.7218409541249275 valid: 0.6995423436164856
EPOCH 1:


101it [01:10,  1.22it/s]

  batch 100 loss: 0.7166473907232285


203it [01:42, 19.38it/s]

  batch 200 loss: 0.7149485379457474


303it [01:48, 19.68it/s]

  batch 300 loss: 0.7112326177954674


401it [01:53, 16.64it/s]

  batch 400 loss: 0.7008974525332451


501it [01:59, 16.10it/s]

  batch 500 loss: 0.7022838288545609


603it [02:05, 17.05it/s]

  batch 600 loss: 0.6942991617321969


703it [02:10, 18.87it/s]

  batch 700 loss: 0.7235970465838909


803it [02:16, 19.84it/s]

  batch 800 loss: 0.6753491899371147


900it [02:22, 14.35it/s]

  batch 900 loss: 0.7050062882900238


1003it [02:28, 17.25it/s]

  batch 1000 loss: 0.7175086337327957


1101it [02:33, 17.26it/s]

  batch 1100 loss: 0.6908260059356689


1202it [02:39, 16.20it/s]

  batch 1200 loss: 0.6985468453168869


1303it [02:45, 17.71it/s]

  batch 1300 loss: 0.725700671672821


1369it [02:48,  8.11it/s]


Epoch loss -> train: 0.725700671672821 valid: 0.7058811783790588
EPOCH 2:


100it [00:06, 16.43it/s]

  batch 100 loss: 0.725783534348011


200it [00:11, 17.20it/s]

  batch 200 loss: 0.7137933406233787


300it [00:17, 17.40it/s]

  batch 300 loss: 0.6879107141494751


400it [00:23, 17.33it/s]

  batch 400 loss: 0.6743107134103775


500it [00:29, 16.88it/s]

  batch 500 loss: 0.7331976807117462


600it [00:35, 17.19it/s]

  batch 600 loss: 0.7081045719981194


700it [00:40, 16.87it/s]

  batch 700 loss: 0.7028074851632118


800it [00:46, 17.03it/s]

  batch 800 loss: 0.722780488729477


900it [00:52, 16.89it/s]

  batch 900 loss: 0.6978469225764274


1000it [00:58, 16.47it/s]

  batch 1000 loss: 0.6759715043008327


1100it [01:04, 17.21it/s]

  batch 1100 loss: 0.6991880583763123


1200it [01:10, 16.71it/s]

  batch 1200 loss: 0.7118703511357307


1300it [01:16, 17.10it/s]

  batch 1300 loss: 0.7202801504731178


1369it [01:20, 16.99it/s]


Epoch loss -> train: 0.7202801504731178 valid: 0.7147231101989746
EPOCH 3:


100it [00:06, 16.13it/s]

  batch 100 loss: 0.7377018070220948


200it [00:12, 17.00it/s]

  batch 200 loss: 0.7240927717089654


300it [00:18, 16.53it/s]

  batch 300 loss: 0.7011372384428978


400it [00:24, 16.59it/s]

  batch 400 loss: 0.6729877126216889


500it [00:30, 17.05it/s]

  batch 500 loss: 0.6596620774269104


600it [00:36, 16.92it/s]

  batch 600 loss: 0.6746262875199318


700it [00:42, 16.52it/s]

  batch 700 loss: 0.69354766741395


800it [00:48, 16.64it/s]

  batch 800 loss: 0.6787648555636406


900it [00:53, 16.89it/s]

  batch 900 loss: 0.7147163194417954


1000it [01:00, 16.66it/s]

  batch 1000 loss: 0.7334641972184182


1100it [01:06, 16.39it/s]

  batch 1100 loss: 0.7063703691959381


1200it [01:12, 16.86it/s]

  batch 1200 loss: 0.6891968715190887


1300it [01:18, 16.41it/s]

  batch 1300 loss: 0.6943979361653327


1369it [01:22, 16.55it/s]


Epoch loss -> train: 0.6943979361653327 valid: 0.6948264241218567
EPOCH 4:


100it [00:09, 10.82it/s]

  batch 100 loss: 0.7024273338913918


200it [00:17, 12.07it/s]

  batch 200 loss: 0.6831239432096481


300it [00:26, 11.73it/s]

  batch 300 loss: 0.7294775426387787


400it [00:34, 12.29it/s]

  batch 400 loss: 0.7315550059080124


500it [00:42, 12.19it/s]

  batch 500 loss: 0.7282255923748017


600it [00:50, 11.92it/s]

  batch 600 loss: 0.6757489088177681


700it [00:58, 12.36it/s]

  batch 700 loss: 0.6764918678998947


800it [01:08, 11.52it/s]

  batch 800 loss: 0.7001362469792366


900it [01:15, 13.05it/s]

  batch 900 loss: 0.6987373575568199


1000it [01:23, 12.38it/s]

  batch 1000 loss: 0.6692554694414139


1100it [01:31, 13.53it/s]

  batch 1100 loss: 0.6792632532119751


1202it [02:12,  5.50it/s]

  batch 1200 loss: 0.6861340251564979


1300it [02:36, 10.20it/s]

  batch 1300 loss: 0.7154281041026116


1369it [02:42,  8.43it/s]


Epoch loss -> train: 0.7154281041026116 valid: 0.6996454000473022
EPOCH 5:


100it [00:18,  5.84it/s]

  batch 100 loss: 0.7135171085596085


201it [00:49, 14.34it/s]

  batch 200 loss: 0.6958935567736626


300it [01:25, 10.34it/s]

  batch 300 loss: 0.6956495660543441


400it [01:41,  2.25it/s]

  batch 400 loss: 0.7126031064987183


501it [02:07,  2.77it/s]

  batch 500 loss: 0.6593978184461594


601it [02:41, 11.18it/s]

  batch 600 loss: 0.686978123486042


702it [02:57, 12.85it/s]

  batch 700 loss: 0.6958428448438645


803it [03:05, 14.16it/s]

  batch 800 loss: 0.6944805815815925


903it [03:13, 14.26it/s]

  batch 900 loss: 0.7097709149122238


1002it [03:19, 14.15it/s]

  batch 1000 loss: 0.7041550728678704


1101it [03:26, 18.11it/s]

  batch 1100 loss: 0.6817671290040016


1202it [03:34, 13.41it/s]

  batch 1200 loss: 0.6976127058267594


1301it [03:51, 14.84it/s]

  batch 1300 loss: 0.6901265352964401


1369it [03:56,  5.78it/s]


Epoch loss -> train: 0.6901265352964401 valid: 0.6942901015281677
EPOCH 6:


100it [00:08, 11.62it/s]

  batch 100 loss: 0.7086201372742653


200it [00:16, 11.99it/s]

  batch 200 loss: 0.6934648871421814


300it [00:26, 10.50it/s]

  batch 300 loss: 0.6799608394503593


400it [00:44,  7.27it/s]

  batch 400 loss: 0.7129393523931503


500it [00:59,  5.52it/s]

  batch 500 loss: 0.703176434636116


600it [01:35,  3.22it/s]

  batch 600 loss: 0.6886884632706642


700it [01:59,  3.45it/s]

  batch 700 loss: 0.7098058053851127


800it [02:14,  6.37it/s]

  batch 800 loss: 0.6920783007144928


900it [02:46,  3.32it/s]

  batch 900 loss: 0.7018960559368134


1000it [03:14,  3.63it/s]

  batch 1000 loss: 0.6941087731719017


1100it [03:23,  7.69it/s]

  batch 1100 loss: 0.6750176054239273


1200it [03:33,  8.94it/s]

  batch 1200 loss: 0.6869314855337143


1300it [03:41, 11.94it/s]

  batch 1300 loss: 0.686413628757


1369it [03:47,  6.02it/s]


Epoch loss -> train: 0.686413628757 valid: 0.6925864815711975
EPOCH 7:


100it [00:44,  2.03it/s]

  batch 100 loss: 0.7069042780995369


203it [01:15, 10.82it/s]

  batch 200 loss: 0.6653010520339012


300it [01:37,  1.21it/s]

  batch 300 loss: 0.6724893927574158


400it [02:19,  2.76it/s]

  batch 400 loss: 0.7224690619111062


502it [02:53, 14.64it/s]

  batch 500 loss: 0.7105796760320664


600it [04:03,  1.26it/s]

  batch 600 loss: 0.6744599649310112


700it [05:07,  1.49s/it]

  batch 700 loss: 0.7018418005108833


800it [06:23,  2.27it/s]

  batch 800 loss: 0.6948260042071343


900it [07:42,  1.35it/s]

  batch 900 loss: 0.6912414741516113


1000it [08:59,  1.42s/it]

  batch 1000 loss: 0.7173481172323227


1100it [10:52,  2.78it/s]

  batch 1100 loss: 0.701351921260357


1200it [12:31,  1.58s/it]

  batch 1200 loss: 0.6432544440031052


1300it [16:09,  2.54s/it]

  batch 1300 loss: 0.6978758531808853


1369it [18:37,  1.23it/s]


Epoch loss -> train: 0.6978758531808853 valid: 0.6949520111083984
EPOCH 8:


100it [02:51,  1.56s/it]

  batch 100 loss: 0.6974577471613884


200it [05:25,  1.18it/s]

  batch 200 loss: 0.7147616466879845


300it [07:34,  1.58s/it]

  batch 300 loss: 0.6753260526061058


400it [09:51,  1.01s/it]

  batch 400 loss: 0.7151787036657333


502it [11:41,  6.30it/s]

  batch 500 loss: 0.6894723924994469


600it [12:04,  5.25it/s]

  batch 600 loss: 0.6764670222997665


701it [12:23,  6.82it/s]

  batch 700 loss: 0.7092979285120964


800it [12:47,  4.50it/s]

  batch 800 loss: 0.6922963237762452


903it [13:04, 16.40it/s]

  batch 900 loss: 0.6773119708895683


1001it [13:11, 11.31it/s]

  batch 1000 loss: 0.6750611880421639


1101it [13:19, 11.58it/s]

  batch 1100 loss: 0.6919186487793922


1200it [14:26,  1.41it/s]

  batch 1200 loss: 0.6936344295740128


1301it [15:07,  6.51it/s]

  batch 1300 loss: 0.6839825430512428


1369it [15:25,  1.48it/s]


Epoch loss -> train: 0.6839825430512428 valid: 0.6898749470710754
EPOCH 9:


100it [00:22,  4.36it/s]

  batch 100 loss: 0.6847296550869941


200it [00:55,  3.31it/s]

  batch 200 loss: 0.6924620285630226


300it [01:16,  4.19it/s]

  batch 300 loss: 0.672406189441681


400it [01:39,  4.17it/s]

  batch 400 loss: 0.6638946026563645


500it [01:59,  4.75it/s]

  batch 500 loss: 0.6922143957018853


600it [02:33,  3.50it/s]

  batch 600 loss: 0.7128493344783783


700it [02:58,  4.13it/s]

  batch 700 loss: 0.6804750326275826


800it [03:20,  4.19it/s]

  batch 800 loss: 0.7123719984292984


900it [03:55,  3.10it/s]

  batch 900 loss: 0.6984399688243866


1000it [04:28,  2.96it/s]

  batch 1000 loss: 0.697119140625


1100it [04:55,  3.55it/s]

  batch 1100 loss: 0.6704696840047837


1200it [05:28,  2.90it/s]

  batch 1200 loss: 0.7066310107707977


1301it [06:07,  5.03it/s]

  batch 1300 loss: 0.6828699290752411


1369it [06:20,  3.59it/s]


Epoch loss -> train: 0.6828699290752411 valid: 0.6916883587837219
EPOCH 10:


100it [00:15,  6.61it/s]

  batch 100 loss: 0.7070996534824371


200it [00:29,  6.89it/s]

  batch 200 loss: 0.6866196891665459


300it [00:45,  6.41it/s]

  batch 300 loss: 0.665055173933506


400it [01:04,  5.08it/s]

  batch 400 loss: 0.6924335059523582


500it [01:22,  5.83it/s]

  batch 500 loss: 0.6703286221623421


600it [01:35,  7.08it/s]

  batch 600 loss: 0.692269332408905


700it [01:50,  6.75it/s]

  batch 700 loss: 0.7073194572329521


800it [02:07,  6.05it/s]

  batch 800 loss: 0.7237568473815919


900it [02:41,  3.90it/s]

  batch 900 loss: 0.6794071289896965


1000it [02:59,  5.09it/s]

  batch 1000 loss: 0.7228882652521134


1100it [03:17,  5.35it/s]

  batch 1100 loss: 0.6769801479578018


1200it [03:38,  4.73it/s]

  batch 1200 loss: 0.6864669358730316


1301it [04:23,  5.42it/s]

  batch 1300 loss: 0.6610392931103707


1369it [04:35,  4.96it/s]


Epoch loss -> train: 0.6610392931103707 valid: 0.6924932599067688
EPOCH 11:


100it [00:28,  3.32it/s]

  batch 100 loss: 0.6742844846844673


200it [00:48,  4.87it/s]

  batch 200 loss: 0.6845008294284344


300it [01:16,  3.58it/s]

  batch 300 loss: 0.7168490758538246


400it [01:41,  4.25it/s]

  batch 400 loss: 0.6980521394312382


500it [02:55,  1.04s/it]

  batch 500 loss: 0.7070310118794442


600it [05:05,  1.51s/it]

  batch 600 loss: 0.7083250427246094


700it [07:20,  1.08s/it]

  batch 700 loss: 0.6962837305665016


800it [09:10,  1.30s/it]

  batch 800 loss: 0.7051129883527756


900it [11:10,  1.17s/it]

  batch 900 loss: 0.6908095210790635


1000it [13:21,  1.19s/it]

  batch 1000 loss: 0.6714005497097969


1100it [15:31,  1.44s/it]

  batch 1100 loss: 0.6537109047174454


1200it [17:38,  1.38s/it]

  batch 1200 loss: 0.6836677196621895


1300it [19:47,  1.45s/it]

  batch 1300 loss: 0.663544702231884


1369it [21:29,  1.06it/s]


Epoch loss -> train: 0.663544702231884 valid: 0.6950995326042175
EPOCH 12:


100it [00:07, 13.80it/s]

  batch 100 loss: 0.6656770497560501


200it [00:15, 12.73it/s]

  batch 200 loss: 0.7083825233578682


300it [00:22, 13.12it/s]

  batch 300 loss: 0.7133370867371559


400it [00:30, 13.23it/s]

  batch 400 loss: 0.717566454410553


500it [00:38, 12.59it/s]

  batch 500 loss: 0.698592569231987


600it [00:46, 13.22it/s]

  batch 600 loss: 0.6739394447207451


700it [00:54, 12.76it/s]

  batch 700 loss: 0.6809292188286782


800it [01:02, 12.83it/s]

  batch 800 loss: 0.6711753073334694


900it [01:09, 13.46it/s]

  batch 900 loss: 0.6710220941901207


1000it [01:17, 12.79it/s]

  batch 1000 loss: 0.672483344078064


1100it [01:25, 12.85it/s]

  batch 1100 loss: 0.6977643075585366


1200it [01:33, 12.49it/s]

  batch 1200 loss: 0.6867687714099884


1300it [01:41, 13.22it/s]

  batch 1300 loss: 0.6854468557238579


1369it [01:46, 12.82it/s]


Epoch loss -> train: 0.6854468557238579 valid: 0.6951386332511902
Early stopping after 13 epochs.


100%|██████████| 1369/1369 [01:27<00:00, 15.62it/s]
100%|██████████| 294/294 [00:19<00:00, 14.87it/s]
100%|██████████| 294/294 [00:19<00:00, 15.43it/s]

lr=0.001, hidden_dim=64:
    73868 parameters
    1:48:19.088563 H:m:s
    0.6951





In [None]:
x = torch.randint(9, (5, 18))
print(x)
xsize = x.size()
print("xsize:", xsize)
x = x.view(-1, 1)
# print(x)
x.shape

tensor([[5, 0, 4, 4, 0, 1, 5, 1, 2, 0, 0, 6, 7, 5, 2, 2, 6, 0],
        [3, 1, 2, 0, 1, 2, 4, 5, 4, 3, 4, 7, 1, 5, 5, 8, 4, 2],
        [3, 1, 0, 1, 8, 4, 1, 2, 3, 1, 2, 0, 2, 4, 3, 4, 4, 1],
        [8, 8, 5, 4, 8, 1, 6, 4, 7, 5, 7, 1, 2, 3, 0, 2, 2, 5],
        [4, 7, 8, 0, 6, 0, 1, 2, 1, 1, 2, 1, 5, 1, 5, 0, 4, 0]])
xsize: torch.Size([5, 18])


torch.Size([90, 1])

In [55]:
import torch.nn as nn
import torch_geometric.nn as pygnn
from models.definitions.geometric_models import GraphModel

cs_hoeg_config["verbose"] = False
cs_hoeg_config["squeeze"] = False
cs_hoeg_config["pre_forward_view"] = True
cs_hoeg_config["track_time"] = True
cs_hoeg_config["model_output_path"] = "models/CS/hoeg/exp_v2"
cs_hoeg_config["device"] = torch.device("cpu")
# cs_hoeg_config["device"] = torch.device('cuda')


class GNN(GraphModel):
    def __init__(
        self,
        hidden_channels: int = 64,
        out_channels: int = 1,
        pre_forward_view: bool = False,
        squeeze: bool = True,
    ):
        super().__init__()
        self.squeeze = squeeze
        self.conv1 = pygnn.GraphConv(-1, hidden_channels)
        self.act1 = nn.PReLU()
        self.lin_out = pygnn.Linear(-1, out_channels)

    def forward(self, x, edge_index, batch=None):
        x = self.conv1(x, edge_index)
        x = self.act1(x)
        x = self.lin_out(x)
        x = torch.squeeze(x) if self.squeeze else x
        return x


lr_range = [0.01, 0.001]
hidden_dim_range = [8, 16, 24, 32, 48, 64, 128, 256]
for lr in lr_range:
    for hidden_dim in hidden_dim_range:
        hetero_experiment_utils.run_hoeg_experiment_configuration(
            GNN,
            lr=lr,
            hidden_dim=hidden_dim,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            hoeg_config=cs_hoeg_config,
        )


lr=0.01, hidden_dim=8:


  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)
  return F.l1_loss(input, target, reduction=self.reduction)


IndexError: Encountered an index error. Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 15] (got interval [0, 25])

In [None]:
def evaluate_hetero_model(
    target_node_type: str,
    model: GraphModel,
    dataloader: DataLoader,
    metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    device: torch.device = torch.device("cpu"),
    verbose: bool = False,
    squeeze_required: bool = True,
) -> torch.Tensor:
    with torch.no_grad():

        def _eval_batch(batch, model):
            batch_inputs, batch_adjacency_matrix, batch_labels = (
                batch.x_dict,
                batch.edge_index_dict,
                batch[target_node_type].y,
            )
            return (
                model(
                    batch_inputs,
                    edge_index=batch_adjacency_matrix
                    # , batch=batch[target_node_type].batch,
                ),
                batch_labels,
            )

        model.eval()
        model.train(False)
        model.to(device)
        y_preds = torch.tensor([]).to(device)
        y_true = torch.tensor([]).to(device)
        for batch in tqdm(dataloader, disable=not (verbose)):
            batch.to(device)
            batch_y_preds, batch_y_true = _eval_batch(batch, model)
            # append
            y_preds = torch.cat(
                (y_preds, batch_y_preds[target_node_type].view(-1, 25).mean(dim=1))
            )
            y_true = torch.cat((y_true, batch_y_true))
        if squeeze_required:
            y_preds = torch.squeeze(y_preds)
    return metric(y_preds.to(device), y_true.to(device))


def evaluate_best_model(
    target_node_type: str,
    model_state_dict_path: str,
    model: GraphModel,
    metric: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    device: torch.device,
    train_loader: Union[DataLoader, None] = None,
    val_loader: Union[DataLoader, None] = None,
    test_loader: Union[DataLoader, None] = None,
    verbose: bool = True,
    squeeze_required: bool = True,
) -> dict[str, torch.Tensor]:
    best_state_dict = torch.load(model_state_dict_path, map_location=device)

    model.load_state_dict(best_state_dict)
    model.eval()
    kwargs = {
        "target_node_type": target_node_type,
        "model": model,
        "metric": metric,
        "device": device,
        "verbose": verbose,
        "squeeze_required": squeeze_required,
    }
    evaluation = {}
    if train_loader:
        evaluation |= {
            f"Train {metric}": evaluate_hetero_model(dataloader=train_loader, **kwargs)
        }
    if val_loader:
        evaluation |= {
            f"Val {metric}": evaluate_hetero_model(dataloader=val_loader, **kwargs)
        }
    if test_loader:
        evaluation |= {
            f"Test {metric}": evaluate_hetero_model(dataloader=test_loader, **kwargs)
        }
    return evaluation

In [None]:
cs_hoeg_config["track_time"] = True
cs_hoeg_config["verbose"] = True
cs_hoeg_config["squeeze"] = True
cs_hoeg_config["EPOCHS"] = 1
# cs_hoeg_config["device"] = torch.device("cpu")

lr_range = [0.01, 0.001]
hidden_dim_range = [8, 16, 24, 32, 48, 64, 128, 256]
for lr in lr_range:
    for hidden_dim in hidden_dim_range:
        hetero_experiment_utils.run_hoeg_experiment_configuration(
            HeteroHigherOrderGNN,
            lr=lr,
            hidden_dim=hidden_dim,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            hoeg_config=cs_hoeg_config,
        )

## legacy

In [None]:
# MODEL INITIATION
# TODO: try custom Heterogeneous GNN Architecture (without to_hetero())
class HeteroHigherOrderGNN(GraphModel):
    def __init__(
        self,
        hidden_channels: int = 8,
        out_channels: int = 1,
        regression_target: bool = True,
    ):
        super().__init__()
        self.conv1 = pygnn.GraphConv(-1, hidden_channels)
        self.act1 = nn.PReLU()
        self.conv2 = pygnn.GraphConv(-1, hidden_channels, add_self_loops=False)
        self.act2 = nn.PReLU()
        self.lin_out = pygnn.Linear(-1, out_channels)
        # self.probs_out = lambda x: x
        # if not regression_target:
        #     self.probs_out = nn.Softmax(dim=1)

    def forward(self, x, edge_index, batch=None):
        x = x.view(-1, 1)
        x = self.conv1(x, edge_index)
        x = self.act1(x)
        x = self.conv2(x, edge_index)
        x = self.act2(x)
        x = self.lin_out(x)
        # return self.probs_out(x)
        return x


model = HeteroHigherOrderGNN(32, 1, cs_hoeg_config["regression_task"])
model = pygnn.to_hetero(model, cs_hoeg_config["meta_data"])

# Print summary of data and model
cs_hoeg_config["verbose"] = True
if cs_hoeg_config["verbose"]:
    # print(model)
    with torch.no_grad():  # Initialize lazy modules, s.t. we can count its parameters.
        batch = next(iter(train_loader))
        batch.to(cs_hoeg_config["device"])
        model.to(cs_hoeg_config["device"])
        out = model(batch.x_dict, batch.edge_index_dict)
        print(f"Number of parameters: {utilities.torch_utils.count_parameters(model)}")



  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)


Number of parameters: 808


In [None]:
# MODEL TRAINING
print("Training started, progress available in Tensorboard")
torch.cuda.empty_cache()

timestamp = datetime.now().strftime("%Y%m%d_%Hh%Mm")
model_path_base = (
    f"{cs_hoeg_config['model_output_path']}/{str(model).split('(')[0]}_{timestamp}"
)

best_state_dict_path = hetero_training_utils.run_training_hetero(
    target_node_type=cs_hoeg_config["target_node_type"],
    num_epochs=cs_hoeg_config["EPOCHS"],
    model=model,
    train_loader=train_loader,
    validation_loader=val_loader,
    optimizer=O.Adam(model.parameters(), **cs_hoeg_config["optimizer_settings"]),
    loss_fn=cs_hoeg_config["loss_fn"],
    early_stopping_criterion=cs_hoeg_config["early_stopping"],
    model_path_base=model_path_base,
    device=cs_hoeg_config["device"],
    verbose=False,
    squeeze_required=cs_hoeg_config["squeeze_required"],
)

# Write experiment settings as JSON into model path (of the model we've just trained)
with open(os.path.join(model_path_base, "experiment_settings.json"), "w") as file_path:
    json.dump(evaluation_utils.get_json_serializable_dict(cs_hoeg_config), file_path)

Training started, progress available in Tensorboard


  0%|          | 0/5058 [00:00<?, ?it/s]

  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  return F.l1_loss(input, target, reduction=self.reduction)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  return F.l1_loss(input, target, reduction=self.reduction)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  return F.l1_loss(input, target, reduction=self.reduct

Early stopping after 9 epochs.


In [None]:
# MODEL EVALUATION
model_path_base = f"{cs_hoeg_config['model_output_path']}/GraphModule_20230802_13h34m"
state_dict_path = f"{cs_hoeg_config['model_output_path']}/GraphModule_20230802_13h34m/state_dict_epoch2.pt"  # 0.5517 test mae | 15k params
# model_path_base = f"{cs_hoeg_config['model_output_path']}/GraphModule_20230802_16h49m"
# state_dict_path = f"{cs_hoeg_config['model_output_path']}/GraphModule_20230802_16h49m/state_dict_epoch4.pt"  # 0.5514 test mae | 808 params

# Get evaluation results
# evaluation_dict = hetero_evaluation_utils.evaluate_best_model(
evaluation_dict = evaluate_best_model(
    target_node_type=cs_hoeg_config["target_node_type"],
    model_state_dict_path=state_dict_path,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    model=model,
    metric=cs_hoeg_config["loss_fn"],
    device=cs_hoeg_config["device"],
    verbose=cs_hoeg_config["verbose"],
    squeeze_required=cs_hoeg_config["squeeze_required"],
)

# Store model results as JSON into model path
with open(os.path.join(model_path_base, "evaluation_report.json"), "w") as file_path:
    json.dump(evaluation_utils.get_json_serializable_dict(evaluation_dict), file_path)

  0%|          | 0/5058 [00:00<?, ?it/s]

  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 

In [None]:
# Print MAE results
print(model_path_base)
pprint(evaluation_dict)

models/CS/hoeg/GraphModule_20230802_13h34m
{'Test L1Loss()': tensor(0.5517),
 'Train L1Loss()': tensor(0.5572),
 'Val L1Loss()': tensor(0.5563)}
