In [6]:
from binn import BINN, BINNDataLoader, BINNTrainer
import pandas as pd

# Load your data
data_matrix = pd.read_csv("../data/test_qm.csv")
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

# Initialize BINN and handler
binn = BINN(data_matrix=data_matrix, use_reactome=True, n_layers=4, dropout=0.2)

# Initialize DataLoader
binn_dataloader = BINNDataLoader()

# Align data to the BINN network
aligned_data = binn_dataloader.align_to_network(data_matrix, binn)

# Prepare training data
data_splits = binn_dataloader.prepare_training_data(
    aligned_data=aligned_data,
    design_matrix=design_matrix,
    group_column="group",
    sample_column="sample",
    validation_split=0.2,
)

# Create PyTorch DataLoaders
train_loader = binn_dataloader.create_dataloader(*data_splits["train"], batch_size=32)
val_loader = None
if "val" in data_splits:
    val_loader = binn_dataloader.create_dataloader(*data_splits["val"], batch_size=32)

# Train the model
trainer = BINNTrainer(binn)
trainer.train(dataloaders={"train": train_loader, "val": val_loader}, num_epochs=200)


[INFO] BINN is on device: cpu
Mapping group labels: {np.int64(1): 0, np.int64(2): 1}
[Epoch 1/200] Train Loss: 0.7487, Train Accuracy: 0.5653
[Epoch 1/200] Val Loss: 0.6936, Val Accuracy: 0.3750
[Epoch 2/200] Train Loss: 0.7813, Train Accuracy: 0.5034
[Epoch 2/200] Val Loss: 0.6957, Val Accuracy: 0.3750
[Epoch 3/200] Train Loss: 0.8375, Train Accuracy: 0.5185
[Epoch 3/200] Val Loss: 0.6953, Val Accuracy: 0.3750
[Epoch 4/200] Train Loss: 0.7762, Train Accuracy: 0.5403
[Epoch 4/200] Val Loss: 0.6955, Val Accuracy: 0.3750
[Epoch 5/200] Train Loss: 0.8002, Train Accuracy: 0.5084
[Epoch 5/200] Val Loss: 0.6962, Val Accuracy: 0.3750
[Epoch 6/200] Train Loss: 0.7655, Train Accuracy: 0.5097
[Epoch 6/200] Val Loss: 0.6982, Val Accuracy: 0.3281
[Epoch 7/200] Train Loss: 0.6893, Train Accuracy: 0.5873
[Epoch 7/200] Val Loss: 0.6934, Val Accuracy: 0.3750
[Epoch 8/200] Train Loss: 0.6710, Train Accuracy: 0.6041
[Epoch 8/200] Val Loss: 0.6812, Val Accuracy: 0.5156
[Epoch 9/200] Train Loss: 0.7459, 

In [4]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)

In [5]:
import torch

test_data = torch.Tensor(X[5:10])
background_data = torch.Tensor(X[0:5])

importance_df = explainer.explain(test_data, background_data)
importance_df

Unnamed: 0,source,target,source name,target name,value,type,source layer,target layer
0,1,497,A0M8Q6,R-HSA-166663,0.000000,0,0,1
1,1,497,A0M8Q6,R-HSA-166663,0.000000,1,0,1
2,1,954,A0M8Q6,R-HSA-198933,0.000000,0,0,1
3,1,954,A0M8Q6,R-HSA-198933,0.000000,1,0,1
4,1,539,A0M8Q6,R-HSA-2029481,0.000000,0,0,1
...,...,...,...,...,...,...,...,...
6901,1319,0,R-HSA-9612973,root,0.306590,1,4,5
6902,1320,0,R-HSA-9709957,root,0.044687,0,4,5
6903,1320,0,R-HSA-9709957,root,0.014266,1,4,5
6904,1321,0,R-HSA-9748784,root,0.028765,0,4,5


In [6]:
from binn import ImportanceNetwork

IG = ImportanceNetwork(importance_df, norm_method="fan")

In [7]:
IG.plot_complete_sankey(
    show_top_n=5,
    multiclass=False, savename="img/complete_sankey.png", node_cmap="coolwarm", edge_cmap="coolwarm"
)

In [8]:
query_node = "R-HSA-597592"

IG.plot_subgraph_sankey(
    query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)