In [None]:
import numpy as np
from rdkit import Chem
from openfe_benchmarks import hif2a
from kartograf import KartografAtomMapper
from konnektor.visualization import draw_ligand_network
from openfe.setup.atom_mapping.lomap_scorers import default_lomap_score

system= hif2a.get_system()
compounds = system.ligand_components
compounds = list(filter(lambda x: not x.name in ["lig_2", "lig_3", "lig_4", "lig_7"], compounds))
Chem.Draw.MolsToGridImage ([c.to_rdkit() for c in compounds], )

In [None]:
from kartograf.atom_mapper import KartografAtomMapper

mapper = KartografAtomMapper()

mappings = []
for cA in compounds:
    for cB in compounds:
        mapping = next(mapper.suggest_mappings(cA, cB))
        mappings.append(mapping)

mappings



# Radial Network Layout

In [None]:
from konnektor.network_planners import RadialLigandNetworkPlanner
ligand_network_planner = RadialLigandNetworkPlanner(mapper=KartografAtomMapper(), scorer=default_lomap_score)

In [None]:
radial_network = ligand_network_planner(compounds)
radial_network.name="Radial"
radial_network

In [None]:
fig = draw_ligand_network(radial_network, title="Radial Graph");
fig.savefig(radial_network.name+"_Network.png")
fig.show()

## Starry Sky Network Layout

In [None]:
from konnektor.network_planners import StarrySkyLigandNetworkPlanner
ligand_network_planner = StarrySkyLigandNetworkPlanner(mapper=KartografAtomMapper(), 
                                                       scorer=default_lomap_score, 
                                                       target_node_connectivity=3)

In [None]:
starry_sky_network = ligand_network_planner(compounds)
starry_sky_network.name="Starry Sky"
starry_sky_network

# Minimal Spanning Tree

In [None]:
from konnektor.network_planners import MinimalSpanningTreeLigandNetworkPlanner

ligand_network_planner = MinimalSpanningTreeLigandNetworkPlanner(mapper=KartografAtomMapper(), scorer=default_lomap_score)

In [None]:
mst_network = ligand_network_planner(compounds)
mst_network.name="MST"
mst_network

In [None]:
fig = draw_ligand_network(mst_network, "MST");
fig.savefig("MST_Network.png")
fig.show()

# Cyclo Graphs

In [None]:
from konnektor.network_planners import CyclicLigandNetworkPlanner
ligand_network_planner = CyclicLigandNetworkPlanner(mapper=KartografAtomMapper(), scorer=default_lomap_score,
                                                    cycle_sizes=3, node_present_in_cycles=2)

In [None]:
cyclic_network = ligand_network_planner(compounds)
cyclic_network.name="Cyclic"
cyclic_network

In [None]:
fig = draw_ligand_network(cyclic_network);
fig.savefig(cyclic_network.name+"_Network.png")
fig.show()

# Maximally connected network

In [None]:
from konnektor.network_planners import MaximalNetworkPlanner
ligand_network_planner = MaximalNetworkPlanner(mapper=KartografAtomMapper(), scorer=default_lomap_score)

In [None]:
max_network = ligand_network_planner(compounds)
max_network.name = "Max"
max_network

In [None]:
fig = draw_ligand_network(max_network);
fig.savefig(max_network.name+"_Network.png")
fig.show()

# Diversity Cluster Network

In [None]:
from konnektor.network_planners import DiversityNetworkPlanner
from sklearn.cluster import KMeans
ligand_network_planner = DiversityNetworkPlanner(mapper=KartografAtomMapper(), scorer=default_lomap_score)

In [None]:
div_network = ligand_network_planner(compounds)
div_network.name = "Diversity Cluster"
div_network

In [None]:
fig = draw_ligand_network(div_network);
fig.savefig(div_network.name+"_Network.png")
fig.show()

# Summary

In [None]:
fig, axes = plt.subplots(ncols=2, nrows=3, figsize=[9*2,3*9])
axes= np.array(axes).flat

fs = 22
for ax, net in zip(axes, [max_network, radial_network, mst_network, cyclic_network, div_network]):
    draw_ligand_network(network=net, title=net.name, ax=ax, node_size=1500)
    ax.axis("off")
axes[-1].axis("off")



In [None]:
#fig.savefig("../.img/network_layouts.png", dpi=400)

In [None]:
import openfe_benchmarks
dir(openfe_benchmarks)

## in one cell:

In [None]:
import logging
import datetime
import numpy as np
from matplotlib import pyplot as plt
from openfe_benchmarks import benzenes, hif2a, tyk2, p38, ptp1b, tnsk2, thrombin, cmet
from kartograf import KartografAtomMapper
from kartograf.atom_mapper import log
log.setLevel(logging.ERROR)

from konnektor.network_connecting_algorithms.bipartite_MST_connect import log
log.setLevel(logging.ERROR)
from kartograf.atom_align import align_mol_shape
from konnektor.visualization import draw_ligand_network
from openfe.setup.atom_mapping.lomap_scorers import default_lomap_score

tset=[benzenes, hif2a,
    tyk2, p38,
    ptp1b, tnsk2,
    thrombin,
    cmet]
for ts in tset:

    s = ts.get_system()
    compounds = list(filter(lambda x: not x.name in ["lig_2", "lig_3", "lig_4", "lig_7"],
                        s.ligand_components))
    from konnektor.network_planners import (MaximalNetworkPlanner, RadialLigandNetworkPlanner,
                                            MinimalSpanningTreeLigandNetworkPlanner, CyclicLigandNetworkPlanner, DiversityNetworkPlanner)

    networkers = [MaximalNetworkPlanner,
                  RadialLigandNetworkPlanner,
                  MinimalSpanningTreeLigandNetworkPlanner,
                  CyclicLigandNetworkPlanner, DiversityNetworkPlanner]


    networks = []
    for networker_cls, name in zip(networkers,["Max",
                                               "Radial", "MST",
                                               "Cyclic", "Div"]):
        try:
            start_time = datetime.datetime.now()
            networker = networker_cls(mapper=KartografAtomMapper(), scorer=default_lomap_score,)
            networker.progress =False
            network = networker.generate_ligand_network(compounds)
            end_time = datetime.datetime.now()
            network.name=name

            networks.append(network)
            duration =  end_time-start_time

            print( "{:<6}\t{:>12}\t{:>4}\t\t{:6>}\t{:6>}".format(name, s.system_name,str(len(compounds)), duration.seconds, np.round(sum([e.annotations["score"] for e in list(network.edges)]),2)))
        except:
            print()

    #Visualize
    fig, axes = plt.subplots(ncols=2, nrows=3, figsize=[16,3*9])
    axes= np.array(axes).flat
    fs=22
    for ax, net in zip(axes, networks):
        draw_ligand_network(network=net, title=net.name, ax=ax, node_size=1500, fontsize=fs)
        ax.axis("off")

    axes[-1].axis("off")
    #fig.show()
    fig.savefig(s.system_name+"_networks.png", dpi=400)