In [16]:
from binn import Network, BINN
import pandas as pd

input_data = pd.read_csv("../data/test_qm.csv")
translation = pd.read_csv("../data/translation.tsv", sep="\t")
pathways = pd.read_csv("../data/pathways.tsv", sep="\t")

network = Network(
    input_data=input_data,
    pathways=pathways,
    mapping=translation,
    source_column="child",
    target_column="parent",
)

binn = BINN(
    network=network,
    n_layers=4,
    dropout=0.2,
    validate=False,
    residual=False,
    device="cpu",
    learning_rate=0.001
)


BINN is on the device: cpu


In [17]:
from util_for_examples import fit_data_matrix_to_network_input, generate_data
import torch
from lightning.pytorch import Trainer

design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

protein_matrix = fit_data_matrix_to_network_input(input_data, features=network.inputs)

X, y = generate_data(protein_matrix, design_matrix=design_matrix)
dataset = torch.utils.data.TensorDataset(
    torch.tensor(X, dtype=torch.float32, device=binn.device),
    torch.tensor(y, dtype=torch.int16, device=binn.device),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

# You can train using the Lightning Trainer
trainer = Trainer(max_epochs=10, log_every_n_steps=10)
#trainer.fit(binn, dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Since training with the ```Lightning.Trainer``` is slow (since new workers are created for each epoch), we can implement our own training-loop in a standard PyTorch-fashion.

In [18]:
import torch.nn.functional as F

# You can also train with a standard PyTorch train loop 

optimizer = binn.configure_optimizers()[0][0]

num_epochs = 30

for epoch in range(num_epochs):
    binn.train() 
    total_loss = 0.0
    total_accuracy = 0

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(binn.device)
        targets = targets.to(binn.device).type(torch.LongTensor)
        optimizer.zero_grad()
        outputs = binn(inputs).to(binn.device)
        loss = F.cross_entropy(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_accuracy += torch.sum(torch.argmax(outputs, axis=1) == targets) / len(targets)

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_accuracy / len(dataloader)
    print(f'Epoch {epoch}, Average Accuracy {avg_accuracy}, Average Loss: {avg_loss}')



Epoch 0, Average Accuracy 0.5519999861717224, Average Loss: 0.7984879660606384
Epoch 1, Average Accuracy 0.7119999527931213, Average Loss: 0.5876554298400879
Epoch 2, Average Accuracy 0.7699999809265137, Average Loss: 0.5000520890951157
Epoch 3, Average Accuracy 0.8059999942779541, Average Loss: 0.4410453510284424
Epoch 4, Average Accuracy 0.7699999809265137, Average Loss: 0.48531828820705414
Epoch 5, Average Accuracy 0.8100000023841858, Average Loss: 0.38955733954906463
Epoch 6, Average Accuracy 0.7809999585151672, Average Loss: 0.4015652185678482
Epoch 7, Average Accuracy 0.7890000343322754, Average Loss: 0.4215572929382324
Epoch 8, Average Accuracy 0.8889999985694885, Average Loss: 0.31606161445379255
Epoch 9, Average Accuracy 0.8769999742507935, Average Loss: 0.3141865438222885
Epoch 10, Average Accuracy 0.8849999904632568, Average Loss: 0.31385430574417117
Epoch 11, Average Accuracy 0.871999979019165, Average Loss: 0.2942655232548714
Epoch 12, Average Accuracy 0.8399999737739563, 

In [19]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)

In [20]:
test_data = torch.Tensor(X[5:10])
background_data = torch.Tensor(X[0:5])

importance_df = explainer.explain(test_data, background_data)
importance_df.head()

Unnamed: 0,source,target,source name,target name,value,type,source layer,target layer
0,1,497,A0M8Q6,R-HSA-166663,0.0,0,0,1
1,1,497,A0M8Q6,R-HSA-166663,0.0,1,0,1
2,1,954,A0M8Q6,R-HSA-198933,0.0,0,0,1
3,1,954,A0M8Q6,R-HSA-198933,0.0,1,0,1
4,1,539,A0M8Q6,R-HSA-2029481,0.0,0,0,1


In [21]:
from binn import ImportanceNetwork

IG = ImportanceNetwork(importance_df, norm_method="fan")

In [22]:
IG.plot_complete_sankey(
    multiclass=False, savename="img/complete_sankey.png", node_cmap="Accent_r", edge_cmap="Accent_r"
)

In [23]:
query_node = "R-HSA-597592"

IG.plot_subgraph_sankey(
    query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)