In [None]:
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
FIG_PATH = "/home/philippe/-Philippe-MolGenDocking/Figures/results"

In [None]:
from pathlib import Path
import json
import re

import pandas as pd
import numpy as np
from tqdm import tqdm

import seaborn as sns
import matplotlib.pyplot as plt

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit import DataStructs

from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, fcluster

from notebooks.utils import PandasTableFormatter
from mol_gen_docking.reward.diversity_aware_top_k import diversity_aware_top_k

In [None]:
MOLSTRAL_PATH = Path("MolGenOutput/test_ood")

files = [f for d in MOLSTRAL_PATH.iterdir() for f in d.iterdir() if "error" not in str(f) and str(f).endswith("scored.jsonl")]
files = sorted(files)

In [None]:
generations = []
for f in tqdm(files):
    with f.open("r") as fd:
        for i_l,line in enumerate(fd):
            g = json.loads(line)
            all_smis = g["reward_meta"].get("all_smi", [""])
            fail_reason = "valid"
            if len(all_smis) == 0:
                valid = 0
                smiles = ""
                reward = 0.
                fail_reason = g["reward_meta"].get("smiles_extraction_failure", "unknown").replace("_", " ").replace("smiles", "SMILES")
            elif len(all_smis) > 1:
                valid = 1
                smiles = all_smis[-1]
                reward = float(g["reward_meta"]["all_smi_rewards"][-1])
            else:
                valid = 1
                smiles = all_smis[0]
                reward = float(g["reward"])
            model_name = str(f).split("/")[-1].split("eval")[0][:-1]
            if "scored" in model_name:
                model_name = str(f).split("/")[-1].split("scored")[0][:-2]
            generations.append(
                {
                    "prompt_id": g["metadata"]["prompt_id"],
                    "reward": reward,
                    "model": model_name,
                    "n_props": len(g["metadata"]["properties"]),
                    "properties": ",".join(g["metadata"]["properties"]),
                    "objectives": ",".join(g["metadata"]["objectives"]),
                    "smiles":smiles,
                    "validity": valid,
                    "valid": fail_reason
                }
            )

df = pd.DataFrame(generations)
df["Model"] = df["model"].apply(lambda x: re.sub(r"-\d+(B|b)", "", x[:-1]).replace("-2507", "").replace("Distill", "D.").replace("-it", "").replace("Thinking", "Think."))

df

In [None]:
cmap = {
    "valid": "seagreen",
    "no valid SMILES": "gold",
    "multiple SMILES": "darkorange",
    "no SMILES": "brown",
    "no answer": "red"
}

ax = sns.histplot(data=df, x="Model", hue="valid", multiple="stack", stat="count", palette= cmap, hue_order=list(cmap.keys()))
# rotate x labels
max_count = 128_000
ax.set_ylim(0, max_count)

yticks = np.linspace(0, max_count, 6)
ax.set_yticks(yticks)
ax.set_yticklabels([f"{int(y / max_count * 100)}%" for y in yticks])
_ = plt.xticks(rotation=45, ha='right')
# Replace y_ticks with percentages from 0 to 100

plt.savefig(f"{FIG_PATH}/validity.pdf", bbox_inches='tight')

In [None]:
def agg_topk(k=100, n_rollout = 2):
    def w_fn(x):
        # print(len(x))
        x = x[:n_rollout]
        x = x.sort_values(ascending=False)
        # Pad with 0s
        x = np.pad(x, (0,100), 'constant')
        return x[:k].mean()
    return w_fn

def uniqueness_topk(k=100):
    def w_fn(x):
        x = x[:k]
        tot = len(x)
        return len(x.drop_duplicates())/tot
    return w_fn

