## Pipeline for model explainability

In [None]:
import os

# Check if we are in the correct directory
print("Current working directory:", os.getcwd())
path = os.path.abspath(os.path.join(os.getcwd(), '..', 'path.py'))
%run $path
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

In [None]:
#Configure dataloader
in_vitro = '../data/in_vitro/dataset.csv'
features = '../data/info.csv'
seed = 42

# Load Tier1 model and their hypeparameters
tier1_model = "../output/models/tier1.pth"
tier1_params = "../output/params/tier1.json"
tier1_type = 'mpnn'

# Load Tier2 model and their hypeparameters
tier2_model = "../output/models/hershberger_gat.pth"
tier2_params = "../output/params/hershberger_gat.json"
tier2_type = 'gat'

# Plot cross-attention maps
out_attn = "../output/explanations/attention/"
out_count = "../output/explanations/counterfactual/"

##### Loading pre-trained (supervised only) model

In [None]:
from params import  load_params, load_tier2_model

params = load_params(tier2_params)
model = load_tier2_model(
    tier2_model,
    tier2_type,
    params,
    node_dim=48,
    edge_dim=12,
    node_pred=1,
    edge_pred=2,
    num_tasks=1)

##### Load data

In [None]:
from tier2_loader import graph_loader_inference
from utils import device

smiles = "CC1=C(C=CC(=C1)C2=CC(=C(C=C2)N)C)N"

mol_loader = graph_loader_inference(
    smiles=smiles, 
    batch_size=1, 
    hyperparams_path=tier1_params,
    model_path=tier1_model,
    architecture_type=tier1_type,
    data_path=in_vitro,
    feature_path=features,
    device=device,
    seed=seed
)

##### Predict

In [None]:
from predictor import predict_tier2
from utils import device


predict_tier2(
    model.to(device),
    mol_loader,
    device,
    show_results=True);

##### Plot Cross-attention maps

In [None]:
from reveal import view_attention

view_attention(
    model,
    mol_loader,
    device,
    out_attn,
    pathway="hershberger")

##### Plot Counterfactual maps

In [None]:
from utils import device
from reveal import view_counterfactuality

view_counterfactuality(
    model, 
    mol_loader,
    in_vitro,
    features,
    device,
    out_count,
    perturbations=15,
    task_index=0)