In [8]:
from binn import BINN
import pandas as pd

input_data = pd.read_csv("../data/test_qm.csv")
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

binn = BINN(
    data_matrix=input_data,
    use_reactome=True,
    n_layers=4,
    dropout=0.2,
    validate=False,
)


BINN is on the device: cpu


In [None]:
from binn import BINNHandler

data_handler = BINNHandler(network=binn, save_dir="./logs")

# Align the input data to the network
aligned_data = data_handler.align_to_network(data_matrix=input_data)

# Prepare training data (X, y) from aligned data and design matrix
X, y = data_handler.prepare_training_data(
    data_matrix=aligned_data,
    design_matrix=design_matrix,
    group_column="group",  # Use the column name that specifies groups/classes
    sample_column="sample",  # Use the column name for sample identifiers
)

# Create a PyTorch DataLoader
dataloader = data_handler.create_dataloader(X, y, batch_size=8, shuffle=True)

#  Train the BINN using PyTorch's standard training loop
print("Training BINN with PyTorch...")
data_handler.train_binn(dataloader, num_epochs=30)

Training BINN with PyTorch...
Epoch 0: Avg Accuracy: 0.9300, Avg Loss: 0.1858
Epoch 1: Avg Accuracy: 0.9650, Avg Loss: 0.1211
Epoch 2: Avg Accuracy: 0.9550, Avg Loss: 0.1255
Epoch 3: Avg Accuracy: 0.9650, Avg Loss: 0.1163




Epoch 4: Avg Accuracy: 0.9750, Avg Loss: 0.0901
Epoch 5: Avg Accuracy: 0.9600, Avg Loss: 0.1104
Epoch 6: Avg Accuracy: 0.9690, Avg Loss: 0.1038
Epoch 7: Avg Accuracy: 0.9800, Avg Loss: 0.0860
Epoch 8: Avg Accuracy: 0.9750, Avg Loss: 0.1150
Epoch 9: Avg Accuracy: 0.9600, Avg Loss: 0.1154
Epoch 10: Avg Accuracy: 0.9450, Avg Loss: 0.1223
Epoch 11: Avg Accuracy: 0.9500, Avg Loss: 0.1165
Epoch 12: Avg Accuracy: 0.9900, Avg Loss: 0.1081
Epoch 13: Avg Accuracy: 0.9370, Avg Loss: 0.1333
Epoch 14: Avg Accuracy: 0.9450, Avg Loss: 0.1222
Epoch 15: Avg Accuracy: 0.9250, Avg Loss: 0.1555
Epoch 16: Avg Accuracy: 0.9800, Avg Loss: 0.1020
Epoch 17: Avg Accuracy: 0.9750, Avg Loss: 0.0930
Epoch 18: Avg Accuracy: 0.9390, Avg Loss: 0.1448
Epoch 19: Avg Accuracy: 0.9450, Avg Loss: 0.1250
Epoch 20: Avg Accuracy: 0.9490, Avg Loss: 0.1416
Epoch 21: Avg Accuracy: 0.9750, Avg Loss: 0.0954
Epoch 22: Avg Accuracy: 0.9800, Avg Loss: 0.0873
Epoch 23: Avg Accuracy: 0.9300, Avg Loss: 0.1828
Epoch 24: Avg Accuracy: 0.

In [4]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)

In [5]:
import torch

test_data = torch.Tensor(X[5:10])
background_data = torch.Tensor(X[0:5])

importance_df = explainer.explain(test_data, background_data)
importance_df

Unnamed: 0,source,target,source name,target name,value,type,source layer,target layer
0,1,497,A0M8Q6,R-HSA-166663,0.000000,0,0,1
1,1,497,A0M8Q6,R-HSA-166663,0.000000,1,0,1
2,1,954,A0M8Q6,R-HSA-198933,0.000000,0,0,1
3,1,954,A0M8Q6,R-HSA-198933,0.000000,1,0,1
4,1,539,A0M8Q6,R-HSA-2029481,0.000000,0,0,1
...,...,...,...,...,...,...,...,...
6901,1319,0,R-HSA-9612973,root,0.306590,1,4,5
6902,1320,0,R-HSA-9709957,root,0.044687,0,4,5
6903,1320,0,R-HSA-9709957,root,0.014266,1,4,5
6904,1321,0,R-HSA-9748784,root,0.028765,0,4,5


In [6]:
from binn import ImportanceNetwork

IG = ImportanceNetwork(importance_df, norm_method="fan")

In [7]:
IG.plot_complete_sankey(
    show_top_n=5,
    multiclass=False, savename="img/complete_sankey.png", node_cmap="coolwarm", edge_cmap="coolwarm"
)

In [8]:
query_node = "R-HSA-597592"

IG.plot_subgraph_sankey(
    query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)