def murcko_tanim_sim_topk(k=100):
    def w_fn(x):
        x = x[:k]
        if len(x) == 1:
            return 1.0
        mols = [Chem.MolFromSmiles(smi) for smi in x]
        murckos = [MurckoScaffold.GetScaffoldForMol(mol) for mol in mols]
        fps = [
            AllChem.GetMorganFingerprintAsBitVect(m, 3, 2048)
            for m in murckos
        ]
        # Compute pairwise tanimoto similarity
        dist = [
            1-np.array(DataStructs.BulkTanimotoSimilarity(fp, fps[i+1:])) for i, fp in enumerate(fps[:-1])
        ]
        dist = np.concatenate(dist)
        dist = squareform(dist)

        return dist.mean(0).mean()
    return w_fn

def sim_topk(k=100, div = 0.7, n_rollout = 100):
    def w_fn(df):
        x = df["smiles"].to_numpy()[:n_rollout]
        rewards = df["reward"].to_numpy()[:n_rollout]

        if len(x) == 1:
            cluster_rewards = [rewards[0]]
        else:
            mols = [Chem.MolFromSmiles(smi) for smi in x]
            murckos = [MurckoScaffold.GetScaffoldForMol(mol) for mol in mols]
            fps = [
                AllChem.GetMorganFingerprintAsBitVect(m, 3, 2048)
                for m in murckos
            ]
            # Compute pairwise tanimoto similarity
            dist = [
                1-np.array(DataStructs.BulkTanimotoSimilarity(fp, fps[i+1:])) for i, fp in enumerate(fps[:-1])
            ]
            dist = np.concatenate(dist)
            idxs = diversity_aware_top_k(
                dist=dist, weights = rewards, k=k, t = div
            )
            cluster_rewards = [rewards[i] for i in idxs]
        cluster_rewards = np.array(cluster_rewards)
        cluster_rewards = np.sort(cluster_rewards)[::-1]
        cluster_rewards = np.pad(cluster_rewards, (0,k), 'constant')[:k]
        return cluster_rewards.mean()
    return w_fn



In [None]:
sub_sample_prompts = df.prompt_id.unique()[:100]

In [None]:
pbar.close()

In [None]:
import warnings
warnings.filterwarnings("ignore")

