# Ensemble XAI on BINN
This tutorial trains a BINN on a single simulation dataset and compares six explanation methods: **DeepLift**, **IntegratedGradients**, **GradientShap**, **Input×Gradient**, **SmoothGrad**, and **DeepLiftShap**.

In [1]:
import sys
from pathlib import Path
cwd = Path.cwd()
if (cwd / 'openbinn').exists():
    sys.path.insert(0, str(cwd))
elif (cwd.parent / 'openbinn').exists():
    sys.path.insert(0, str(cwd.parent))


In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data.sampler import SubsetRandomSampler
from torch_geometric.loader import DataLoader as GeoLoader
import matplotlib.pyplot as plt

from openbinn.binn import PNet
from openbinn.binn.util import InMemoryLogger, get_roc
from openbinn.binn.data import PnetSimDataSet, PnetSimExpDataSet, ReactomeNetwork, get_layer_maps
from openbinn.explainer import Explainer
import openbinn.experiment_utils as utils
class ModelWrapper(torch.nn.Module):
    def __init__(self, model, target_layer):
        super().__init__()
        self.model = model
        self.print_layer = target_layer
        self.target_layer = target_layer
    def forward(self, x):
        outs = self.model(x)
        return outs[self.target_layer - 1]


def load_reactome_once():
    return ReactomeNetwork(dict(
        reactome_base_dir="../biological_knowledge/simulation",
        relations_file_name="SimulationPathwaysRelation.txt",
        pathway_names_file_name="SimulationPathways.txt",
        pathway_genes_file_name="SimulationPathways.gmt",
    ))

def train_dataset(scen_dir, reactome, best_params=None):
    ds = PnetSimDataSet(root=str(scen_dir), num_features=3)
    ds.split_index_by_file(
        train_fp=scen_dir/'splits'/'training_set_0.csv',
        valid_fp=scen_dir/'splits'/'validation_set.csv',
        test_fp =scen_dir/'splits'/'test_set.csv',
    )
    maps = get_layer_maps(genes=list(ds.node_index), reactome=reactome, n_levels=3, direction='root_to_leaf', add_unk_genes=False)
    ds.node_index = [g for g in ds.node_index if g in maps[0].index]
    lr = 1e-3 if best_params is None else best_params[0]
    bs = 16 if best_params is None else int(best_params[1])
    tr_loader = GeoLoader(ds, bs, sampler=SubsetRandomSampler(ds.train_idx), num_workers=0)
    va_loader = GeoLoader(ds, bs, sampler=SubsetRandomSampler(ds.valid_idx), num_workers=0)
    model = PNet(layers=maps, num_genes=maps[0].shape[0], lr=lr)
    trainer = pl.Trainer(accelerator='auto', deterministic=True, max_epochs=200,
                         callbacks=[EarlyStopping('val_loss', patience=10, mode='min', verbose=False, min_delta=0.01)],
                         logger=InMemoryLogger(), enable_progress_bar=False)
    trainer.fit(model, tr_loader, va_loader)
    (scen_dir/'results'/'optimal').mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), scen_dir/'results'/'optimal'/'trained_model.pth')
    return model, maps

def explain_dataset(scen_dir, reactome, method):
    ds = PnetSimExpDataSet(root=str(scen_dir), num_features=1)
    ds.split_index_by_file(
        train_fp=scen_dir/'splits'/'training_set_0.csv',
        valid_fp=scen_dir/'splits'/'validation_set.csv',
        test_fp =scen_dir/'splits'/'test_set.csv',
    )
    maps = get_layer_maps(genes=list(ds.node_index), reactome=reactome, n_levels=3, direction='root_to_leaf', add_unk_genes=False)
    ds.node_index = [g for g in ds.node_index if g in maps[0].index]
    model = PNet(layers=maps, num_genes=maps[0].shape[0], lr=0.001)
    state = torch.load(scen_dir/'results'/'optimal'/'trained_model.pth', map_location='cpu')
    model.load_state_dict(state); model.eval()
    loader = GeoLoader(ds, batch_size=len(ds.test_idx), sampler=SubsetRandomSampler(ds.test_idx), num_workers=0)
    explain_root = scen_dir/'explanations'
    explain_root.mkdir(exist_ok=True)
    for tgt in range(1, len(maps)+1):
        wrap = ModelWrapper(model, tgt)
        expl_acc, lab_acc, pred_acc, id_acc = {}, [], [], []
        for X, y, ids in loader:
            p_conf = {'baseline': torch.zeros_like(X), 'classification_type': 'binary'}
            explainer = Explainer(method, wrap, p_conf)
            exp_dict = explainer.get_layer_explanations(X, y)
            for lname, ten in exp_dict.items():
                expl_acc.setdefault(lname, []).append(ten.detach().cpu().numpy())
            lab_acc.append(y.cpu().numpy())
            pred_acc.append(wrap(X).detach().cpu().numpy())
            id_acc.append(ids)
        for idx, (lname, arrs) in enumerate(expl_acc.items()):
            if idx >= len(maps):
                break
            arr = np.concatenate(arrs, axis=0)
            labels = np.concatenate(lab_acc, axis=0)
            preds  = np.concatenate(pred_acc, axis=0)
            all_ids= [sid for batch in id_acc for sid in batch]
            cur_map = maps[idx]
            cols = list(cur_map.index) if cur_map.shape[0]==arr.shape[1] else list(cur_map.columns)
            df = pd.DataFrame(arr, columns=cols)
            df['label'] = labels
            df['prediction'] = preds
            df['sample_id'] = all_ids
            out_fp = explain_root / f"PNet_{method}_L{tgt}_layer{idx}_test.csv"
            df.to_csv(out_fp, index=False)
    print('Saved raw importances for', method)
reactome = load_reactome_once()
scenario = Path('./data/b1.0_g1.0/1')  # beta=1.0, gamma=1.0
model, maps = train_dataset(scenario, reactome)


  from .autonotebook import tqdm as notebook_tqdm
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name              | Type       | Params | Mode 
---------------------------------------------------------
0 | network           | ModuleList | 111    | train
1 | intermediate_outs | ModuleList | 35     | train
---------------------------------------------------------
146       Trainable params
0         Non-trainable params
146       Total params
0.001     Total estimated model params size (MB)
29        Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_epochs=5` reached.


In [3]:
methods = ['deeplift', 'ig', 'gradshap', 'itg', 'shap']
for method in methods:
    print(f'Running {method} ...')
    explain_dataset(scenario, reactome, method)




Running deeplift ...




Saved raw importances for deeplift
Running ig ...




Saved raw importances for ig
Running gradshap ...




Saved raw importances for gradshap
Running itg ...
Saved raw importances for itg
Running shap ...
Saved raw importances for shap
