In [None]:
# import libraries
import BioSimSpace as BSS
import os
import sys
import glob
import csv
import numpy as np
import networkx as nx
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from sklearn.preprocessing import minmax_scale
import itertools

print("adding code to the pythonpath...")
code = '/home/anna/Documents/code/python'
if code not in sys.path:
    sys.path.insert(1, code)
import pipeline

from pipeline.prep import *
from pipeline.utils import *

pipeline.__file__

In [None]:
# now want to start putting it all together by initialising the pipeline
# this is so can have all the file locations 

pl = initialise_pipeline()
# where the ligands for the pipeline are located. These should all be in the same folder in sdf format
pl.ligands_folder(f"/home/anna/Documents/benchmark/inputs/tyk2/ligands")
# where the pipeline should be made
pl.main_folder("/home/anna/Documents/code/write/test")


In [None]:
# make default protocols
pl.setup_protocols()

# edit as needed

#setup ligands and network
pl.setup_ligands()
pl.setup_network()

##### <span style="color:teal">Protein parameterisation</span>  

This needs to be carried out carefully.

can parameterise using :
```python
prot = BSS.IO.readPDB(path_to_protein, pdb4amber=False)[0]
prot_p = BSS.Parameters.parameterise(prot, protocol.protein_forcefield()).getMolecule()
BSS.IO.saveMolecules("inputs/protein", prot_p, ["PRM7","RST7"])
```

tleap may fail. Best to parameterise carefully before and also consider crystal waters.

can view using:
```python
BSS.Notebook.View(f"{input_dir}/{protein}/protein/{protein}_parameterised.pdb").system()
```


In [None]:
# add the protein file locations to the pipeline setup object
pl.protein_path(f"/home/anna/Documents/benchmark/inputs/tyk2/tyk2_parameterised")


In [None]:
# write the run_all script, also does a final ligand and network write
pl.write_run_all()

##### <span style="color:teal">Generating the RBFENN</span>  


first want to make all the stuff for it

In [None]:
# for sem perturbations
tgt_to_run = f"{protein}_rename" #f"{protein}_me" f"{protein}_rename" for tyk2 and p38
cats_files_path = f"{main_folder}/scripts/RBFENN/ANALYSIS/perturbation_networks/output/series_predictions"

In [None]:
def scaleArray(arr):
    """Scales an array to be the inverse in the range [0-1]."""
    
    # normalise to the range 0-1.
    return minmax_scale(1 /  arr, feature_range=(0,1))


In [None]:
# get the FEPNN SEM prediction per ligand.
perts = {}
for cats_file in glob.glob(f"{cats_files_path}/{tgt_to_run}_*"):
    
    with open(cats_file, "r") as readfile:
        reader = csv.reader(readfile)
        next(reader)
        for row in reader:
            pert = row[0]
            pred_sem = float(row[1])
            
            if not pert in perts:
                perts[pert] = [pred_sem]
            else:
                perts[pert].append(pred_sem)
            
# compute the mean SEM prediction per pert.
pert_names = []
pert_sems = []
for pert, sems in perts.items():
    mean_sem = np.mean(sems)
    pert_names.append(pert)
    pert_sems.append(float(mean_sem))

# now scale the sems to [0-1].
pert_sems = scaleArray(np.array(pert_sems))

for pert, val in zip(pert_names, pert_sems):
    perts[pert] = val

In [None]:
# make folder for the RBFENN network
validate.folder_path(f"{pl.exec_folder()}/RBFENN", create=True)

written = []
with open(f"{pl.exec_folder()}/RBFENN/links_file.in", "w") as writefile:
    writer = csv.writer(writefile, delimiter =" ")
    
    for pert_name, value in perts.items():
        # find the lomap filename for both ligs.
        liga_lomap_name = None
        ligb_lomap_name = None
        for filename in glob.glob(f"{pl.ligands_folder()}/*.sdf"):
            # if "lig_8" in filename:
            #     continue # exclude +1 ligands from tnks2 set.
            if pert_name.split("~")[0] in filename:
                liga_lomap_name = filename.split("/")[-1].split(".")[0]#.replace("ejm","ejm_").replace("jmc","jmc_")
            elif pert_name.split("~")[1] in filename:
                ligb_lomap_name = filename.split("/")[-1].split(".")[0]#.replace("ejm","ejm_").replace("jmc","jmc_")
            
            if liga_lomap_name and ligb_lomap_name:
                if not [liga_lomap_name, ligb_lomap_name] in written:
                    writer.writerow([liga_lomap_name, ligb_lomap_name, value])
                    
                    written.append([liga_lomap_name, ligb_lomap_name])

