## Pipeline for model explainability

In [None]:
import os

# Check if we are in the correct directory
print("Current working directory:", os.getcwd())
path = os.path.abspath(os.path.join(os.getcwd(), '..', 'path.py'))
%run $path

In [None]:
# Import data
train_file = '../data/random_leish10/train.csv'
val_file = '../data/random_leish10/val.csv'
test_file = '../data/random_leish10/test.csv'

# Load pre-trained model and their hypeparameters
architecture_type = 'gat'
hyperparams_path = "../output/params/leishmania_infantum_GAT_10uM.json"
model_path = "../output/models/leishmania_infantum_GAT_10uM.pth"

# Plot contribution maps using counterfactual perturbations
cont_maps = "../output/explanations/similarities/"

# Plot attention scores onto maps (only for attentive GNNs)
attn_maps = "../output/explanations/attentions/"

##### Load data

In [None]:
from params import load_data

train_smiles, y_train = load_data(train_file)
val_smiles, y_val = load_data(val_file)
test_smiles, y_test = load_data(test_file)

print(f"Training data: {len(train_smiles)} samples")
print(f"Validation data: {len(val_smiles)} samples")
print(f"Test data: {len(test_smiles)} samples")

##### Building molecular graphs in data loaders

In [None]:
from loaders import graph_loader, graph_info

train_loader, val_loader, test_loader = graph_loader(
    train_smiles, 
    val_smiles, 
    test_smiles, 
    y_train, 
    y_val, 
    y_test, 
    batch_size=32)

node_dim, edge_dim, num_tasks = graph_info(train_loader)
print(f"Max number of atom features: {node_dim}")
print(f"Max number of bond features: {edge_dim}")
print(f"Number of tasks: {num_tasks}")

##### Loading pre-trained (supervised only) model

In [None]:
from params import  load_params, load_model

params = load_params(hyperparams_path)
model = load_model(
    model_path, 
    architecture_type, 
    params, 
    node_dim, 
    edge_dim, 
    num_tasks)

##### Plot counterfactual contribution maps

In [None]:
from utils import device
from explanations import view_explanations

view_explanations(
    model,
    test_loader, 
    device, 
    out_path=cont_maps,
    task_idx=0)

##### Plot attention weights

In [None]:
from utils import device
from explanations import view_attentions

view_attentions(
    model,
    test_loader,
    device,
    out_path=attn_maps)