In [None]:
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

tqdm.pandas()

In [None]:
from mol_gen_docking.data.pydantic_dataset import read_jsonl
from pathlib import Path

def load(path:str):
    data = read_jsonl(Path(path))
    out = [line.conversations[0].meta for line in data]
    for i, line in enumerate(data):
        out[i]["prompt_id"] = line.identifier
    return out

def get_full_reaction(row: pd.Series):
    full_reac = ""
    for p, reacs in zip(row["products"], row["reactants"]):
        full_reac += " + ".join(reacs) + " -> " + " + ".join(p) + "\n"
    return full_reac[:-1]

def get_n_reactions_steps(row: pd.Series):
    if not row["type of objective"] in ["full synthesis", "product prediction", "full synthesis (with inter.)"]:
        return row["idx_chosen"] + 1
    else:
        return len(row["full_reaction"].split("\n"))

def runtime_get_reactants(row: pd.Series):
    full_reaction = row["full_reaction"]
    reac_steps = full_reaction.split("\n")
    reactants = [
        reac.split(" -> ")[0].split(" + ") for reac in reac_steps
    ]
    return reactants

def runtime_get_products(row: pd.Series):
    full_reaction = row["full_reaction"]
    reac_steps = full_reaction.split("\n")
    products = [
        reac.split(" -> ")[1].split(" + ")[0] for reac in reac_steps
    ]
    return products

def smarts_to_image(smarts):
    rxn = AllChem.ReactionFromSmarts(smarts)
    rxn.Initialize()
    img = AllChem.Draw.ReactionToImage(rxn, subImgSize=(300,300))
    return img

def get_obj_type(obj):
    if "full_path" in obj and "interm" in obj:
        return "full synthesis (with inter.)"
    elif "full_path" in obj:
        return "full synthesis"
    elif "reactant" in obj:
        return "reactant prediction"
    elif "products" in obj:
        return "product prediction"
    elif "product" in obj:
        return "product prediction"
    else:
        return obj

def get_df(data_d):
    df = pd.DataFrame(data_d)
    df= df.explode(["properties", "objectives"]).reset_index(drop=True)
    df["full_reaction"] = df.apply(get_full_reaction, axis=1)
    df["type of objective"]  = df.objectives.apply(get_obj_type)
    df["reaction_steps"] = pd.Categorical(
        df.apply(
           get_n_reactions_steps, axis=1
        ), ordered=True
    )
    return df

df_zinc = pd.read_csv("data/properties.csv")

df_zinc.rename(columns={
    "QED":"qed",
    "CalcExactMolWt":"ExactMolWt",
    "CalcTPSA":"TPSA",
    "CalcNumHBA":"NumHBAcceptors",
    "CalcNumHBD":"NumHBDonors",
    "CalcNumRotatableBonds":"NumRotatableBonds",
    "CalcNumAromaticRings":"NumAromaticRings",
    "CalcHallKierAlpha":"HallKierAlpha",
}, inplace=True)

df = get_df(load("data/synthesis_tasks/train_prompts_unclean_json.jsonl"))
df

In [None]:
fig,ax = plt.subplots(figsize=(6,4))
sns.histplot(df, x="reaction_steps", hue="type of objective", multiple="stack", palette="deep",ax=ax)
ax.set_xlabel("Number of reaction steps")
ax.set_ylabel("Count")

# Find distribution and prompt ids for full_path and final_products

Get the descriptors of the products at each reaction step and compare their distribution to the ZINC dataset.
Subsample the original dataset to match the distribution of the number of reaction steps in the ZINC dataset, to avoid biasing the comparison.

## Compute descriptors for the products of the reaction dataset

In [None]:
from rdkit.Chem import Descriptors
from rdkit import Chem
DESCRIPTORS = [
    "qed",
    "ExactMolWt",
    "TPSA",
    "NumHBAcceptors",
    "NumHBDonors",
    "NumRotatableBonds",
    "NumAromaticRings",
    "HallKierAlpha",
]
def get_desc(smi):
    mol = Chem.MolFromSmiles(smi)
    descriptors = {}
    for desc in DESCRIPTORS:
        func = getattr(Descriptors, desc.replace("NumHB", "NumH"))
        descriptors[desc] = func(mol)
    return descriptors

from multiprocessing import Pool

df_to_ana = df[df.objectives.apply(lambda x: x.startswith("f"))]
with Pool(8) as p:
    descs = list(tqdm(p.imap(get_desc, [x[-1] for x in df_to_ana["products"]]), total=len(df_to_ana)))

