In [1]:
# Import library
import time
import logging
from util import create_parser, set_seed, logger_setup
from data_loading import get_data
from training import train_gnn
from inference import infer_gnn
import json

# copied from inference.py
import torch
import pandas as pd
from train_util import AddEgoIds, extract_param, add_arange_ids, get_loaders, evaluate_homo
from training import get_model
from torch_geometric.nn import summary
import wandb
import os
import sys
import time

# coppied from train_util.py
import tqdm
from torch_geometric.transforms import BaseTransform
from typing import Union
from torch_geometric.data import Data
from torch_geometric.loader import LinkNeighborLoader
from sklearn.metrics import f1_score
import json

# Torch related library
from torch_geometric.explain import Explainer, GNNExplainer, PGExplainer


script_start = time.time()

In [2]:
#create parcer
debug_flags = [
            "--data", "Small_HI",
            "--model", "dir_gin",
            "--ego",
            "--unique_name", "directed",
            "--tqdm",
            "--ports",
            "--emlps",
        ]
parser = create_parser()
args = parser.parse_args(debug_flags)


In [3]:
def prep_explanation_homo(loader, inds, model, data, device, args):
    """
    Code is created based on evaluate_homo function in train_util.py
    """
    for batch in tqdm.tqdm(loader, disable=not args.tqdm):
        # Select the seed edges from which the batch was created
        inds = inds.detach().cpu()
        batch_edge_inds = inds[batch.input_id.detach().cpu()]
        batch_edge_ids = loader.data.edge_attr.detach().cpu()[batch_edge_inds, 0]
        mask = torch.isin(batch.edge_attr[:, 0].detach().cpu(), batch_edge_ids)

        #add the seed edges that have not been sampled to the batch
        missing = ~torch.isin(batch_edge_ids, batch.edge_attr[:, 0].detach().cpu())

        if missing.sum() != 0 and (args.data == 'Small_J' or args.data == 'Small_Q'):
            missing_ids = batch_edge_ids[missing].int()
            n_ids = batch.n_id
            add_edge_index = data.edge_index[:, missing_ids].detach().clone()
            node_mapping = {value.item(): idx for idx, value in enumerate(n_ids)}
            add_edge_index = torch.tensor([[node_mapping[val.item()] for val in row] for row in add_edge_index])
            add_edge_attr = data.edge_attr[missing_ids, :].detach().clone()
            add_y = data.y[missing_ids].detach().clone()
        
            batch.edge_index = torch.cat((batch.edge_index, add_edge_index), 1)
            batch.edge_attr = torch.cat((batch.edge_attr, add_edge_attr), 0)
            batch.y = torch.cat((batch.y, add_y), 0)

            mask = torch.cat((mask, torch.ones(add_y.shape[0], dtype=torch.bool)))

        #remove the unique edge id from the edge features, as it's no longer needed
        batch.edge_attr = batch.edge_attr[:, 1:]

        with torch.no_grad():
            batch.to(device)
            out = model(batch.x, batch.edge_index, batch.edge_attr)
            out = out[mask]

    return loader, inds, model, data, device, args, batch_edge_ids, batch_edge_inds, batch

In [4]:
def load_trained_model(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config):
    """
    Code is created based on inference.py
    """
    # set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # define a model config dictionary and wandb logging at the same time
    wandb.init(
        mode="disabled" if args.testing else "online",
        project="explainability",

        config={
            "epochs": args.n_epochs,
            "batch_size": args.batch_size,
            "model": args.model,
            "data": args.data,
            "num_neighbors": args.num_neighs,
            "lr": extract_param("lr", args),
            "n_hidden": extract_param("n_hidden", args),
            "n_gnn_layers": extract_param("n_gnn_layers", args),
            "loss": "ce",
            "w_ce1": extract_param("w_ce1", args),
            "w_ce2": extract_param("w_ce2", args),
            "dropout": extract_param("dropout", args),
            "final_dropout": extract_param("final_dropout", args),
            "n_heads": extract_param("n_heads", args) if args.model == 'gat' else None
        }
    )

    config = wandb.config

    # set the transform if ego ids should be used
    if args.ego:
        transform = AddEgoIds()
    else:
        transform = None

    # add the unique ids to later find the seed edges
    add_arange_ids([tr_data, val_data, te_data])

    tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform,
                                                   args)

    # get the model
    sample_batch = next(iter(tr_loader))
    model = get_model(sample_batch, config, args)

    if args.reverse_mp:
        model = to_hetero(model, te_data.metadata(), aggr='mean')

    logging.info("=> loading model checkpoint")
    #todo: to avoid issue: hardcoding unique name as directed
    checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_directed.tar') 
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)

    logging.info("=> loaded checkpoint (epoch {})".format(start_epoch))

    
    if not args.reverse_mp:
        te_loader, te_inds, model, te_data, device, args, batch_edge_ids, batch_edge_inds, batch = prep_explanation_homo(te_loader, te_inds, model, te_data, device, args)
    else:
        te_loader, te_inds, model, te_data, device, args, batch_edge_ids, batch_edge_inds, batch = prep_explanation(te_loader, te_inds, model, te_data, device, args)

    wandb.finish()

    return te_loader, te_inds, model, te_data, device, args, batch_edge_ids, batch_edge_inds, batch


