In [1]:
from binn import BINN, BINNDataLoader, BINNTrainer
import pandas as pd

# Load your data
data_matrix = pd.read_csv("../data/test_qm.csv")
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

# Initialize BINN
binn = BINN(data_matrix=data_matrix, network_source="reactome", n_layers=2, dropout=0.2)

## Initialize DataLoader
binn_dataloader = BINNDataLoader(binn)

# Create DataLoaders
dataloaders = binn_dataloader.create_dataloaders(
    data_matrix=data_matrix,
    design_matrix=design_matrix,
    feature_column="Protein",
    group_column="group",
    sample_column="sample",
    batch_size=32,
    validation_split=0.2,
)
# Train the model
trainer = BINNTrainer(binn)
trainer.fit(dataloaders=dataloaders, num_epochs=50)


[INFO] BINN is on device: cpu
Mapping group labels: {np.int64(1): 0, np.int64(2): 1}
[Epoch 1/50] Train Loss: 0.8677, Train Accuracy: 0.4666
[Epoch 1/50] Val Loss: 0.6884, Val Accuracy: 0.7500
[Epoch 2/50] Train Loss: 0.8854, Train Accuracy: 0.4515
[Epoch 2/50] Val Loss: 0.6847, Val Accuracy: 0.7500
[Epoch 3/50] Train Loss: 0.8106, Train Accuracy: 0.4847
[Epoch 3/50] Val Loss: 0.6790, Val Accuracy: 0.6719
[Epoch 4/50] Train Loss: 0.8729, Train Accuracy: 0.4534
[Epoch 4/50] Val Loss: 0.6693, Val Accuracy: 0.7031
[Epoch 5/50] Train Loss: 0.7880, Train Accuracy: 0.4791
[Epoch 5/50] Val Loss: 0.6555, Val Accuracy: 0.7188
[Epoch 6/50] Train Loss: 0.8106, Train Accuracy: 0.4672
[Epoch 6/50] Val Loss: 0.6377, Val Accuracy: 0.7188
[Epoch 7/50] Train Loss: 0.8297, Train Accuracy: 0.4659
[Epoch 7/50] Val Loss: 0.6165, Val Accuracy: 0.7500
[Epoch 8/50] Train Loss: 0.7668, Train Accuracy: 0.5534
[Epoch 8/50] Val Loss: 0.5947, Val Accuracy: 0.7188
[Epoch 9/50] Train Loss: 0.7428, Train Accuracy: 0

In [2]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)
single_explanations = explainer.explain_single(dataloaders, split="val")
normalized_single_explanations = explainer.normalize_importances(single_explanations, method="fan")
normalized_single_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,normalized_importance
0,0,1,A0M8Q6,R-HSA-1280218,0,0.008060,0.003118
1,0,1,A0M8Q6,R-HSA-1280218,1,0.009703,0.003754
2,0,1,A0M8Q6,R-HSA-168249,0,0.008060,0.003118
3,0,1,A0M8Q6,R-HSA-168249,1,0.009703,0.003754
4,0,1,A0M8Q6,R-HSA-202733,0,0.008060,0.003118
...,...,...,...,...,...,...,...
3755,2,3,R-HSA-9612973,output_node,1,0.157928,0.068016
3756,2,3,R-HSA-9709957,output_node,0,0.001029,0.000514
3757,2,3,R-HSA-9709957,output_node,1,0.266911,0.133455
3758,2,3,R-HSA-9748784,output_node,0,0.026517,0.008839


In [3]:
from binn.plot.sankey import SankeyPlotter

plotter = SankeyPlotter(
    explanations_data=single_explanations,
    multiclass=False,
    show_top_n=3,
    value_col="importance",
    node_cmap="Reds",
    edge_cmap="coolwarm"
)

fig = plotter.plot()
fig.show()

In [4]:
average_explanations = explainer.explain(dataloaders, nr_iterations=3, num_epochs=50, trainer=trainer)

[BINNExplainer] Iteration 1/3...
[Epoch 1/50] Train Loss: 0.8281, Train Accuracy: 0.5196
[Epoch 1/50] Val Loss: 0.6930, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.9235, Train Accuracy: 0.4683
[Epoch 2/50] Val Loss: 0.6927, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.8645, Train Accuracy: 0.4478
[Epoch 3/50] Val Loss: 0.6926, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.8600, Train Accuracy: 0.4341
[Epoch 4/50] Val Loss: 0.6926, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 0.8760, Train Accuracy: 0.4653
[Epoch 5/50] Val Loss: 0.6922, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.8581, Train Accuracy: 0.4728
[Epoch 6/50] Val Loss: 0.6920, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.8089, Train Accuracy: 0.5336
[Epoch 7/50] Val Loss: 0.6917, Val Accuracy: 0.5312
[Epoch 8/50] Train Loss: 0.8094, Train Accuracy: 0.4821
[Epoch 8/50] Val Loss: 0.6904, Val Accuracy: 0.5312
[Epoch 9/50] Train Loss: 0.8025, Train Accuracy: 0.5216
[Epoch 9/50] Val Loss: 0.6870, Val Accuracy: 0.

In [5]:
normalized_average_explanations = explainer.normalize_importances(average_explanations, method="fan")
normalized_average_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,importance_0,importance_1,importance_2,importance_mean,importance_std,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.006274,0.001778,0.011938,0.005106,0.006274,0.004229,0.001569
1,0,1,A0M8Q6,R-HSA-166663,1,0.009529,0.006807,0.007827,0.013952,0.009529,0.003156,0.002382
2,0,1,A0M8Q6,R-HSA-198933,0,0.006274,0.001778,0.011938,0.005106,0.006274,0.004229,0.001569
3,0,1,A0M8Q6,R-HSA-198933,1,0.009529,0.006807,0.007827,0.013952,0.009529,0.003156,0.002382
4,0,1,A0M8Q6,R-HSA-2029481,0,0.006274,0.001778,0.011938,0.005106,0.006274,0.004229,0.001569
...,...,...,...,...,...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.130000,0.187380,0.052540,0.150081,0.130000,0.056850,0.055988
7080,4,5,R-HSA-9709957,output_node,0,0.211574,0.198006,0.219232,0.217486,0.211574,0.009621,0.105787
7081,4,5,R-HSA-9709957,output_node,1,0.052519,0.041629,0.059777,0.056152,0.052519,0.007842,0.026259
7082,4,5,R-HSA-9748784,output_node,0,0.164754,0.222715,0.025688,0.245859,0.164754,0.098787,0.054918
