In [1]:
from binn import BINN, BINNDataLoader, BINNTrainer
import pandas as pd

# Load your data
data_matrix = pd.read_csv("../data/sample_datamatrix.csv")
design_matrix = pd.read_csv("../data/sample_design_matrix.tsv", sep="\t")

# Initialize BINN
binn = BINN(data_matrix=data_matrix, network_source="reactome", n_layers=4, dropout=0.2)

## Initialize DataLoader
binn_dataloader = BINNDataLoader(binn)

# Create DataLoaders
dataloaders = binn_dataloader.create_dataloaders(
    data_matrix=data_matrix,
    design_matrix=design_matrix,
    feature_column="Protein",
    group_column="group",
    sample_column="sample",
    batch_size=32,
    validation_split=0.2,
)
# Train the model
trainer = BINNTrainer(binn)
trainer.fit(dataloaders=dataloaders, num_epochs=50)


[INFO] BINN is on device: cpu
Mapping group labels: {np.int64(1): 0, np.int64(2): 1}
[Epoch 1/50] Train Loss: 0.8787, Train Accuracy: 0.4922
[Epoch 1/50] Val Loss: 0.6929, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.7921, Train Accuracy: 0.5259
[Epoch 2/50] Val Loss: 0.6927, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.8105, Train Accuracy: 0.4890
[Epoch 3/50] Val Loss: 0.6923, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.7764, Train Accuracy: 0.5672
[Epoch 4/50] Val Loss: 0.6920, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 0.7429, Train Accuracy: 0.5416
[Epoch 5/50] Val Loss: 0.6917, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.7517, Train Accuracy: 0.5429
[Epoch 6/50] Val Loss: 0.6913, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.6517, Train Accuracy: 0.6336
[Epoch 7/50] Val Loss: 0.6903, Val Accuracy: 0.5312
[Epoch 8/50] Train Loss: 0.6948, Train Accuracy: 0.6416
[Epoch 8/50] Val Loss: 0.6880, Val Accuracy: 0.5312
[Epoch 9/50] Train Loss: 0.6950, Train Accuracy: 0

In [2]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)
single_explanations = explainer.explain_single(dataloaders, split="val")
normalized_single_explanations = explainer.normalize_importances(single_explanations, method="fan")
normalized_single_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.003728,0.000932
1,0,1,A0M8Q6,R-HSA-166663,1,0.002787,0.000697
2,0,1,A0M8Q6,R-HSA-198933,0,0.003728,0.000932
3,0,1,A0M8Q6,R-HSA-198933,1,0.002787,0.000697
4,0,1,A0M8Q6,R-HSA-2029481,0,0.003728,0.000932
...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.204866,0.088231
7080,4,5,R-HSA-9709957,output_node,0,0.103269,0.051635
7081,4,5,R-HSA-9709957,output_node,1,0.174556,0.087278
7082,4,5,R-HSA-9748784,output_node,0,0.050014,0.016671


In [3]:
from binn.plot.sankey import SankeyPlotter

plotter = SankeyPlotter(
    explanations_data=single_explanations,
    show_top_n=5,
    value_col="importance",
    node_cmap="Reds",
    edge_cmap="coolwarm"
)

fig = plotter.plot()
fig.show()

In [4]:
average_explanations = explainer.explain(dataloaders, nr_iterations=3, num_epochs=50, trainer=trainer)

[BINNExplainer] Iteration 1/3...
[Epoch 1/50] Train Loss: 0.6708, Train Accuracy: 0.6629
[Epoch 1/50] Val Loss: 0.6929, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.6445, Train Accuracy: 0.6761
[Epoch 2/50] Val Loss: 0.6927, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.6418, Train Accuracy: 0.5991
[Epoch 3/50] Val Loss: 0.6923, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.6443, Train Accuracy: 0.6241
[Epoch 4/50] Val Loss: 0.6917, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 0.6390, Train Accuracy: 0.6235
[Epoch 5/50] Val Loss: 0.6905, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.6050, Train Accuracy: 0.6823
[Epoch 6/50] Val Loss: 0.6877, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.6678, Train Accuracy: 0.6610
[Epoch 7/50] Val Loss: 0.6815, Val Accuracy: 0.5312
[Epoch 8/50] Train Loss: 0.6211, Train Accuracy: 0.6379
[Epoch 8/50] Val Loss: 0.6702, Val Accuracy: 0.7969
[Epoch 9/50] Train Loss: 0.6309, Train Accuracy: 0.6905
[Epoch 9/50] Val Loss: 0.6498, Val Accuracy: 0.

In [5]:
normalized_average_explanations = explainer.normalize_importances(average_explanations, method="fan")
normalized_average_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,importance_0,importance_1,importance_2,importance_mean,importance_std,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.020381,0.040898,0.006035,0.014211,0.020381,0.014886,0.005095
1,0,1,A0M8Q6,R-HSA-166663,1,0.025529,0.000998,0.018965,0.056623,0.025529,0.023178,0.006382
2,0,1,A0M8Q6,R-HSA-198933,0,0.020381,0.040898,0.006035,0.014211,0.020381,0.014886,0.005095
3,0,1,A0M8Q6,R-HSA-198933,1,0.025529,0.000998,0.018965,0.056623,0.025529,0.023178,0.006382
4,0,1,A0M8Q6,R-HSA-2029481,0,0.020381,0.040898,0.006035,0.014211,0.020381,0.014886,0.005095
...,...,...,...,...,...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.153014,0.186016,0.209957,0.063067,0.153014,0.064348,0.065899
7080,4,5,R-HSA-9709957,output_node,0,0.180209,0.189706,0.130507,0.220414,0.180209,0.037314,0.090104
7081,4,5,R-HSA-9709957,output_node,1,0.195636,0.129179,0.317847,0.139884,0.195636,0.086526,0.097818
7082,4,5,R-HSA-9748784,output_node,0,0.154748,0.156551,0.188176,0.119518,0.154748,0.028059,0.051583
