In [1]:
from binn import Network, BINN
import pandas as pd

input_data = pd.read_csv("../data/test_qm.csv")
translation = pd.read_csv("../data/translation.tsv", sep="\t")
pathways = pd.read_csv("../data/pathways.tsv", sep="\t")

network = Network(
    input_data=input_data,
    pathways=pathways,
    mapping=translation,
    source_column="child",
    target_column="parent",
)

binn = BINN(
    network=network,
    n_layers=4,
    dropout=0.2,
    validate=False,
    residual=False,
    device="cpu",
    learning_rate=0.001
)

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



BINN is on the device: cpu


In [2]:
from util_for_examples import fit_data_matrix_to_network_input, generate_data
import torch
from lightning.pytorch import Trainer

design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

protein_matrix = fit_data_matrix_to_network_input(input_data, features=network.inputs)

X, y = generate_data(protein_matrix, design_matrix=design_matrix)
dataset = torch.utils.data.TensorDataset(
    torch.tensor(X, dtype=torch.float32, device=binn.device),
    torch.tensor(y, dtype=torch.int16, device=binn.device),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

# You can train using the Lightning Trainer
trainer = Trainer(max_epochs=10, log_every_n_steps=10)
#trainer.fit(binn, dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


Since training with the ```Lightning.Trainer``` is slow (since new workers are created for each epoch), we can implement our own training-loop in a standard PyTorch-fashion.

In [3]:
import torch.nn.functional as F

# You can also train with a standard PyTorch train loop 

optimizer = binn.configure_optimizers()[0][0]

num_epochs = 30

for epoch in range(num_epochs):
    binn.train() 
    total_loss = 0.0
    total_accuracy = 0

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(binn.device)
        targets = targets.to(binn.device).type(torch.LongTensor)
        optimizer.zero_grad()
        outputs = binn(inputs).to(binn.device)
        loss = F.cross_entropy(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_accuracy += torch.sum(torch.argmax(outputs, axis=1) == targets) / len(targets)

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_accuracy / len(dataloader)
    print(f'Epoch {epoch}, Average Accuracy {avg_accuracy}, Average Loss: {avg_loss}')



Epoch 0, Average Accuracy 0.621999979019165, Average Loss: 0.7747332274913787
Epoch 1, Average Accuracy 0.7649999856948853, Average Loss: 0.5117802643775939
Epoch 2, Average Accuracy 0.7769999504089355, Average Loss: 0.46191721022129056
Epoch 3, Average Accuracy 0.8199999928474426, Average Loss: 0.3731368619203568
Epoch 4, Average Accuracy 0.8799999952316284, Average Loss: 0.3123727583885193
Epoch 5, Average Accuracy 0.8169999718666077, Average Loss: 0.37796577394008635
Epoch 6, Average Accuracy 0.8650000095367432, Average Loss: 0.3374636244773865
Epoch 7, Average Accuracy 0.8569999933242798, Average Loss: 0.35321935921907427
Epoch 8, Average Accuracy 0.8899999856948853, Average Loss: 0.2933012720942497
Epoch 9, Average Accuracy 0.9200000166893005, Average Loss: 0.22624582052230835
Epoch 10, Average Accuracy 0.9350000023841858, Average Loss: 0.21954303443431855
Epoch 11, Average Accuracy 0.8849999904632568, Average Loss: 0.24858492612838745
Epoch 12, Average Accuracy 0.9350000023841858

In [4]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)

In [5]:
test_data = torch.Tensor(X)
background_data = torch.Tensor(X)

importance_df = explainer.explain(test_data, background_data)
importance_df

Unnamed: 0,source,target,source name,target name,value,type,source layer,target layer
0,1,497,A0M8Q6,R-HSA-166663,0.033707,0,0,1
1,1,497,A0M8Q6,R-HSA-166663,0.033731,1,0,1
2,1,531,A0M8Q6,R-HSA-198933,0.033707,0,0,1
3,1,531,A0M8Q6,R-HSA-198933,0.033731,1,0,1
4,1,539,A0M8Q6,R-HSA-2029481,0.033707,0,0,1
...,...,...,...,...,...,...,...,...
6901,1319,0,R-HSA-9612973,root,0.035929,1,4,5
6902,1320,0,R-HSA-9709957,root,0.067488,0,4,5
6903,1320,0,R-HSA-9709957,root,0.255553,1,4,5
6904,1321,0,R-HSA-9748784,root,0.165022,0,4,5


In [6]:
from binn import ImportanceNetwork

IG = ImportanceNetwork(importance_df, norm_method="fan")

In [7]:
IG.plot_complete_sankey(
    show_top_n=5,
    multiclass=False, savename="img/complete_sankey.png", node_cmap="coolwarm", edge_cmap="coolwarm"
)

In [8]:
query_node = "R-HSA-597592"

IG.plot_subgraph_sankey(
    query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)

{2: 'rgba(58.6004535, 76.17308133, 192.189204015, 0.5)', 3: 'rgba(86.286010418, 115.46874525199999, 224.22585460399998, 0.5)', 5: 'rgba(142.73541501999998, 177.16577066, 253.99920384, 0.5)', 14: 'rgba(66.76036582, 88.35350065800002, 202.89273265499997, 0.5)', 15: 'rgba(59.76615526, 77.91314123400001, 193.718279535, 0.5)', 22: 'rgba(71.54023731, 95.222908116, 208.592739544, 0.5)', 26: 'rgba(236.669582465, 209.66027296, 195.11104406, 0.5)', 30: 'rgba(72.74455827, 96.93300277200001, 209.984480568, 0.5)', 31: 'rgba(198.231270878, 214.434909812, 241.268071906, 0.5)', 39: 'rgba(70.33591634999999, 93.51281346, 207.20099852, 0.5)', 40: 'rgba(67.92727442999998, 90.09262414800001, 204.417516472, 0.5)', 44: 'rgba(112.486294712, 147.01562553600002, 243.409062563, 0.5)', 46: 'rgba(214.034617907, 219.59754054200002, 228.09590967100002, 0.5)', 47: 'rgba(97.833826772, 130.09677140800002, 234.046821674, 0.5)', 48: 'rgba(171.497177945, 199.629669015, 252.955384659, 0.5)', 82: 'rgba(240.14062210499998, 1