In [1]:
from binn import BINN, BINNDataLoader, BINNTrainer
import pandas as pd

# Load your data
data_matrix = pd.read_csv("../data/test_qm.csv")
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

# Initialize BINN
binn = BINN(data_matrix=data_matrix, network_source="reactome", n_layers=4, dropout=0.2)

## Initialize DataLoader
binn_dataloader = BINNDataLoader(binn)

# Create DataLoaders
dataloaders = binn_dataloader.create_dataloaders(
    data_matrix=data_matrix,
    design_matrix=design_matrix,
    feature_column="Protein",
    group_column="group",
    sample_column="sample",
    batch_size=32,
    validation_split=0.2,
)
# Train the model
trainer = BINNTrainer(binn)
trainer.fit(dataloaders=dataloaders, num_epochs=50)


[INFO] BINN is on device: cpu
Mapping group labels: {np.int64(1): 0, np.int64(2): 1}
[Epoch 1/50] Train Loss: 0.7664, Train Accuracy: 0.5722
[Epoch 1/50] Val Loss: 0.6931, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.7555, Train Accuracy: 0.5741
[Epoch 2/50] Val Loss: 0.6929, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.7466, Train Accuracy: 0.5978
[Epoch 3/50] Val Loss: 0.6929, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.7772, Train Accuracy: 0.5935
[Epoch 4/50] Val Loss: 0.6930, Val Accuracy: 0.6250
[Epoch 5/50] Train Loss: 0.7238, Train Accuracy: 0.6129
[Epoch 5/50] Val Loss: 0.6920, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.7126, Train Accuracy: 0.6054
[Epoch 6/50] Val Loss: 0.6914, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.6633, Train Accuracy: 0.6347
[Epoch 7/50] Val Loss: 0.6891, Val Accuracy: 0.6250
[Epoch 8/50] Train Loss: 0.7122, Train Accuracy: 0.6060
[Epoch 8/50] Val Loss: 0.6842, Val Accuracy: 0.5938
[Epoch 9/50] Train Loss: 0.7094, Train Accuracy: 0

In [2]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)
explanations_df = explainer.explain(dataloaders, split="val")
explanations_df

Unnamed: 0,source_layer,target_layer,source_node,target_node,class_idx,importance
0,0,1,A0M8Q6,R-HSA-166663,0,0.041094
1,0,1,A0M8Q6,R-HSA-166663,1,0.011406
2,0,1,A0M8Q6,R-HSA-198933,0,0.041094
3,0,1,A0M8Q6,R-HSA-198933,1,0.011406
4,0,1,A0M8Q6,R-HSA-2029481,0,0.041094
...,...,...,...,...,...,...
7079,4,5,R-HSA-9612973,output_node,1,0.250316
7080,4,5,R-HSA-9709957,output_node,0,0.204158
7081,4,5,R-HSA-9709957,output_node,1,0.031585
7082,4,5,R-HSA-9748784,output_node,0,0.244457


In [5]:
explainer.explain_average(dataloaders, nr_iterations=10, num_epochs=50, trainer=trainer)

[BINNExplainer] Iteration 1/10...
[Epoch 1/50] Train Loss: 0.8816, Train Accuracy: 0.4778
[Epoch 1/50] Val Loss: 0.6930, Val Accuracy: 0.5312
[Epoch 2/50] Train Loss: 0.8271, Train Accuracy: 0.5034
[Epoch 2/50] Val Loss: 0.6928, Val Accuracy: 0.5312
[Epoch 3/50] Train Loss: 0.8011, Train Accuracy: 0.4716
[Epoch 3/50] Val Loss: 0.6926, Val Accuracy: 0.5312
[Epoch 4/50] Train Loss: 0.8002, Train Accuracy: 0.5159
[Epoch 4/50] Val Loss: 0.6925, Val Accuracy: 0.5312
[Epoch 5/50] Train Loss: 0.7465, Train Accuracy: 0.5916
[Epoch 5/50] Val Loss: 0.6924, Val Accuracy: 0.5312
[Epoch 6/50] Train Loss: 0.7251, Train Accuracy: 0.5985
[Epoch 6/50] Val Loss: 0.6928, Val Accuracy: 0.5312
[Epoch 7/50] Train Loss: 0.7403, Train Accuracy: 0.5597
[Epoch 7/50] Val Loss: 0.6931, Val Accuracy: 0.5312
[Epoch 8/50] Train Loss: 0.6668, Train Accuracy: 0.5929
[Epoch 8/50] Val Loss: 0.6931, Val Accuracy: 0.5312
[Epoch 9/50] Train Loss: 0.6969, Train Accuracy: 0.6284
[Epoch 9/50] Val Loss: 0.6916, Val Accuracy: 0