div_values = [0.01,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
rollouts = [50,75,100]
ks = [5, 10, 20, 30]
div_clus_df = []

pbar = tqdm(total=len(div_values)*len(rollouts)*len(ks))
for div in div_values:
    for n_rollout in rollouts:
        new_pbar_desc = f"div: {div}, n_rollout: {n_rollout}"
        pbar.set_description(new_pbar_desc)
        pbar.refresh()
        for k in ks:
            div_clus_df_single = df[
                (df.validity == 1) & (df.prompt_id.isin(sub_sample_prompts))
            ].drop_duplicates(subset=["prompt_id", "smiles"]).groupby(
                    ["model", "prompt_id"]
                ).apply(
                sim_topk(k, div, n_rollout)
            ).to_frame("value").reset_index()
            div_clus_df_single["k"] = k
            div_clus_df_single["n_rollout"] = n_rollout
            div_clus_df_single["div"] = div
            div_clus_df.append(div_clus_df_single)
            pbar.update(1)
pbar.close()

div_clus_df = pd.concat(div_clus_df).reset_index()
div_clus_df = div_clus_df.groupby(["model", "n_rollout", "div", "k"])["value"].mean().reset_index()
div_clus_df

In [None]:
div_clus_df["Model"] = div_clus_df["model"].apply(lambda x: re.sub(r"-\d+(B|b)", "", x[:-1]).replace("-2507", "").replace("Distill", "D.").replace("-it", "").replace("Thinking", "Think."))

In [None]:
div_clus_df["Model"].unique()

In [None]:
cmap_models = {
    "ChemDFM-R": "orange",
    "ChemDFM-v2.0": "goldenrod",
    "ether0": "chocolate",
    "DeepSeek-R1-D.-Llama": "darkorchid",
    "DeepSeek-R1-D.-Qwen": "orchid",
    "Llama-3.3-Instruct": "darkslategray",
    "Qwen3-A3B-Think.": "teal",
    "gemma-3": "crimson",
}

In [None]:
div_clus_df["sim"] = 1 - div_clus_df["div"]

g = sns.FacetGrid(
    div_clus_df,
    row="n_rollout",
    col="k",
    margin_titles=True,
    height=2.,
    aspect=1.1,
)

def draw(data, **kwargs):

    sns.lineplot(
        data,
        x = "sim",
        y = "value",
        hue = "Model",
        marker="o",
        sizes=1,
        alpha=0.8,
        **kwargs
    )

g.map_dataframe(
    draw,
)
# Add legend to the top right
g.add_legend(title="Model", loc="lower center", bbox_to_anchor=(0.28, -0.12), ncols = 4, fontsize = 8, title_fontsize = 10)
g.set_axis_labels("", "")
g.fig.supxlabel("Similarity threshold between candidate clusters", y=0., x = 0.28)
g.fig.supylabel("Diversity-Aware Top-k Score", x=0.02)
#
g.set_titles(row_template="$n_r$={row_name}", col_template="k={col_name}")
# g.tight_layout()


plt.savefig(f"{FIG_PATH}/diversity_reward.pdf", bbox_inches='tight')

In [None]:
k_values = [1,5,10,20, 30]
topk_dfs = []


for k in tqdm(k_values):
    roll_values = list(range(k,101,5))

    topk_df = df[df.validity == 1].drop_duplicates(subset=["prompt_id", "smiles"]).groupby(
        ["model", "prompt_id"]
    ).agg(
        **{f"{roll}":pd.NamedAgg(column="reward", aggfunc=agg_topk(k, roll)) for roll in roll_values}
    ).reset_index()
    topk_df["k"] = k
    topk_df = topk_df.melt(id_vars=["model", "k"], value_vars=[str(k) for k in roll_values], var_name="n_rollout", value_name= "top-k")
    topk_df["n_rollout"] = topk_df["n_rollout"].apply(int)

    topk_dfs.append(topk_df)

topk_df = pd.concat(topk_dfs).reset_index()

In [None]:
uniq_df = df[df.validity == 1].groupby(["model", "prompt_id"]).agg(
    **{f"{k}":pd.NamedAgg(column="smiles", aggfunc=uniqueness_topk(k)) for k in list(range(1,100,10))}
).reset_index()
uniq_df = uniq_df.melt(id_vars=["model"], value_vars=[str(k) for k in list(range(1,100,10))], var_name="n_rollout")

uniq_df["n_rollout"] = uniq_df["n_rollout"].apply(int)

In [None]:
murcko_sim_df = df[
    (df.validity == 1) & (df.prompt_id.isin(sub_sample_prompts))
].drop_duplicates(subset=["prompt_id", "smiles"]).groupby(["model", "prompt_id"]).agg(
    **{f"{k}":pd.NamedAgg(column="smiles", aggfunc=murcko_tanim_sim_topk(k)) for k in list(range(1,100,10))}
).reset_index()
murcko_sim_df = murcko_sim_df.melt(id_vars=["model"], value_vars=[str(k) for k in list(range(1,100,10))], var_name="n_rollout")

murcko_sim_df["n_rollout"] = murcko_sim_df["n_rollout"].apply(int)
murcko_sim_df

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(6,3))

uniq_df["Model"] = uniq_df["model"].apply(lambda x: re.sub(r"-\d+(B|b)", "", x[:-1]).replace("-2507", "").replace("Distill", "D.").replace("-it", "").replace("Thinking", "Think."))
murcko_sim_df["Model"] = murcko_sim_df["model"].apply(lambda x: re.sub(r"-\d+(B|b)", "", x[:-1]).replace("-2507", "").replace("Distill", "D.").replace("-it", "").replace("Thinking", "Think."))

ax = axes[0]
sns.lineplot(uniq_df, x="n_rollout", y="value", hue="Model", ax=ax, legend=True)
ax.set_ylabel("Uniqueness")
ax.set_xlabel("$n_r$")
ax.set_ylim(0,1)
# Move legend below plot
ax.legend(bbox_to_anchor=(0.1, 0.), loc='lower center', ncols=1)

