In [3]:
# Python native
import os

os.chdir("/home/tim/Development/OCPPM/")
import pickle
import random
from copy import copy
from datetime import datetime
from statistics import median as median
from sys import platform
import functools
from typing import Any, Callable

# Data handling
import numpy as np
import pandas as pd

# PyG
import torch
from torch_geometric.loader import DataLoader

# PyTorch TensorBoard support
import torch.utils.tensorboard

# Object centric process mining
from ocpa.algo.predictive_monitoring.obj import Feature_Storage as FeatureStorage
import ocpa.algo.predictive_monitoring.factory as feature_factory

# Custom imports
from loan_application_experiment.feature_encodings.efg.efg import EFG
from loan_application_experiment.feature_encodings.efg.efg_sg import EFG_SG
from utilities import torch_utils
from utilities import data_utils
from utilities import training_utils
from utilities import evaluation_utils

# from importing_ocel import build_feature_storage, load_ocel, pickle_feature_storage
from loan_application_experiment.models.geometric_models import (
    AdamsGCN,
    GraphModel,
    HigherOrderGNN_EFG,
)

bpi_efg_config = {
    "model_output_path": "models/BPI17/efg",
    "STORAGE_PATH": "data/BPI17/feature_encodings/EFG/efg",
    "SPLIT_FEATURE_STORAGE_FILE": "BPI_split_[C2_P2_P3_P5_O3_Action_EventOrigin_OrgResource].fs",
    "TARGET_LABEL": (feature_factory.EVENT_REMAINING_TIME, ()),
    "graph_level_prediction": True,
    "classification_task": True,
    "features_dtype": torch.float32,
    "target_dtype": torch.int64,
    "SUBGRAPH_SIZE": 4,
    "BATCH_SIZE": 64,
    "RANDOM_SEED": 42,
    "EPOCHS": 30,
    "early_stopping": 5,
    "optimizer_settings": {
        "lr": 0.001,
        "betas": (0.9, 0.999),
        "eps": 1e-08,
        "weight_decay": 0,
        "amsgrad": False,
    },
    "loss_fn": torch.nn.L1Loss(),
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "verbose": True,
    "skip_cache": False,
}

In [4]:
# Get data and dataloaders
(
    # ds_train,
    ds_val,
    # ds_test
) = data_utils.load_datasets(
    dataset_class=EFG,
    storage_path=bpi_efg_config["STORAGE_PATH"],
    split_feature_storage_file=bpi_efg_config["SPLIT_FEATURE_STORAGE_FILE"],
    target_label=bpi_efg_config["TARGET_LABEL"],
    graph_level_target=bpi_efg_config["graph_level_prediction"],
    features_dtype=bpi_efg_config["features_dtype"],
    target_dtype=bpi_efg_config["target_dtype"],
    subgraph_size=bpi_efg_config["SUBGRAPH_SIZE"],
    # train=True,
    val=True,
    # test=True,
    skip_cache=bpi_efg_config["skip_cache"],
)
(
    # train_loader,
    val_loader,
    # test_loader
) = data_utils.prepare_dataloaders(
    batch_size=bpi_efg_config["BATCH_SIZE"],
    # ds_train=ds_train,
    ds_val=ds_val,
    # ds_test=ds_test,
    # num_workers=0,
    seed_worker=functools.partial(
        torch_utils.seed_worker, state=bpi_efg_config["RANDOM_SEED"]
    ),
    generator=torch.Generator().manual_seed(bpi_efg_config["RANDOM_SEED"]),
)

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


In [5]:
data_utils.print_dataset_summaries(ds_val=ds_val)

Validation set
EFG (#graphs=4411):
+------------+----------+----------+
|            |   #nodes |   #edges |
|------------+----------+----------|
| mean       |     12.5 |     13.7 |
| std        |      3.6 |      4.5 |
| min        |      6   |      5   |
| quantile25 |     10   |     11   |
| median     |     12   |     13   |
| quantile75 |     14   |     16   |
| max        |     41   |     50   |
+------------+----------+----------+ 



In [38]:
b0 = next(iter(val_loader))


def eval_batch(batch, model):
    batch_inputs, batch_adjacency_matrix, batch_labels = (
        batch.x.float(),
        batch.edge_index,
        batch.y.float(),
    )
    return model(batch_inputs, batch_adjacency_matrix), batch_labels


pd.DataFrame(b0.x).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,17,18,19,20,21,22,23,24,25,26
count,256.0,256.0,256.0,256.0,256.0,256.0,256.0,256.0,256.0,256.0,...,256.0,256.0,256.0,256.0,256.0,256.0,256.0,256.0,256.0,256.0
mean,-0.064036,-0.023751,-0.058362,0.108887,0.220502,-0.067831,0.039286,0.03304,-0.107734,0.065532,...,-0.038242,-0.024978,-0.003184,-0.01388,-0.189062,-0.055596,0.007124,175868.53125,135573.34375,8739.429688
std,0.894721,0.952523,0.641111,1.153635,1.118303,0.0,1.06257,1.035689,0.435029,1.065598,...,0.74324,0.0,0.0,0.0,0.8232,0.651811,0.79469,80577.125,14970.163086,15889.636719
min,-0.2946,-0.234576,-0.098432,-0.2946,-0.52515,-0.067831,-0.285884,-0.398522,-0.134923,-0.329228,...,-0.084695,-0.024978,-0.003184,-0.01388,-0.709011,-0.162898,-1.161597,31509.0,114852.0,557.0
25%,-0.2946,-0.234576,-0.098432,-0.2946,-0.52515,-0.067831,-0.285884,-0.398522,-0.134923,-0.329228,...,-0.084695,-0.024978,-0.003184,-0.01388,-0.698642,-0.162898,-0.057565,86364.0,128226.0,2521.0
50%,-0.2946,-0.234576,-0.098432,-0.2946,-0.52515,-0.067831,-0.285884,-0.398522,-0.134923,-0.329228,...,-0.084695,-0.024978,-0.003184,-0.01388,-0.60803,-0.162898,-0.057565,234196.0,128226.0,3832.0
75%,-0.2946,-0.234576,-0.098432,-0.2946,1.891142,-0.067831,-0.285884,-0.398522,-0.134923,-0.329228,...,-0.084695,-0.024978,-0.003184,-0.01388,-0.118018,-0.162898,-0.057565,234196.0,150853.0,5061.0
max,3.39443,4.263014,10.159348,3.39443,1.891142,-0.067831,3.497917,2.508842,6.825533,2.930726,...,11.807139,-0.024978,-0.003184,-0.01388,2.441271,7.360688,3.254531,234196.0,150853.0,58953.0


In [1]:
2e-3

0.002