In [1]:
import os
import torch
from ml.models.RGCNEdgeTypeTAG3VerticesDoubleHistory2Parametrized.model import (
    StateModelEncoder,
)
from torch_geometric.explain import (
    Explainer,
    ModelConfig,
    Explanation,
    CaptumExplainer,
)
from ml.inference import TORCH
import glob
from ml.training.experimental_utils.explain.interface import hetero_forward


SAMPLE_DIR = os.path.join("report", "small_sample")
dataset_sample = [
    torch.load(path, weights_only=False, map_location="cpu")
    for path in glob.glob(f"{SAMPLE_DIR}/*")
]

PATH_TO_MODEL = os.path.join("report", "models", "model_78.pth")
weights = torch.load(PATH_TO_MODEL, map_location="cpu", weights_only=True)

In [2]:
hetero_encoder = hetero_forward(StateModelEncoder)
model = hetero_encoder(
    hidden_channels=82,
    num_of_state_features=64,
    num_hops_1=8,
    num_hops_2=8,
    normalization=True,
)
model.load_state_dict(weights)
model_config = ModelConfig(
    mode="multiclass_classification", task_level="node", return_type="log_probs"
)

In [4]:
captum_explainer = Explainer(
    model,
    algorithm=CaptumExplainer("IntegratedGradients"),
    explanation_type="model",
    model_config=model_config,
    node_mask_type="attributes",
)
captum_explainations: list[Explanation] = list()
for data in dataset_sample:
    del data[TORCH.statevertex_history_gamevertex]
    del data[TORCH.statevertex_in_gamevertex]
    captum_explainations.append(
        captum_explainer(
            data.x_dict,
            data.edge_index_dict,
            edge_type_dict=data.edge_type_dict,
            edge_attr_dict=data.edge_attr_dict,
        )
    )


  from .autonotebook import tqdm as notebook_tqdm
