In [1]:
from binn import Network, BINN
import pandas as pd

input_data = pd.read_csv("../data/test_qm.csv")
translation = pd.read_csv("../data/translation.tsv", sep="\t")
pathways = pd.read_csv("../data/pathways.tsv", sep="\t")

network = Network(
    input_data=input_data,
    pathways=pathways,
    mapping=translation,
    source_column="child",
    target_column="parent",
)

binn = BINN(
    network=network,
    n_layers=4,
    dropout=0.2,
    validate=False,
    device="cpu",
    learning_rate=0.001
)

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



BINN is on the device: cpu


In [2]:
from util_for_examples import fit_data_matrix_to_network_input, generate_data
import torch
from lightning.pytorch import Trainer

design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")

protein_matrix = fit_data_matrix_to_network_input(input_data, features=network.inputs)

X, y = generate_data(protein_matrix, design_matrix=design_matrix)
dataset = torch.utils.data.TensorDataset(
    torch.tensor(X, dtype=torch.float32, device=binn.device),
    torch.tensor(y, dtype=torch.int16, device=binn.device),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

# You can train using the Lightning Trainer
trainer = Trainer(max_epochs=10, log_every_n_steps=10)
#trainer.fit(binn, dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


Since training with the ```Lightning.Trainer``` is slow (since new workers are created for each epoch), we can implement our own training-loop in a standard PyTorch-fashion.

In [3]:
import torch.nn.functional as F

# You can also train with a standard PyTorch train loop 

optimizer = binn.configure_optimizers()[0][0]

num_epochs = 30

for epoch in range(num_epochs):
    binn.train() 
    total_loss = 0.0
    total_accuracy = 0

    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(binn.device)
        targets = targets.to(binn.device).type(torch.LongTensor)
        optimizer.zero_grad()
        outputs = binn(inputs).to(binn.device)
        loss = F.cross_entropy(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_accuracy += torch.sum(torch.argmax(outputs, axis=1) == targets) / len(targets)

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_accuracy / len(dataloader)
    print(f'Epoch {epoch}, Average Accuracy {avg_accuracy}, Average Loss: {avg_loss}')



Epoch 0, Average Accuracy 0.5419999957084656, Average Loss: 0.8137147390842437
Epoch 1, Average Accuracy 0.6869999766349792, Average Loss: 0.5829368078708649
Epoch 2, Average Accuracy 0.7519999742507935, Average Loss: 0.4501472520828247
Epoch 3, Average Accuracy 0.8519999980926514, Average Loss: 0.3524850034713745
Epoch 4, Average Accuracy 0.8140000104904175, Average Loss: 0.36932810574769975
Epoch 5, Average Accuracy 0.8369999527931213, Average Loss: 0.33754108875989913
Epoch 6, Average Accuracy 0.8190000057220459, Average Loss: 0.3624538911879063
Epoch 7, Average Accuracy 0.8740000128746033, Average Loss: 0.3092288127541542
Epoch 8, Average Accuracy 0.859000027179718, Average Loss: 0.35472473621368406
Epoch 9, Average Accuracy 0.8869999647140503, Average Loss: 0.26175024479627607
Epoch 10, Average Accuracy 0.9399999976158142, Average Loss: 0.18241691201925278
Epoch 11, Average Accuracy 0.9019999504089355, Average Loss: 0.23924426399171353
Epoch 12, Average Accuracy 0.8899999856948853

In [4]:
from binn import BINNExplainer

explainer = BINNExplainer(binn)

In [5]:
test_data = torch.Tensor(X)
background_data = torch.Tensor(X)

importance_df = explainer.explain(test_data, background_data)
importance_df

Unnamed: 0,source,target,source name,target name,value,type,source layer,target layer
0,1,497,A0M8Q6,R-HSA-166663,0.041265,0,0,1
1,1,497,A0M8Q6,R-HSA-166663,0.049803,1,0,1
2,1,531,A0M8Q6,R-HSA-198933,0.041265,0,0,1
3,1,531,A0M8Q6,R-HSA-198933,0.049803,1,0,1
4,1,539,A0M8Q6,R-HSA-2029481,0.041265,0,0,1
...,...,...,...,...,...,...,...,...
6901,1319,0,R-HSA-9612973,root,0.172510,1,4,5
6902,1320,0,R-HSA-9709957,root,0.153581,0,4,5
6903,1320,0,R-HSA-9709957,root,0.187209,1,4,5
6904,1321,0,R-HSA-9748784,root,0.225962,0,4,5


In [6]:
plot_df = importance_df.copy()

id_to_name = pd.read_csv(
    "../data/id_to_name.txt", sep="\t", names=["id", "name", "species"]
)
id_to_name = id_to_name[id_to_name["species"] == "Homo sapiens"]

human_proteome = pd.read_csv("../data/human_proteome.gz")
proteome_mapping = (
    human_proteome.set_index("accession").drop(columns=["seq"]).to_dict()["trivname"]
)
mapping = id_to_name.drop(columns="species").set_index("id").to_dict()["name"]

mapping.update(proteome_mapping)

mapping.update({"root": "root"})

plot_df["source name"] = plot_df["source name"].map(mapping) + "_" + plot_df["source layer"].astype(str)
plot_df["target name"] = plot_df["target name"].map(mapping)+ "_" + plot_df["target layer"].astype(str)
plot_df.sort_values("value", ascending=False).head(3)

Unnamed: 0,source,target,source name,target name,value,type,source layer,target layer
1725,138,748,H2A1B_HUMAN_0,Mitotic Prophase_1,0.454296,1,0,1
1693,138,574,H2A1B_HUMAN_0,DNA Damage/Telomere Stress Induced Senescence_1,0.454296,1,0,1
1707,138,668,H2A1B_HUMAN_0,B-WICH complex positively regulates rRNA expre...,0.454296,1,0,1


In [7]:
from binn import ImportanceNetwork
def translate(x):
    if not isinstance(x, str):
        return x
    return x
        
plot_df["source name"] = plot_df["source name"].apply(lambda x: translate(x))
plot_df["target name"] = plot_df["target name"].apply(lambda x: translate(x))
plot_df.dropna(subset="source name", inplace=True)
IG = ImportanceNetwork(plot_df, norm_method="fan")

In [8]:
IG.plot_complete_sankey(
    show_top_n=10,
    multiclass=False, savename="img/complete_sankey.png", node_cmap="coolwarm", edge_cmap="coolwarm"
)

In [9]:
query_node = "Neutrophil degranulation_2"

IG.plot_subgraph_sankey(
    query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)