In [None]:
# ligands and ligands_names already exists due to lomap above
# if change the folder name, will put this in the execution model as default.
pl.setup_network(folder="RBFENN", links_file=f"{pl.exec_folder()}/RBFENN/links_file.in")

# this will update the existing network.
pl.write_network(file_path=f"{pl.exec_folder()}/rbfenn_network.dat")

In [None]:
# write the rbfenn to a different network file

##### <span style="color:teal">Comparing lomap and the rbfenn</span>  


In [None]:
# get a list of the perts in each and all together
perts_lomap = []
perts_rbfenn = []
perts = []

with open(f"{exec_folder}/network_rbfenn_scores.dat", "r") as fepnn_file, \
        open(f"{exec_folder}/network_lomap_scores.dat", "r") as lomap_file:
    reader_fepnn = csv.reader(fepnn_file)
    reader_lomap = csv.reader(lomap_file)
    
    for line in reader_fepnn:
        perts_rbfenn.append(f"{line[0]}~{line[1]}")
        perts.append(f"{line[0]}~{line[1]}")
    for line in reader_lomap:
        perts_lomap.append(f"{line[0]}~{line[1]}")
        perts.append(f"{line[0]}~{line[1]}")

# write a file that contains the combined perts, directions are distinct     
combined_perts = []
filtered_out = 0
for pert in perts:
    
    if not pert in combined_perts:
        combined_perts.append(pert)
    else:
        filtered_out += 1
print(f"Removed {filtered_out} duplicate perts between lomap and rbfenn to give {len(combined_perts)} combined perts.")

# write a file that contains the unique perts, 1 direction only.      
filtered_perts = []
filtered_out = 0
for pert in combined_perts:
    
    inv_pert = pert.split("~")[1]+"~"+pert.split("~")[0]
    
    if not pert in filtered_perts and not inv_pert in filtered_perts:
        filtered_perts.append(pert)
    else:
        filtered_out += 1
print(f"Removed {filtered_out} inverse perts to give {len(filtered_perts)} unique perts, one direction only.")

# get the perts that are unique to each
unique_perts = []
unique_out_lomap = 0
unique_out_rbfenn = 0
shared_out = 0

for pert in perts_lomap:
    
    inv_pert = pert.split("~")[1]+"~"+pert.split("~")[0]
    
    if not pert in perts_rbfenn and not inv_pert in perts_rbfenn:
        unique_perts.append((pert, "lomap"))
        unique_out_lomap += 1

for pert in perts_rbfenn:
    
    inv_pert = pert.split("~")[1]+"~"+pert.split("~")[0]
    
    if not pert in perts_lomap and not inv_pert in perts_lomap:
        unique_perts.append((pert, "rbfenn"))
        unique_out_rbfenn += 1
    
for pert in combined_perts:

    inv_pert = pert.split("~")[1]+"~"+pert.split("~")[0]

    if pert in perts_lomap or inv_pert in perts_lomap:
        if pert in perts_rbfenn or inv_pert in perts_rbfenn:
            unique_perts.append((pert, "shared"))
            shared_out += 1

       
print(f"There are {unique_out_lomap} pert(s) unique to lomap and {unique_out_rbfenn} pert(s) unique to rbfenn.")
print(f"There are {shared_out} pert(s) shared between lomap and rbfenn")

In [None]:
with open(f"{exec_folder}/combined_perts.dat", "w") as writefile:
    writer = csv.writer(writefile)
    for pert in combined_perts:
        writer.writerow([pert])
print(f"Total number of combined perturbations: {len(combined_perts)}")

with open(f"{exec_folder}/filtered_perts.dat", "w") as writefile:
    writer = csv.writer(writefile)
    for pert in filtered_perts:
        writer.writerow([pert])
print(f"Total number of filtered perturbations: {len(filtered_perts)}")

# write a file for the different perts
with open(f"{exec_folder}/unique_perts.dat", "w") as writefile:
    writer = csv.writer(writefile)
    for pert in unique_perts:
        writer.writerow([pert[0],pert[1]])
print(f"Total number of unique perturbations: {len(unique_perts)} (lomap: {unique_out_lomap}, rbfenn: {unique_out_rbfenn})")
print(f"Total number of shared perturbations: {shared_out}")

In [None]:
# make a dicitonary of the perts for plotting the nx graph
both_pert_networks_dict = {}

