In [None]:
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pandas as pd
from tqdm import tqdm
import numpy as np
from rdkit.Chem import AllChem
import seaborn as sns

FIG_PATH = "../-Philippe-MolGenDocking/Figures/reaction_data_comparison"
os.makedirs(FIG_PATH, exist_ok=True)
tqdm.pandas()

# Prompts

In [None]:
from mol_gen_docking.data.pydantic_dataset import read_jsonl
from pathlib import Path

def load(path:str):
    data = read_jsonl(Path(path))
    return [line.conversations[0].meta for line in data]

def get_n_reactions_steps(row: pd.Series):
    if not row["type of objective"] in ["full synthesis", "product prediction"]:
        return row["idx_chosen"] + 1
    else:
        return len(row["full_reaction"].split("\n"))

def runtime_get_reactants(row: pd.Series):
    full_reaction = row["full_reaction"]
    reac_steps = full_reaction.split("\n")
    reactants = [
        reac.split(" -> ")[0].split(" + ") for reac in reac_steps
    ]
    return reactants

def runtime_get_products(row: pd.Series):
    full_reaction = row["full_reaction"]
    reac_steps = full_reaction.split("\n")
    products = [
        reac.split(" -> ")[1].split(" + ")[0] for reac in reac_steps
    ]
    return products

def get_df(data_d):
    df = pd.DataFrame(data_d)
    df= df.explode(["properties", "objectives"]).reset_index(drop=True)
    df["last molecule passes filter"] = df.pass_filters.apply(lambda x: x[-1])
    df["prop. of molecules passing filter"] = df.pass_filters.apply(lambda x: np.mean(x))

    df["type of objective"]  = df.objectives.apply(
        lambda x: "full synthesis" if "full_path" in x else "reactant prediction" if "reactant" in x else "product prediction" if "products" in x else "product prediction" if "product" in x else x
    )


    df["all_reactants"] = df.apply(runtime_get_reactants, axis = 1)
    df["all_products"] = df.apply(runtime_get_products, axis = 1)
    return df

df_zinc = pd.read_csv("data/properties.csv")

df_zinc.rename(columns={
    "QED":"qed",
    "CalcExactMolWt":"ExactMolWt",
    "CalcTPSA":"TPSA",
    "CalcNumHBA":"NumHAcceptors",
    "CalcNumHBD":"NumHDonors",
    "CalcNumRotatableBonds":"NumRotatableBonds",
    "CalcNumAromaticRings":"NumAromaticRings",
    "CalcHallKierAlpha":"HallKierAlpha",
}, inplace=True)

all_dfs = []
for jsonl_file in os.listdir("data/synthesis"):
    if jsonl_file == "train_prompts.jsonl":
        continue
    print(jsonl_file)
    n_reaction_retry = jsonl_file.split(".")[0].split("_")[-2]
    n_bb_retry = jsonl_file.split(".")[0].split("_")[-1]
    all_dfs.append(get_df(load("data/synthesis/" + jsonl_file)))
    print(all_dfs[-1].shape)
    all_dfs[-1]["n_reaction_retry"] = int(n_reaction_retry)
    all_dfs[-1]["n_bb_retry"] = int(n_bb_retry)

df = pd.concat(all_dfs).reset_index(drop=True)
df["reaction_steps"] = pd.Categorical(
        df.apply(
           get_n_reactions_steps, axis=1
        ), ordered=True
    )

df_per_step = df[["all_products", "all_reactants", "or_smarts", "pass_filters", "n_reaction_retry", "n_bb_retry"]].rename(
    columns={
        "all_reactants": "reactants",
        "all_products": "products",
    }
)
df_per_step["n_step"] =df_per_step.products.apply(lambda x: list(range(1,1+len(x))))
df_per_step = df_per_step.explode(["reactants","products", "or_smarts", "pass_filters", "n_step"])
df_per_step["n_reactants"] = pd.Categorical(df_per_step.reactants.apply(len), ordered=True)
df_per_step = df_per_step.reset_index(drop=True)
df_per_step


df

