In [1]:
from binn import BINN, BINNDataLoader, BINNTrainer
import pandas as pd

# Load your data
data_matrix = pd.read_csv("../data/test_qm.csv")
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

# Initialize BINN
binn = BINN(data_matrix=data_matrix, network_source="reactome", n_layers=4, dropout=0.2)

## Initialize DataLoader
binn_dataloader = BINNDataLoader(binn)

# Create DataLoaders
dataloaders = binn_dataloader.create_dataloaders(
    data_matrix=data_matrix,
    design_matrix=design_matrix,
    feature_column="Protein",
    group_column="group",
    sample_column="sample",
    batch_size=32,
    validation_split=0.2,
)
# Train the model
trainer = BINNTrainer(binn)
trainer.fit(dataloaders=dataloaders, num_epochs=50)


[INFO] BINN is on device: cpu
Mapping group labels: {np.int64(1): 0, np.int64(2): 1}
[Epoch 1/50] Train Loss: 0.7530, Train Accuracy: 0.5616
[Epoch 1/50] Val Loss: 0.6930, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.7637, Train Accuracy: 0.5778
[Epoch 2/50] Val Loss: 0.6928, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.8138, Train Accuracy: 0.4334
[Epoch 3/50] Val Loss: 0.6926, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.7063, Train Accuracy: 0.5778
[Epoch 4/50] Val Loss: 0.6927, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 0.7464, Train Accuracy: 0.5603
[Epoch 5/50] Val Loss: 0.6923, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.7427, Train Accuracy: 0.6241
[Epoch 6/50] Val Loss: 0.6918, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.7326, Train Accuracy: 0.5379
[Epoch 7/50] Val Loss: 0.6915, Val Accuracy: 0.5312
[Epoch 8/50] Train Loss: 0.7740, Train Accuracy: 0.5528
[Epoch 8/50] Val Loss: 0.6914, Val Accuracy: 0.4375
[Epoch 9/50] Train Loss: 0.7307, Train Accuracy: 0

In [2]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)
single_explanations = explainer.explain_single(dataloaders, split="val")
normalized_single_explanations = explainer.normalize_importances(single_explanations, method="fan")
normalized_single_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.004052,0.001013
1,0,1,A0M8Q6,R-HSA-166663,1,0.007045,0.001761
2,0,1,A0M8Q6,R-HSA-198933,0,0.004052,0.001013
3,0,1,A0M8Q6,R-HSA-198933,1,0.007045,0.001761
4,0,1,A0M8Q6,R-HSA-2029481,0,0.004052,0.001013
...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.220133,0.094806
7080,4,5,R-HSA-9709957,output_node,0,0.054674,0.027337
7081,4,5,R-HSA-9709957,output_node,1,0.247149,0.123574
7082,4,5,R-HSA-9748784,output_node,0,0.096502,0.032167


In [4]:
from binn.plot.sankey import SankeyPlotter

plotter = SankeyPlotter(
    explanations_data=single_explanations,
    show_top_n=10,
    value_col="importance",
    node_cmap="Reds",
    edge_cmap="coolwarm"
)

fig = plotter.plot()
fig.show()

              node_id         x         y
0           nOther_l1  0.003636  0.950000
1          nP12814_l0  0.003636  0.799911
2          nP35555_l0  0.003636  0.711032
3          nP23141_l0  0.003636  0.622153
4          nP27797_l0  0.003636  0.533274
5          nP15291_l0  0.003636  0.444395
6          nP11142_l0  0.003636  0.355516
7          nP60900_l0  0.003636  0.266637
8          nP02452_l0  0.003636  0.177758
9          nP25788_l0  0.003636  0.088879
10         nP60709_l0  0.003636  0.000000
11          nOther_l2  0.185455  0.950000
12   nR-HSA-983170_l1  0.185455  0.799911
13    nR-HSA-73728_l1  0.185455  0.711032
14   nR-HSA-114452_l1  0.185455  0.622153
15  nR-HSA-2173789_l1  0.185455  0.533274
16    nR-HSA-69613_l1  0.185455  0.444395
17   nR-HSA-975634_l1  0.185455  0.355516
18   nR-HSA-446388_l1  0.185455  0.266637
19  nR-HSA-2173793_l1  0.185455  0.177758
20  nR-HSA-5693567_l1  0.185455  0.088879
21    nR-HSA-72163_l1  0.185455  0.000000
22          nOther_l3  0.367273  0

In [4]:
average_explanations = explainer.explain(dataloaders, nr_iterations=3, num_epochs=50, trainer=trainer)

[BINNExplainer] Iteration 1/3...
[Epoch 1/50] Train Loss: 0.8281, Train Accuracy: 0.5196
[Epoch 1/50] Val Loss: 0.6930, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.9235, Train Accuracy: 0.4683
[Epoch 2/50] Val Loss: 0.6927, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.8645, Train Accuracy: 0.4478
[Epoch 3/50] Val Loss: 0.6926, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.8600, Train Accuracy: 0.4341
[Epoch 4/50] Val Loss: 0.6926, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 0.8760, Train Accuracy: 0.4653
[Epoch 5/50] Val Loss: 0.6922, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.8581, Train Accuracy: 0.4728
[Epoch 6/50] Val Loss: 0.6920, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.8089, Train Accuracy: 0.5336
[Epoch 7/50] Val Loss: 0.6917, Val Accuracy: 0.5312
[Epoch 8/50] Train Loss: 0.8094, Train Accuracy: 0.4821
[Epoch 8/50] Val Loss: 0.6904, Val Accuracy: 0.5312
[Epoch 9/50] Train Loss: 0.8025, Train Accuracy: 0.5216
[Epoch 9/50] Val Loss: 0.6870, Val Accuracy: 0.

In [5]:
normalized_average_explanations = explainer.normalize_importances(average_explanations, method="fan")
normalized_average_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,importance_0,importance_1,importance_2,importance_mean,importance_std,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.006274,0.001778,0.011938,0.005106,0.006274,0.004229,0.001569
1,0,1,A0M8Q6,R-HSA-166663,1,0.009529,0.006807,0.007827,0.013952,0.009529,0.003156,0.002382
2,0,1,A0M8Q6,R-HSA-198933,0,0.006274,0.001778,0.011938,0.005106,0.006274,0.004229,0.001569
3,0,1,A0M8Q6,R-HSA-198933,1,0.009529,0.006807,0.007827,0.013952,0.009529,0.003156,0.002382
4,0,1,A0M8Q6,R-HSA-2029481,0,0.006274,0.001778,0.011938,0.005106,0.006274,0.004229,0.001569
...,...,...,...,...,...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.130000,0.187380,0.052540,0.150081,0.130000,0.056850,0.055988
7080,4,5,R-HSA-9709957,output_node,0,0.211574,0.198006,0.219232,0.217486,0.211574,0.009621,0.105787
7081,4,5,R-HSA-9709957,output_node,1,0.052519,0.041629,0.059777,0.056152,0.052519,0.007842,0.026259
7082,4,5,R-HSA-9748784,output_node,0,0.164754,0.222715,0.025688,0.245859,0.164754,0.098787,0.054918
