In [1]:
from binn import Network, BINN
import pandas as pd

input_data = pd.read_csv("../data/test_qm.csv")
translation = pd.read_csv("../data/translation.tsv", sep="\t")
pathways = pd.read_csv("../data/pathways.tsv", sep="\t")

network = Network(
    input_data=input_data,
    pathways=pathways,
    mapping=translation,
    source_column="child",
    target_column="parent",
)

binn = BINN(
    network=network,
    n_layers=4,
    dropout=0.2,
    validate=False,
    residual=True,
    device="cpu",
)

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



BINN is on the device: cpu


In [2]:
from util_for_examples import fit_data_matrix_to_network_input, generate_data
import torch
from lightning.pytorch import Trainer

design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

protein_matrix = fit_data_matrix_to_network_input(input_data, features=network.inputs)

X, y = generate_data(protein_matrix, design_matrix=design_matrix)
dataset = torch.utils.data.TensorDataset(
    torch.tensor(X, dtype=torch.float32, device=binn.device),
    torch.tensor(y, dtype=torch.int16, device=binn.device),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

# You can train using the Lightning Trainer
trainer = Trainer(max_epochs=10, log_every_n_steps=10)
#trainer.fit(binn, dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [3]:
import torch.nn.functional as F

# You can also train with a standard PyTorch train loop 

optimizer = torch.optim.Adam(binn.parameters(), lr=0.001)

num_epochs = 30

for epoch in range(num_epochs):
    binn.train() 
    total_loss = 0.0
    total_accuracy = 0

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(binn.device)
        targets = targets.to(binn.device).type(torch.LongTensor)
        optimizer.zero_grad()
        outputs = binn(inputs).to(binn.device)
        loss = F.cross_entropy(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_accuracy += torch.sum(torch.argmax(outputs, axis=1) == targets) / len(targets)

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_accuracy / len(dataloader)
    print(f'Epoch {epoch}, Average Accuracy {avg_accuracy}, Average Loss: {avg_loss}')



Epoch 0, Average Accuracy 0.6669999957084656, Average Loss: 0.6535322880744934
Epoch 1, Average Accuracy 0.824999988079071, Average Loss: 0.5731604671478272
Epoch 2, Average Accuracy 0.8840000033378601, Average Loss: 0.5321515893936157
Epoch 3, Average Accuracy 0.9100000262260437, Average Loss: 0.5054576551914215
Epoch 4, Average Accuracy 0.9649999737739563, Average Loss: 0.4636109948158264
Epoch 5, Average Accuracy 0.909000039100647, Average Loss: 0.4785078537464142
Epoch 6, Average Accuracy 0.9519999623298645, Average Loss: 0.4421776020526886
Epoch 7, Average Accuracy 0.9449999928474426, Average Loss: 0.42987540245056155
Epoch 8, Average Accuracy 0.9649999737739563, Average Loss: 0.4235784935951233
Epoch 9, Average Accuracy 0.9350000023841858, Average Loss: 0.42061012744903564
Epoch 10, Average Accuracy 0.9449999928474426, Average Loss: 0.41523029446601867
Epoch 11, Average Accuracy 0.9599999785423279, Average Loss: 0.4051105892658234
Epoch 12, Average Accuracy 0.9449999928474426, Av

In [4]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)

In [5]:
test_data = torch.Tensor(X[5:10])
background_data = torch.Tensor(X[0:5])

importance_df = explainer.explain(test_data, background_data)
importance_df.head()

Unnamed: 0,source,target,source name,target name,value,type,source layer,target layer
0,1,497,A0M8Q6,R-HSA-166663,0.0,0,0,1
1,1,497,A0M8Q6,R-HSA-166663,0.0,1,0,1
2,1,954,A0M8Q6,R-HSA-198933,0.0,0,0,1
3,1,954,A0M8Q6,R-HSA-198933,0.0,1,0,1
4,1,539,A0M8Q6,R-HSA-2029481,0.0,0,0,1


In [6]:
from binn import ImportanceNetwork

IG = ImportanceNetwork(importance_df, norm_method="fan")

In [21]:
IG.plot_complete_sankey(
    multiclass=False, savename="img/complete_sankey.png", node_cmap="Accent_r", edge_cmap="Accent_r"
)

In [11]:
query_node = "R-HSA-597592"

IG.plot_subgraph_sankey(
    query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)