for pert in filtered_perts:

    inv_pert = pert.split("~")[1]+"~"+pert.split("~")[0]
    
    if pert in perts_lomap and pert in perts_rbfenn:
        both_pert_networks_dict[pert] = "both"
    elif inv_pert in perts_lomap and pert in perts_rbfenn:
        both_pert_networks_dict[pert] = "both"
    elif pert in perts_lomap and pert not in perts_rbfenn:
        both_pert_networks_dict[pert] = "lomap"
    elif inv_pert in perts_lomap and pert not in perts_rbfenn:
        both_pert_networks_dict[pert] = "lomap"
    elif pert not in perts_lomap and pert in perts_rbfenn:
        both_pert_networks_dict[pert] = "rbfenn"
    elif inv_pert not in perts_lomap and pert in perts_rbfenn:
        both_pert_networks_dict[pert] = "rbfenn"      

# create dict for images for the nx graph
image_dict = {}
# list files in inputs
input_files_for_image = sorted(os.listdir(f"{exec_folder}/visualise_network_lomap/inputs"))
for in_file in input_files_for_image:
    lig_name_list = in_file.split("_")[1:]
    lig_name = '_'.join(lig_name_list).split(".")[0]
    lig_number = in_file.split("_")[0]
    image_dict[lig_name] = lig_number   

In [None]:
both_pert_networks_dict

In [None]:
# Generate the graph.
graph = nx.Graph()

# Loop over the nligands and add as nodes to the graph.
for lig in ligand_names:
    img = f"{exec_folder}/visualise_network_lomap/images/{image_dict[lig]}.png"
    graph.add_node(lig, image=img, label=lig, labelloc="t")

# make a dict of colours
# navy teal  #CC00CC
# clear is '#FF000000' 
colour_dict = {"both":'navy' ,"lomap":'teal' ,"rbfenn":'hotpink' }

# Loop over the edges in the dictionary and add to the graph.
for edge in both_pert_networks_dict:
    graph.add_edge(edge.split("~")[0],edge.split("~")[1],
                    color=colour_dict[both_pert_networks_dict[edge]]
                    )

# Plot the networkX graph.
pos = nx.kamada_kawai_layout(graph)
colours = nx.get_edge_attributes(graph,'color').values()

plt.figure(figsize=(12,12), dpi=150)
nx.draw(
    graph, pos, edge_color=colours, width=1, linewidths=5,
    node_size=2000, node_color='skyblue', font_size = 12,
    labels={node: node for node in graph.nodes()})

plt.savefig(f"{exec_folder}/compared_networks_no_images.png", dpi=300)
# plt.show()

# Convert to a dot graph.
dot_graph = nx.drawing.nx_pydot.to_pydot(graph)

# Write to a PNG.
network_plot = f"{exec_folder}/compared_networks.png"
dot_graph.write_png(network_plot)

# Create a plot of the network.
img = mpimg.imread(network_plot)
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(img)


In [None]:
# calculate the lomap score for the combined network file (unfiltered)
combined_pert_network_dict = {}

for pert in combined_perts:
    lig_a = pert.split("~")[0]
    lig_b = pert.split("~")[1]
    # then, we need to find this index for our chosen edge that we are adding.
    lig_a_index = ligand_names.index(lig_a)
    lig_b_index = ligand_names.index(lig_b)
    # finally, we need to calculate this single lomap score.
    single_transformation, single_lomap_score = BSS.Align.generateNetwork([ligands[lig_a_index], ligands[lig_b_index]], names=[ligand_names[lig_a_index], ligand_names[lig_b_index]], plot_network=False)
    print(f"LOMAP score for {lig_a} to {lig_b} is {single_lomap_score[0]} .")
    combined_pert_network_dict[(lig_a, lig_b)] = single_lomap_score[0]

In [None]:
# write the combined to a different network file

with open(f"{exec_folder}/network_combined.dat", "w") as network_file:

    writer = csv.writer(network_file, delimiter=" ")
    
    for pert, score in combined_pert_network_dict.items():
        # # based on the provided (at top of notebook) lambda allocations and LOMAP threshold, decide allocation.
        # if score == None or score < float(node.getInput("LOMAP Threshold")):
        #     num_lambda = node.getInput("DiffLambdaWindows")
        # else:
        #     num_lambda = node.getInput("LambdaWindows")
        
        num_lambda = node.getInput("LambdaWindows")            
       
        # given the number of allocated lambda windows, generate an array for parsing downstream.
        lam_array_np = np.around(np.linspace(0, 1, int(num_lambda)), decimals=5)

        # make the array into a format readable by bash.
        lam_array = str(lam_array_np).replace("[ ", "").replace("]", "").replace("  ", ",").replace('\n', '')

        # write out both directions for this perturbation.
        if engine == "ALL":
            for eng in BSS.FreeEnergy.engines():
                writer.writerow([pert[0], pert[1], len(lam_array_np), lam_array, eng])
        else:
            writer.writerow([pert[0], pert[1], len(lam_array_np), lam_array, engine])
        # writer.writerow([pert[1], pert[0], len(lam_array_np), lam_array, engine])         