In [5]:
with open('data_config.json', 'r') as config_file:
    data_config = json.load(config_file)

# Setup logging
logger_setup()

#set seed
set_seed(args.seed)

#get data
logging.info("Retrieving data")
t1 = time.perf_counter()

tr_data, val_data, te_data, tr_inds, val_inds, te_inds = get_data(args, data_config)

t2 = time.perf_counter()
logging.info(f"Retrieved data in {t2-t1:.2f}s")


logging.info(f"Running Explanation")
#todo: data, inds of tr, val needed?
te_loader, te_inds, model, te_data, device, args, batch_edge_ids, batch_edge_inds, batch = load_trained_model(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config)

2025-06-08 16:24:57,469 [INFO ] Random seed set as 1
2025-06-08 16:24:57,470 [INFO ] Retrieving data
🔍 inside of get_data: data_config[paths][aml_data] = kaggle-files
🔍 inside of get_data: --data passed in as        = Small_HI
   EdgeID  from_id  to_id  Timestamp  Amount Sent  Sent Currency  \
0       2        3      3         10     14675.57              0   
1      17       24     24         10       897.37              0   
2     158      163    163         10     99986.94              0   
3     218      215    215         10        16.08              0   
4     281      265    265         10        10.30              0   

   Amount Received  Received Currency  Payment Format  Is Laundering  
0         14675.57                  0               0              0  
1           897.37                  0               0              0  
2         99986.94                  0               0              0  
3            16.08                  0               0              0  
4        

wandb: Currently logged in as: kmishik (kmishik-university-of-bath) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


2025-06-08 16:37:47,323 [INFO ] => loading model checkpoint


  checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_directed.tar')


2025-06-08 16:37:47,874 [INFO ] => loaded checkpoint (epoch 22)


100%|██████████| 106/106 [02:33<00:00,  1.45s/it]


## Below code will raise an error as attribute for edge_mask_type is not currently supported.

In [6]:
import math
explainer = Explainer (model=model,
                       algorithm=GNNExplainer(
                         epochs       = 300,
                         num_hops     = 3,  
                         edge_size    = math.inf,        
                         entropy_reg  = 0.0),
                       explanation_type="model", # skip node_mask_type (will set as NA) as node feature is added only for processing purpose.
                       edge_mask_type="attribute",
                       model_config=dict(
                           mode="binary_classification",
                           task_level="edge",
                           return_type="raw"
                       ))

ValueError: 'attribute' is not a valid MaskType

In [7]:
import math
explainer = Explainer (model=model,
                       algorithm=GNNExplainer(
                         epochs       = 300,
                         num_hops     = 3,  
                         edge_size    = math.inf,        
                         entropy_reg  = 0.0),
                       explanation_type="model", # skip node_mask_type (will set as NA) as node feature is added only for processing purpose.
                       edge_mask_type="object",
                       model_config=dict(
                           mode="binary_classification",
                           task_level="edge",
                           return_type="raw"
                       ))

In [10]:

explanation = explainer(
    batch.x,
    batch.edge_index,
    edge_attr=batch.edge_attr,
)
print(f'Generated explanations in {explanation.available_explanations}')


Generated explanations in ['edge_mask']


In [11]:
from torch_geometric.explain import fidelity, fidelity_curve_auc
pos_fidelity, neg_fidelity = fidelity(explainer, explanation)

In [12]:
pos_fidelity, neg_fidelity

(0.501678466796875, 0.501678466796875)

In [13]:
batch

GraphData(x=[17844, 2], edge_index=[2, 162346], edge_attr=[162346, 6], y=[162346], readout='edge', loss_fn='ce', num_nodes=17844, timestamps=[162346], n_id=[17844], e_id=[162346], input_id=[3740], edge_label_index=[2, 3740], edge_label=[3740])

In [14]:
# How many laundering edges in this batch.
num_illicit = int(batch.y.sum())
print(f"{num_illicit} / {batch.y.numel()} edges ({num_illicit/batch.y.numel()}%) are labelled 'laundering'. ")


1481 / 162346 edges (0.009122491468838161%) are labelled 'laundering'. 
