In [1]:
from binn import BINN, BINNDataLoader, BINNTrainer
import pandas as pd

# Load your data
data_matrix = pd.read_csv("../data/test_qm.csv")
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

# Initialize BINN
binn = BINN(data_matrix=data_matrix, network_source="reactome", n_layers=4, dropout=0.2)

## Initialize DataLoader
binn_dataloader = BINNDataLoader(binn)

# Create DataLoaders
dataloaders = binn_dataloader.create_dataloaders(
    data_matrix=data_matrix,
    design_matrix=design_matrix,
    feature_column="Protein",
    group_column="group",
    sample_column="sample",
    batch_size=32,
    validation_split=0.2,
)
# Train the model
trainer = BINNTrainer(binn)
trainer.fit(dataloaders=dataloaders, num_epochs=50)


[INFO] BINN is on device: cpu
Mapping group labels: {np.int64(1): 0, np.int64(2): 1}
[Epoch 1/50] Train Loss: 0.8969, Train Accuracy: 0.4791
[Epoch 1/50] Val Loss: 0.6932, Val Accuracy: 0.4688
[Epoch 2/50] Train Loss: 0.9213, Train Accuracy: 0.4909
[Epoch 2/50] Val Loss: 0.6933, Val Accuracy: 0.4688
[Epoch 3/50] Train Loss: 0.8488, Train Accuracy: 0.5341
[Epoch 3/50] Val Loss: 0.6934, Val Accuracy: 0.4688
[Epoch 4/50] Train Loss: 0.9623, Train Accuracy: 0.4134
[Epoch 4/50] Val Loss: 0.6938, Val Accuracy: 0.4688
[Epoch 5/50] Train Loss: 0.8799, Train Accuracy: 0.5386
[Epoch 5/50] Val Loss: 0.6940, Val Accuracy: 0.4688
[Epoch 6/50] Train Loss: 0.7955, Train Accuracy: 0.5416
[Epoch 6/50] Val Loss: 0.6942, Val Accuracy: 0.4688
[Epoch 7/50] Train Loss: 0.8800, Train Accuracy: 0.5353
[Epoch 7/50] Val Loss: 0.6946, Val Accuracy: 0.4688
[Epoch 8/50] Train Loss: 0.6925, Train Accuracy: 0.6297
[Epoch 8/50] Val Loss: 0.6945, Val Accuracy: 0.4688
[Epoch 9/50] Train Loss: 0.7421, Train Accuracy: 0

In [2]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)
single_explanations = explainer.explain_single(dataloaders, split="val")
normalized_single_explanations = explainer.normalize_importances(single_explanations, method="fan")
normalized_single_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,source_id,target_id,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.017386,nA0M8Q6_l0,nR-HSA-166663_l1,0.004346
1,0,1,A0M8Q6,R-HSA-166663,1,0.024437,nA0M8Q6_l0,nR-HSA-166663_l1,0.006109
2,0,1,A0M8Q6,R-HSA-198933,0,0.017386,nA0M8Q6_l0,nR-HSA-198933_l1,0.004346
3,0,1,A0M8Q6,R-HSA-198933,1,0.024437,nA0M8Q6_l0,nR-HSA-198933_l1,0.006109
4,0,1,A0M8Q6,R-HSA-2029481,0,0.017386,nA0M8Q6_l0,nR-HSA-2029481_l1,0.004346
...,...,...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.074729,nR-HSA-9612973_l4,noutput_node_l5,0.032184
7080,4,5,R-HSA-9709957,output_node,0,0.115902,nR-HSA-9709957_l4,noutput_node_l5,0.057951
7081,4,5,R-HSA-9709957,output_node,1,0.000118,nR-HSA-9709957_l4,noutput_node_l5,0.000059
7082,4,5,R-HSA-9748784,output_node,0,0.206383,nR-HSA-9748784_l4,noutput_node_l5,0.068794


In [3]:
from binn.plot.sankey import plot_sankey

plot_sankey(single_explanations)

In [4]:
average_explanations = explainer.explain(dataloaders, nr_iterations=3, num_epochs=50, trainer=trainer)

[BINNExplainer] Iteration 1/3...
[Epoch 1/50] Train Loss: 0.6609, Train Accuracy: 0.6261
[Epoch 1/50] Val Loss: 0.6929, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.7479, Train Accuracy: 0.5498
[Epoch 2/50] Val Loss: 0.6927, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.7175, Train Accuracy: 0.5985
[Epoch 3/50] Val Loss: 0.6926, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.6994, Train Accuracy: 0.6179
[Epoch 4/50] Val Loss: 0.6920, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 0.7394, Train Accuracy: 0.6347
[Epoch 5/50] Val Loss: 0.6905, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.6687, Train Accuracy: 0.6291
[Epoch 6/50] Val Loss: 0.6872, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.6867, Train Accuracy: 0.6179
[Epoch 7/50] Val Loss: 0.6802, Val Accuracy: 0.5312
[Epoch 8/50] Train Loss: 0.6190, Train Accuracy: 0.6491
[Epoch 8/50] Val Loss: 0.6657, Val Accuracy: 0.8750
[Epoch 9/50] Train Loss: 0.6091, Train Accuracy: 0.6961
[Epoch 9/50] Val Loss: 0.6385, Val Accuracy: 0.

In [5]:
normalized_average_explanations = explainer.normalize_importances(average_explanations, method="fan")
normalized_average_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,source_id,target_id,importance_0,importance_1,importance_2,importance_mean,importance_std,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.043231,nA0M8Q6_l0,nR-HSA-166663_l1,0.023816,0.054591,0.051286,0.043231,0.013794,0.010808
1,0,1,A0M8Q6,R-HSA-166663,1,0.020888,nA0M8Q6_l0,nR-HSA-166663_l1,0.043425,0.013309,0.005929,0.020888,0.016218,0.005222
2,0,1,A0M8Q6,R-HSA-198933,0,0.043231,nA0M8Q6_l0,nR-HSA-198933_l1,0.023816,0.054591,0.051286,0.043231,0.013794,0.010808
3,0,1,A0M8Q6,R-HSA-198933,1,0.020888,nA0M8Q6_l0,nR-HSA-198933_l1,0.043425,0.013309,0.005929,0.020888,0.016218,0.005222
4,0,1,A0M8Q6,R-HSA-2029481,0,0.043231,nA0M8Q6_l0,nR-HSA-2029481_l1,0.023816,0.054591,0.051286,0.043231,0.013794,0.010808
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.089820,nR-HSA-9612973_l4,noutput_node_l5,0.108489,0.015467,0.145504,0.089820,0.054704,0.038683
7080,4,5,R-HSA-9709957,output_node,0,0.184374,nR-HSA-9709957_l4,noutput_node_l5,0.262943,0.206544,0.083635,0.184374,0.074862,0.092187
7081,4,5,R-HSA-9709957,output_node,1,0.132082,nR-HSA-9709957_l4,noutput_node_l5,0.247671,0.118751,0.029824,0.132082,0.089434,0.066041
7082,4,5,R-HSA-9748784,output_node,0,0.178142,nR-HSA-9748784_l4,noutput_node_l5,0.034630,0.246579,0.253217,0.178142,0.101515,0.059381
