In [1]:
from biological_fuzzy_logic_networks.DREAM_analysis.utils import (
    create_bfz,
    prepare_cell_line_data,
    cl_data_to_input,
    data_to_nodes_mapping,
)
import pandas as pd
from typing import List, Union, Sequence
from app_tunnel.apps import mlflow_tunnel
from sklearn.metrics import r2_score
import click
import json
import torch
import pickle as pickle
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
markers = [v for k, v in data_to_nodes_mapping().items()]

In [3]:
pkn_sif="/dccstor/ipc1/CAR/DREAM/DREAMdata/PKN_Alice.sif"
network_class="DREAMBioFuzzNet"
data_file="/dccstor/ipc1/CAR/DREAM/DREAMdata/Time_aligned_per_cell_line/CL_incl_test/MFM223.csv"
output_dir="/dccstor/ipc1/CAR/DREAM/Model/Test/After_synthetic/"
time_point: int = 9
non_marker_cols: Sequence[str] = (
    "treatment",
    "cell_line",
    "time",
    "cellID",
    "fileID",
)
treatment_col_name = "treatment"
sample_n_cells = False
filter_starved_stim= True
sel_condition = "EGF" # Selects only one treatment for training and evaluation
scaler_type = "minmax"
add_root_values = True
input_value = 1
root_nodes = ("EGF", "SERUM")
replace_zero_inputs  = False
train_treatments = None
valid_treatments = None
train_cell_lines = None
valid_cell_lines = None
test_cell_lines  = None
inhibition_value = 1.0
learning_rate = 1e-3
n_epochs = 2
batch_size = 800
checkpoint_path = None
convergence_check = False 
shuffle_nodes = False # Is needed since it does early stopping and then reloads the saved checkpoint

In [4]:
model = create_bfz(pkn_sif, network_class, shuffle_nodes=shuffle_nodes)


In [5]:
cl_data = prepare_cell_line_data(
    data_file=data_file,
    time_point=time_point,
    non_marker_cols=non_marker_cols,
    treatment_col_name=treatment_col_name,
    filter_starved_stim=filter_starved_stim,
    sample_n_cells=sample_n_cells,
    sel_condition=sel_condition,
)

<class 'str'>
['MFM223']


In [6]:
cl_data

Unnamed: 0,treatment,cell_line,time,b-catenin,cleavedCas,CyclinB,GAPDH,IdU,Ki.67,4EBP1,...,PLCg2,RB,S6,p70S6K,SMAD23,SRC,STAT1,STAT3,STAT5,inhibitor
34047,EGF,MFM223,9.0,0.224784,2.795105,2.97808,2.392843,5.85840,1.59085,2.376591,...,1.496181,6.35094,4.44232,2.279256,2.575617,1.657935,1.863033,0.973149,1.792050,
34048,EGF,MFM223,9.0,1.696425,1.695905,2.37312,2.539773,4.96587,5.86276,5.128863,...,2.965710,2.53357,4.18838,2.476023,2.464410,2.482117,2.519847,0.652426,2.191470,
34049,EGF,MFM223,9.0,1.325228,2.634111,2.81751,3.953618,6.14365,2.74138,5.272276,...,4.207409,7.12715,4.04473,1.941882,2.263106,2.371338,2.778224,2.742117,2.284020,
34050,EGF,MFM223,9.0,0.224784,2.683894,2.89734,3.370827,4.51480,2.09219,3.982892,...,1.954332,4.95758,4.90296,3.205063,2.483510,2.737279,2.525305,0.652426,2.769220,
34051,EGF,MFM223,9.0,0.224784,0.898342,1.20489,0.565534,7.02619,1.59085,0.993163,...,1.190868,3.26278,2.39632,0.331558,0.551474,2.193704,-0.341181,0.652426,0.808508,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51844,EGF,MFM223,9.0,0.624565,2.646052,2.40503,2.978296,4.77824,1.59085,4.498697,...,1.992294,5.28350,4.93505,3.353234,0.719980,1.958217,2.078310,1.087982,2.429430,
51845,EGF,MFM223,9.0,0.782517,0.898342,5.64025,2.331753,5.62258,3.20836,2.033558,...,2.540564,3.78978,5.70999,2.889102,0.551474,2.849723,1.817126,0.697733,2.918340,
51846,EGF,MFM223,9.0,0.224784,0.898342,1.70719,2.525816,5.01940,1.59085,2.132341,...,0.986227,3.64621,2.39632,2.888444,0.551474,2.139019,1.864788,1.141459,1.632790,
51847,EGF,MFM223,9.0,0.224784,0.898342,1.47827,2.694557,5.00493,1.59085,0.993163,...,0.986227,3.52309,3.04512,1.598135,1.145315,2.397771,2.109824,0.652426,2.281400,