df_descriptors = pd.DataFrame(descs)
df_descriptors = df_descriptors.rename(columns={"NumHAcceptors":"NumHBAcceptors", "NumHDonors":"NumHBDonors"})

df_descriptors = pd.concat([df_to_ana.reset_index(drop=True), df_descriptors], axis=1)

In [None]:
fig, axes = plt.subplots(2,4, figsize=(12,5))
axes = axes.flatten()
for i,col in enumerate(DESCRIPTORS):
    sns.histplot(df_zinc, x=col, ax=axes[i], color="black",label="ZINC", alpha=0.35, stat="density", bins=25)
    sns.histplot(
        df_descriptors, x=col, ax=axes[i],
        color="blue", stat="density", label="Reaction dataset (sampled)", bins= 25, alpha=0.35
    )
    axes[i].set_title(col)
    axes[i].legend()
fig.tight_layout()

## Subsample by defining probas

In [None]:
TEMPERATURE = 1.2

In [None]:
# Fit a normal distribution on every descriptor on the ZINC dataset
params_zinc = {
    desc: (df_zinc[desc].mean(), df_zinc[desc].std())
    for desc in DESCRIPTORS
}
# Fit a normal distribution on every descriptor on the products of the reaction dataset
params_data = {
    desc: (df_descriptors[desc].mean(), df_descriptors[desc].std())
    for desc in DESCRIPTORS
}

def get_zinc_logproba(value, descs):
    logproba = 0
    for desc in descs:
        mean, std = params_zinc[desc]
        logproba += -0.5 * ((value[desc] - mean) / std) ** 2 - np.log(std) - 0.5 * np.log(2 * np.pi)
    return logproba

def get_data_logproba(value, descs):
    logproba = 0
    for desc in descs:
        mean, std = params_data[desc]
        logproba += -0.5 * ((value[desc] - mean) / std) ** 2 - np.log(std) - 0.5 * np.log(2 * np.pi)
    return logproba

df_descriptors["zinc_logproba"] = df_descriptors.progress_apply(lambda x: get_zinc_logproba(x, DESCRIPTORS), axis=1)
df_descriptors["data_logproba"] = df_descriptors.progress_apply(lambda x: get_data_logproba(x, DESCRIPTORS), axis=1)

df_descriptors["sample_proba"] =  df_descriptors.progress_apply(lambda x: np.exp((x["zinc_logproba"] - x["data_logproba"]) / TEMPERATURE), axis=1)
df_descriptors["sample_proba"] = df_descriptors["sample_proba"] / df_descriptors["sample_proba"].sum()
df_descriptors

In [None]:
chosen_pids = np.random.choice(df_descriptors.prompt_id, size=50000, replace=False, p=df_descriptors.sample_proba)

In [None]:
fig, axes = plt.subplots(2,4, figsize=(12,5))
axes = axes.flatten()
for i,col in enumerate(DESCRIPTORS):
    sns.histplot(df_zinc, x=col, ax=axes[i], color="black",label="ZINC", alpha=0.35, stat="density", bins=25)
    sns.histplot(
        df_descriptors[df_descriptors.prompt_id.isin(chosen_pids)], x=col, ax=axes[i],
        color="blue", stat="density", label="Reaction dataset (sampled)", bins= 25, alpha=0.35
    )
    axes[i].set_title(col)
    axes[i].legend()
fig.tight_layout()

# Apply modifications

In [None]:
del df_descriptors
del df_zinc

In [None]:
other_tasks_sub_sample = df[~df.objectives.apply(lambda x: x.startswith("f"))].sample(frac=0.7).prompt_id.tolist()
df_clean = df.copy()[df.prompt_id.isin(other_tasks_sub_sample + chosen_pids.tolist())].sample(50000)
df_clean

In [None]:
df_clean.groupby(["type of objective"]).size() / df_clean.shape[0]

In [None]:
from mol_gen_docking.data.pydantic_dataset import read_jsonl, write_jsonl
from pathlib import Path

prompt_ids = df_clean.prompt_id.unique().tolist()
current_data = read_jsonl(Path("data/synthesis_tasks/train_prompts_unclean_json.jsonl"))

clean_data_dict = {line.identifier: line for line in tqdm(current_data)}
clean_data = [clean_data_dict[pid] for pid in prompt_ids if pid in clean_data_dict]

In [None]:
write_jsonl(Path("data/synthesis_tasks/train_prompts_json.jsonl"), clean_data)