ax = axes[1]
sns.lineplot(murcko_sim_df, x="n_rollout", y="value", hue="Model", ax=ax, legend = False)
ax.set_ylabel("Murcko-Diversity")
ax.set_ylim(0,1)
ax.set_xlabel("$n_r$")

fig.savefig(f"{FIG_PATH}/uniqueness_diversity.pdf")

In [None]:
fig,axes = plt.subplots(
    1,
    topk_df.k.nunique(),
    figsize=(3 * topk_df.k.nunique(),3),
    sharey=True,
    gridspec_kw={
        "wspace":0.1
    }
)
# Log scale x axis
# for ax in axes:
#     ax.set_xscale("log")
topk_df["Model"] = topk_df["model"].apply(lambda x: re.sub(r"-\d+(B|b)", "", x[:-1]).replace("-2507", "").replace("Distill", "D.").replace("-it", "").replace("Thinking", "Think."))

for k, ax in zip(topk_df.k.unique(), axes):
    sns.lineplot(topk_df[
         topk_df.k == k
     ], x="n_rollout", y="top-k", hue="Model", ax=ax, legend = k==30, )
    ax.set_ylabel("top-k score")
    ax.set_title(f"k = {k}")
# Move legend of the last axis below plot
axes[-1].legend(bbox_to_anchor=(-1.8, -0.5), loc='lower center', ncols=4)

fig.savefig(f"{FIG_PATH}/topk_k_score.pdf")

In [None]:
MODEL_META = {
    "ether0": {
        "size":"24B",
        "thinking": True,
        "Chem.": True,
    },
    "molstral": {
        "size":"24B",
        "thinking": True,
        "Chem.": True,
    },
    "ChemDFM-v2.0": {
        "thinking": False,
        "Chem.": True,
    },
    "ChemDFM-R": {
        "thinking": True,
        "Chem.": True,
    },
    "-R1-": {
        "thinking": True,
        "Chem.": False,
    },
    "Llama-3.3": {
        "thinking": False,
        "Chem.": False,
    },
    "Qwen3-30B-A3B-Thinking": {
        "thinking": True,
        "Chem.": False,
    },
    "gemma-3": {
        "thinking": False,
        "Chem.": False,
    },
}

In [None]:
# Create table with: model_name, size, metric_name, value
import re

ROLLOUTS_AT_K = {
    1: [10, 25, 50],
    10: [25, 50, 75],
    30: [50, 75, 100],
}

table = pd.DataFrame(columns=["Model", "Size", "Think.", "Metric", r"$n_\text{rollouts}$", "Value"])
size_pattern = re.compile(r'(?i)(?:^|[-_])(\d+\s*[b])(?:$|[-_])')
pbar = tqdm(total=len(df.model.unique())*sum([len(ROLLOUTS_AT_K[k]) for k in ROLLOUTS_AT_K]))

for model_name in df.model.unique():
    assert any([m in model_name for m in MODEL_META.keys()])
    key = [m for m in MODEL_META.keys() if m in model_name][0]
    metadata = MODEL_META[key]
    pbar.set_description(f"{model_name}")
    pbar.refresh()
    for k in ROLLOUTS_AT_K:
        for n_rollout in ROLLOUTS_AT_K[k]:
            sub_df = df[df.model == model_name]
            sub_df = sub_df.drop_duplicates(subset=["prompt_id", "smiles"])
            pass_k = sub_df.groupby("prompt_id").agg(
                **{f"{k}":pd.NamedAgg(column="reward", aggfunc=agg_topk(k=k, n_rollout=n_rollout))}
            )

            if not "size" in metadata:
                try:
                    size = size_pattern.search(model_name).group(1).upper()
                except:
                    raise ValueError(f"Size not found for model {model_name}")
            else:
                size = metadata["size"]
            for i, row in pass_k.iterrows():
                thinking = r"\CheckmarkBold" if metadata["thinking"] else r"\XSolidBrush"
                chem = r"\CheckmarkBold" if metadata["Chem."] else r"\XSolidBrush"

                table.loc[len(table)] = [
                    model_name,
                    size,
                    thinking,
                    # chem,
                    f"top-{k}",
                    n_rollout,
                    row[str(k)]
                ]
            pbar.update(1)