In [7]:
# Load train and valid data
(
    train_data,
    valid_data,
    train_inhibitors,
    valid_inhibitors,
    train_input,
    valid_input,
    train,
    valid,
    scaler,
) = cl_data_to_input(
    data=cl_data,
    model=model,
    train_treatments=train_treatments,
    valid_treatments=valid_treatments,
    train_cell_lines=train_cell_lines,
    valid_cell_lines=valid_cell_lines,
    inhibition_value=inhibition_value,
    scale_type=scaler_type, # Scaler trained on training data, applied on valid data and returned. Values smaller than zero are set to 0
    add_root_values=add_root_values,
    input_value=input_value,
    root_nodes=root_nodes,
    replace_zero_inputs=replace_zero_inputs, # If a value is given replace all zeros (after scaling) with this value
    balance_data=True, # Doesn't do anything...?
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  t[t < 0] = 0
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  t[t < 0] = 0
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  t[t > 1] = 1
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  t[t > 1] = 1


In [8]:
train_data

{'b-catenin': tensor([0.6911, 0.3427, 0.2131,  ..., 0.6284, 0.2530, 0.3357]),
 'cleavedCas': tensor([0.2968, 0.0848, 0.0000,  ..., 0.1466, 0.2656, 0.2578]),
 '4EBP1': tensor([0.7770, 0.5467, 0.6171,  ..., 0.6794, 0.6333, 0.5837]),
 'AKT_S473': tensor([0.6663, 0.8451, 0.7971,  ..., 0.7828, 0.6580, 0.4421]),
 'AKT_T308': tensor([0.6396, 0.4587, 0.2600,  ..., 0.4876, 0.4911, 0.1280]),
 'AMPK': tensor([0.6602, 0.4815, 0.5238,  ..., 0.4295, 0.4454, 0.3476]),
 'BTK': tensor([0.6997, 0.4730, 0.6174,  ..., 0.7426, 0.7524, 0.4959]),
 'CREB': tensor([1.0597e-01, 2.0643e-02, 9.1208e-02,  ..., 7.3788e-02, 2.2663e-07,
         8.2852e-02]),
 'ERK12': tensor([0.3463, 0.7947, 0.5966,  ..., 0.7500, 0.3729, 0.3788]),
 'FAK': tensor([0.4409, 0.4884, 0.4253,  ..., 0.6203, 0.5022, 0.2284]),
 'GSK3B': tensor([0.6965, 0.5561, 0.6473,  ..., 0.7010, 0.5085, 0.2315]),
 'H3': tensor([0.2154, 0.1656, 0.0936,  ..., 0.2297, 0.1114, 0.1640]),
 'JNK': tensor([0.3913, 0.3169, 0.1992,  ..., 0.4624, 0.5654, 0.3800]),
 

In [9]:
print(len(train_data['b-catenin']))
print(len(valid_data['b-catenin']))

13351
4451


In [10]:
train_inhibitors

{'SERUM': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'NFkB': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'PLCg2': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'FAK': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'PI3K': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'BTK': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'PIP3': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'AKT_S473': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'p53': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'RB': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'GSK3B': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'AMPK': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'cleavedCas': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'SMAD23': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'ERK12': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'MSK12': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'MKK36': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'H3': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'p90RSK': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'mTOR': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'AKT':

In [11]:
train_input

{'SERUM': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'EGF': tensor([1., 1., 1.,  ..., 1., 1., 1.])}

In [12]:
print(train[markers].max().max())
train[markers].min().min()

1.0000000000000002


0.0

In [13]:
# Optimize model
loss, best_val_loss, loop_states = model.conduct_optimisation(
    input=train_input,
    valid_input=valid_input,
    ground_truth=train_data,
    valid_ground_truth=valid_data,
    train_inhibitors=train_inhibitors,
    valid_inhibitors=valid_inhibitors,
    epochs=n_epochs,
    learning_rate=learning_rate,
    batch_size=batch_size,
    checkpoint_path=checkpoint_path,
    convergence_check=convergence_check,
    logger=None,
)

print("loss: ", loss)
print("best loss: ", best_val_loss)

Loss:4.12e-01: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [08:29<00:00, 254.75s/it]


loss:                            time      loss  phase
0   2024-07-18 09:05:46.935119  0.408509  train
1   2024-07-18 09:06:02.688487  0.414254  train
2   2024-07-18 09:06:18.205300  0.408887  train
3   2024-07-18 09:06:33.355348  0.417535  train
4   2024-07-18 09:06:48.919473   0.41629  train
5   2024-07-18 09:07:04.587062  0.409635  train
6   2024-07-18 09:07:19.618888  0.421432  train
7   2024-07-18 09:07:34.299279  0.410241  train
8   2024-07-18 09:07:49.439169  0.405525  train
9   2024-07-18 09:08:04.517168  0.418758  train
10  2024-07-18 09:08:19.870923  0.408928  train
11  2024-07-18 09:08:34.840812   0.41087  train
12  2024-07-18 09:08:50.178127  0.409202  train
13  2024-07-18 09:09:04.886018  0.411294  train
14  2024-07-18 09:09:19.618957  0.412486  train
15  2024-07-18 09:09:34.580879   0.40973  train
16  2024-07-18 09:09:48.772012  0.413197  train
17  2024-07-18 09:09:49.794005  0.414137  valid
18  2024-07-18 09:10:06.472736  0.408509  train
19  2024-07-18 09:10:21.388435  0

In [14]:
with torch.no_grad():
    model.initialise_random_truth_and_output(len(valid))
    model.set_network_ground_truth(valid_data)
    model.sequential_update(model.root_nodes, valid_inhibitors)
    val_output_states = pd.DataFrame(
        {k: v.numpy() for k, v in model.output_states.items()}
    )

# Vaidation performance
node_r2_scores = {}
for node in valid_data.keys():
    node_r2_scores[f"val_r2_{node}"] = r2_score(
        valid[node], val_output_states[node]
    )

In [15]:
node_r2_scores

{'val_r2_b-catenin': -7.880609628654527,
 'val_r2_cleavedCas': -43.89535337848481,
 'val_r2_4EBP1': -6.132318341701278,
 'val_r2_AKT_S473': -4.745727955774396,
 'val_r2_AKT_T308': -2.3167636330180996,
 'val_r2_AMPK': -19.573970960329753,
 'val_r2_BTK': -6.897675906434983,
 'val_r2_CREB': -35.29303539254972,
 'val_r2_ERK12': -6.484608117647116,
 'val_r2_FAK': -12.89289128515462,
 'val_r2_GSK3B': -6.188378354926441,
 'val_r2_H3': -25.612846999697783,
 'val_r2_JNK': -12.265462570791025,
 'val_r2_MAP3Ks': -8.938303129187936,
 'val_r2_MAPKAPK2': -20.574088755142427,
 'val_r2_MEK12': -20.72854870882006,
 'val_r2_MKK36': -10.094849988229505,
 'val_r2_MKK4': -8.003472966327555,
 'val_r2_NFkB': -9.02511086305969,
 'val_r2_p38': -7.903707703419265,
 'val_r2_p53': -23.20125848529024,
 'val_r2_p90RSK': -6.801461239505861,
 'val_r2_PDPK1': -3.798775554880339,
 'val_r2_PLCg2': -13.244167649138964,
 'val_r2_RB': -9.690725262576162,
 'val_r2_S6': -12.793977633888508,
 'val_r2_p70S6K': -10.707531227245

In [None]:


# Optimize model
loss, best_val_loss, loop_states = model.conduct_optimisation(
    input=train_input,
    valid_input=valid_input,
    ground_truth=train_data,
    valid_ground_truth=valid_data,
    train_inhibitors=train_inhibitors,
    valid_inhibitors=valid_inhibitors,
    epochs=n_epochs,
    learning_rate=learning_rate,
    batch_size=batch_size,
    checkpoint_path=checkpoint_path,
    convergence_check=convergence_check,
    logger=None,
)

print("loss: ", loss)
print("best loss: ", best_val_loss)

if convergence_check:
    temp = {
        idx: {m: v.detach().numpy() for (m, v) in m.items()}
        for (idx, m) in loop_states.items()
    }
    loop_states_to_save = pd.concat(
        [pd.DataFrame(v) for k, v in temp.items()],
        keys=temp.keys(),
        names=["time", ""],
    ).reset_index("time", drop=False)
    loop_states_to_save.to_csv(f"{output_dir}loop_states.csv")

# Load best model and evaluate:
ckpt = torch.load(f"{checkpoint_path}/model.pt")
model = create_bfz(pkn_sif, network_class)
model.load_from_checkpoint(ckpt["model_state_dict"])
with torch.no_grad():
    model.initialise_random_truth_and_output(len(valid))
    model.set_network_ground_truth(valid_data)
    model.sequential_update(model.root_nodes, valid_inhibitors)
    val_output_states = pd.DataFrame(
        {k: v.numpy() for k, v in model.output_states.items()}
    )

# Vaidation performance
node_r2_scores = {}
for node in valid_data.keys():
    node_r2_scores[f"val_r2_{node}"] = r2_score(
        valid[node], val_output_states[node]
    )


# Save outputs
with open(f"{output_dir}scaler.pkl", "wb") as f:
    pickle.dump(scaler, f)
val_output_states.to_csv(f"{output_dir}valid_output_states.csv")
loss.to_csv(f"{output_dir}loss.csv")
train.to_csv(f"{output_dir}train_data.csv")
valid.to_csv(f"{output_dir}valid_data.csv")
pd.DataFrame(train_inhibitors).to_csv(f"{output_dir}train_inhibitors.csv")
pd.DataFrame(valid_inhibitors).to_csv(f"{output_dir}valid_inhibitors.csv")