In [1]:
from binn import BINN, BINNDataLoader, BINNTrainer
import pandas as pd

# Load your data
data_matrix = pd.read_csv("../data/test_qm.csv")
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

# Initialize BINN
binn = BINN(data_matrix=data_matrix, network_source="reactome", n_layers=4, dropout=0.2)

## Initialize DataLoader
binn_dataloader = BINNDataLoader(binn)

# Create DataLoaders
dataloaders = binn_dataloader.create_dataloaders(
    data_matrix=data_matrix,
    design_matrix=design_matrix,
    feature_column="Protein",
    group_column="group",
    sample_column="sample",
    batch_size=32,
    validation_split=0.2,
)
# Train the model
trainer = BINNTrainer(binn)
trainer.fit(dataloaders=dataloaders, num_epochs=50)


[INFO] BINN is on device: cpu
Mapping group labels: {np.int64(1): 0, np.int64(2): 1}
[Epoch 1/50] Train Loss: 0.8390, Train Accuracy: 0.5228
[Epoch 1/50] Val Loss: 0.6931, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.7539, Train Accuracy: 0.5498
[Epoch 2/50] Val Loss: 0.6926, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.7474, Train Accuracy: 0.6041
[Epoch 3/50] Val Loss: 0.6925, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.7478, Train Accuracy: 0.5817
[Epoch 4/50] Val Loss: 0.6921, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 0.7046, Train Accuracy: 0.6142
[Epoch 5/50] Val Loss: 0.6912, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.6743, Train Accuracy: 0.6547
[Epoch 6/50] Val Loss: 0.6895, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.7804, Train Accuracy: 0.5591
[Epoch 7/50] Val Loss: 0.6854, Val Accuracy: 0.5312
[Epoch 8/50] Train Loss: 0.7200, Train Accuracy: 0.5985
[Epoch 8/50] Val Loss: 0.6767, Val Accuracy: 0.5938
[Epoch 9/50] Train Loss: 0.6761, Train Accuracy: 0

In [None]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)
single_explanations = explainer.explain_single(dataloaders, split="val")
normalized_single_explanations = explainer.normalize_importances(single_explanations, method="fan")
normalized_single_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.013079,0.003270
1,0,1,A0M8Q6,R-HSA-166663,1,0.009960,0.002490
2,0,1,A0M8Q6,R-HSA-198933,0,0.013079,0.003270
3,0,1,A0M8Q6,R-HSA-198933,1,0.009960,0.002490
4,0,1,A0M8Q6,R-HSA-2029481,0,0.013079,0.003270
...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.155677,0.067047
7080,4,5,R-HSA-9709957,output_node,0,0.016762,0.008381
7081,4,5,R-HSA-9709957,output_node,1,0.012426,0.006213
7082,4,5,R-HSA-9748784,output_node,0,0.043275,0.014425


In [3]:
average_explanations = explainer.explain(dataloaders, nr_iterations=10, num_epochs=50, trainer=trainer)

[BINNExplainer] Iteration 1/10...
[Epoch 1/50] Train Loss: 1.0732, Train Accuracy: 0.3647
[Epoch 1/50] Val Loss: 0.6930, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.9623, Train Accuracy: 0.3877
[Epoch 2/50] Val Loss: 0.6930, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.9798, Train Accuracy: 0.3851
[Epoch 3/50] Val Loss: 0.6930, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.9234, Train Accuracy: 0.4216
[Epoch 4/50] Val Loss: 0.6929, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 1.1204, Train Accuracy: 0.3196
[Epoch 5/50] Val Loss: 0.6931, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.9897, Train Accuracy: 0.3640
[Epoch 6/50] Val Loss: 0.6940, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.9605, Train Accuracy: 0.4015
[Epoch 7/50] Val Loss: 0.6962, Val Accuracy: 0.3594
[Epoch 8/50] Train Loss: 0.8728, Train Accuracy: 0.4265
[Epoch 8/50] Val Loss: 0.7010, Val Accuracy: 0.2656
[Epoch 9/50] Train Loss: 0.9412, Train Accuracy: 0.4565
[Epoch 9/50] Val Loss: 0.7099, Val Accuracy: 0

In [4]:
normalized_average_explanations = explainer.normalize_importances(average_explanations, method="fan")
normalized_average_explanations

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance,importance_0,importance_1,importance_2,importance_3,importance_4,importance_5,importance_6,importance_7,importance_8,importance_9,importance_mean,importance_std,normalized_importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.007869,0.011303,0.015724,0.015170,0.002462,0.002444,0.007732,0.005273,0.002714,0.008155,0.007710,0.007869,0.004679,0.001967
1,0,1,A0M8Q6,R-HSA-166663,1,0.011062,0.005682,0.006922,0.038007,0.013618,0.014674,0.011140,0.001846,0.002769,0.012297,0.003666,0.011062,0.010008,0.002766
2,0,1,A0M8Q6,R-HSA-198933,0,0.007869,0.011303,0.015724,0.015170,0.002462,0.002444,0.007732,0.005273,0.002714,0.008155,0.007710,0.007869,0.004679,0.001967
3,0,1,A0M8Q6,R-HSA-198933,1,0.011062,0.005682,0.006922,0.038007,0.013618,0.014674,0.011140,0.001846,0.002769,0.012297,0.003666,0.011062,0.010008,0.002766
4,0,1,A0M8Q6,R-HSA-2029481,0,0.007869,0.011303,0.015724,0.015170,0.002462,0.002444,0.007732,0.005273,0.002714,0.008155,0.007710,0.007869,0.004679,0.001967
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.157522,0.020262,0.218977,0.089351,0.091145,0.211061,0.191303,0.028900,0.265969,0.273848,0.184407,0.157522,0.088452,0.067841
7080,4,5,R-HSA-9709957,output_node,0,0.139244,0.166241,0.103599,0.069364,0.205423,0.158629,0.216232,0.080623,0.103972,0.177857,0.110506,0.139244,0.049561,0.069622
7081,4,5,R-HSA-9709957,output_node,1,0.162905,0.175816,0.255807,0.117508,0.115893,0.131525,0.243027,0.114926,0.095025,0.314999,0.064528,0.162905,0.077584,0.081453
7082,4,5,R-HSA-9748784,output_node,0,0.152818,0.040557,0.228328,0.218835,0.076408,0.128111,0.212472,0.166289,0.225683,0.156655,0.074847,0.152818,0.066524,0.050939