In [None]:
import matplotlib.patches as mpatches
palette = {
    obj: sns.color_palette("colorblind", n_colors=df["type of objective"].nunique())[i] for i,obj in enumerate(df["type of objective"].unique())
}
g = sns.FacetGrid(
    df,
    row="n_bb_retry",
    col = "n_reaction_retry",
    row_order=sorted(df["n_bb_retry"].unique()),
    col_order=sorted(df["n_reaction_retry"].unique()),
    palette=palette,
    sharex=True,
    sharey=True,
    margin_titles=True,
)

g.map_dataframe(sns.histplot, x="reaction_steps", hue="type of objective", multiple="stack", palette=palette, hue_order=list(palette.keys()))
handles = [
    mpatches.Patch(color=color, label=label)
    for label, color in palette.items()
]

g.fig.legend(handles=handles, title="type of objective", loc="upper right", bbox_to_anchor=(1.5, 0.5))
g.fig.tight_layout()
g.fig.savefig(os.path.join(FIG_PATH,"reaction_steps_histogram.pdf"), bbox_inches='tight')

In [None]:
g = sns.FacetGrid(
    df_per_step,
    row="n_bb_retry",
    col = "n_reaction_retry",
    row_order=sorted(df["n_bb_retry"].unique()),
    col_order=sorted(df["n_reaction_retry"].unique()),
    palette=palette,
    sharex=True,
    sharey=True,
    margin_titles=True,
)

g.map_dataframe(sns.histplot, x="n_reactants", hue="n_step", multiple="stack", palette="viridis")


g.fig.tight_layout()
# g.fig.savefig(os.path.join(FIG_PATH,"n_reactants_histogram.pdf"), bbox_inches='tight')

In [None]:
def plot_descriptor_distributions(df_descriptors, df_zinc, col="ExactMolWt", xlim=None):
    g = sns.FacetGrid(
        df_descriptors,
        row="n_bb_retry",
        col = "n_reaction_retry",
        row_order=sorted(df_descriptors["n_bb_retry"].unique()),
        col_order=sorted(df_descriptors["n_reaction_retry"].unique()),
        sharex=True,
        sharey=True,
        xlim=xlim,
        margin_titles=True
    )

    g.map_dataframe(sns.histplot,
            x=col,
            hue="reaction_steps",
            palette="viridis",
            stat="density",
            multiple="stack",
    )
    g.map(sns.kdeplot,data=df_zinc, x=col, color="black")
    return g

In [None]:
from rdkit.Chem import Descriptors
from multiprocessing import Pool

def get_desc(smi):
    mol = Chem.MolFromSmiles(smi)
    descriptors =  Descriptors.CalcMolDescriptors(mol)
    return descriptors


with Pool(32) as p:
    descs = list(tqdm(p.imap(get_desc, [x[-1] for x in df["products"]]), total=len(df)))

df_descriptors = pd.DataFrame(descs)
df_descriptors["reaction_steps"] = df["reaction_steps"]
df_descriptors["n_reaction_retry"] = df["n_reaction_retry"]
df_descriptors["n_bb_retry"] = df["n_bb_retry"]


In [None]:
col="ExactMolWt"
g = plot_descriptor_distributions(df_descriptors, df_zinc, col=col, xlim=(0,600))
g.fig.savefig(os.path.join(FIG_PATH,f"{col}_FacetGrid.pdf"), bbox_inches='tight')

In [None]:
col="qed"
g = plot_descriptor_distributions(df_descriptors, df_zinc, col=col, xlim=(0,1))
g.fig.savefig(os.path.join(FIG_PATH,f"{col}_FacetGrid.pdf"), bbox_inches='tight')

In [None]:
with Pool(32) as p:
    descs = list(tqdm(p.imap(get_desc, df_per_step["products"]), total=len(df_per_step)))

df_descriptors_step = pd.DataFrame(descs)
df_descriptors_step["reaction_steps"] = df_per_step["n_step"]
df_descriptors_step["n_reaction_retry"] = df_per_step["n_reaction_retry"]
df_descriptors_step["n_bb_retry"] = df_per_step["n_bb_retry"]