In [None]:

table["Model"] = table["Model"].apply(lambda x: re.sub(r"-\d+(B|b)", "", x).replace("-2507", "").replace("Distill", "D.").replace("-it", "").replace("Thinking", "Think."))

MODEL_ORDER = [
    "Qwen3-A3B-Think.",
    "DeepSeek-R1-D.-Llama",
    "DeepSeek-R1-D.-Qwen",
    "gemma-3",
    "Llama-3.3-Instruct",
    "ChemDFM-R",
    "ether0",
    "ChemDFM-v2.0",
]

table["N_rolls"] = table[r"$n_\text{rollouts}$"].apply(lambda x: r"$n_\text{r}=$" + str(x))
table["Model"] = table["Model"].apply(lambda x: x.replace("_", ""))
table

In [None]:
table.Model.unique()

In [None]:
LATEX_PATH = "/home/philippe/-Philippe-MolGenDocking/tables"

In [None]:
formatter = PandasTableFormatter(
    n_decimals = 3, # Number of decimals to keep in the table
    aggregation_methods=["mean"], # Aggregation functions to apply to the data
    main_subset=0, # Subset of values to bold, here the first column will be bolded corresponding to the mean values, if [0,1] the first two columns will be bolded (independently)
    hide_agg_labels=True, # Hide the aggregation column names in the latex
    global_agg=False # Whether to compute global aggregation across all columns (True)
)

style = formatter.style(
    table, # Dataframe to format
    rows= ["Model", "Size", "Think."], # Rows
    cols=["Metric", r"N_rolls"], # Columns
    values= "Value", # Values
    highlight_fn= np.nanmax, # Function to use to highlight the values, here the maximum values will be highlighted
    props=["font-weight: bold; text-decoration: underline;", "text-decoration: underline"], # Properties to apply to the highlighted values, here the maximum values will be underlined and bolded, the second maximum values will be bolded
    special_format_agg = {
        "std": lambda x: "\\tiny $\\pm$" + x, # Format to apply to the standard deviation values
    },
    remove_col_names=True,
    row_order = MODEL_ORDER
)
formatter.save_to_latex(style, f"{LATEX_PATH}/gen_table.tex", 1, multicol_align="|c|", hrules=True, n_first_cols=2)
style

In [None]:
formatter = PandasTableFormatter(
    n_decimals = 3, # Number of decimals to keep in the table
    aggregation_methods=["mean", "std"], # Aggregation functions to apply to the data
    main_subset=0, # Subset of values to bold, here the first column will be bolded corresponding to the mean values, if [0,1] the first two columns will be bolded (independently)
    hide_agg_labels=True, # Hide the aggregation column names in the latex
    global_agg=False # Whether to compute global aggregation across all columns (True)
)

style = formatter.style(
    table, # Dataframe to format
    rows= ["Model", "Size", "Think."], # Rows
    cols=["Metric", r"N_rolls"], # Columns
    values= "Value", # Values
    highlight_fn= np.nanmax, # Function to use to highlight the values, here the maximum values will be highlighted
    props=["font-weight: bold; text-decoration: underline;", "text-decoration: underline;"], # Properties to apply to the highlighted values, here the maximum values will be underlined and bolded, the second maximum values will be bolded
    special_format_agg = {
        "std": lambda x: "\\tiny $\\pm$" + x, # Format to apply to the standard deviation values
    },
    remove_col_names=True,
    row_order = MODEL_ORDER
)
formatter.save_to_latex(style, f"{LATEX_PATH}/gen_table_std.tex", 1, multicol_align="|c|", hrules=True, n_first_cols=2)