In [None]:
col="ExactMolWt"
plot_descriptor_distributions(df_descriptors_step, df_zinc, col=col, xlim=(0,600))
g.fig.savefig(os.path.join(FIG_PATH,f"{col}_FacetGrid_step.pdf"), bbox_inches='tight')

In [None]:
col="qed"
plot_descriptor_distributions(df_descriptors_step, df_zinc, col=col, xlim=(0,1))
g.fig.savefig(os.path.join(FIG_PATH,f"{col}_FacetGrid_step.pdf"), bbox_inches='tight')

In [None]:
from rdkit import Chem
from rdkit import DataStructs

df_simi = {
    "n_reaction_retry": [],
    "n_bb_retry": [],
    "n_steps": [],
    "similarity": [],
    "quantile": []
}
QUANTILES = [10**-i for i in np.linspace(1., 4, 100)]

def agg_tanimoto_sim(sub_df):
    """ Get the average tanimoto similarity between all products in the sub_df"""
    fps = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in sub_df.products]
    n = len(fps)
    if n <=1:
        return 1.0
    sims = np.zeros((n,n))
    for i in range(n):
        for j in range(i+1, n):
            sim = DataStructs.TanimotoSimilarity(fps[i], fps[j])
            sims[i,j] = sim
            sims[j,i] = sim
    return [np.quantile(sims, 1-q, axis=1) for q in QUANTILES]

pbar = tqdm(total=df_per_step.n_step.nunique() * df_per_step.n_reaction_retry.nunique() * df_per_step.n_bb_retry.nunique())

for n_reaction_retry in df_per_step["n_reaction_retry"].unique():
    for n_bb_retry in df_per_step["n_bb_retry"].unique():
        df_filtered = df_per_step[
            (df_per_step.n_reaction_retry == n_reaction_retry) & (df_per_step.n_bb_retry == n_bb_retry)
        ]
        for n_steps in df_per_step.n_step.unique():
            sub_df = df_filtered[df_filtered.n_step==n_steps]
            if sub_df.shape[0]<= 1:
                pbar.update(1)
                continue
            sims = agg_tanimoto_sim(sub_df)
            for i_q, q in enumerate(QUANTILES):
                quantile_sim = sims[i_q]
                for i in range(len(quantile_sim)):
                    df_simi["n_steps"].append(n_steps)
                    df_simi["similarity"].append(quantile_sim[i])
                    df_simi["quantile"].append(q)
                    df_simi["n_bb_retry"].append(n_bb_retry)
                    df_simi["n_reaction_retry"].append(n_reaction_retry)
            pbar.update(1)
pbar.close()
df_simi = pd.DataFrame(df_simi)

In [None]:
g = sns.FacetGrid(
    df_simi,
    col="n_reaction_retry",
    row = "n_bb_retry",
    row_order=sorted(df["n_bb_retry"].unique()),
    col_order=sorted(df["n_reaction_retry"].unique()),
    sharex=True,
    sharey=True,
    margin_titles=True,
    xlim = (10**-4,0.1)
)

palette = {sorted(df_simi.n_steps.unique())[i]: sns.color_palette("viridis", n_colors=df_simi.n_steps.nunique())[i] for i in range(df_simi.n_steps.nunique())}

g.map_dataframe(sns.lineplot, x="quantile", y="similarity", hue="n_steps", palette=palette, errorbar=None, alpha = 0.9)

# g.map_dataframe(sns.lineplot, x="quantile", y="similarity",  color="black",errorbar=None)

g.set(xscale='log', xlabel="top-% of most similar molecules")

handles = [
    mpatches.Patch(color=color, label=label)
    for label, color in palette.items()
]

g.fig.legend(handles=handles, title="Step", loc="upper right", bbox_to_anchor=(1.5, 0.5))

g.fig.savefig(os.path.join(FIG_PATH,"tanimoto_similarity_per_reaction_step.pdf"), bbox_inches='tight')

In [None]:
# Representing all reactions




