In [None]:
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
from pathlib import Path
import json
import re

import pandas as pd
import numpy as np
from tqdm import tqdm

import seaborn as sns
import matplotlib.pyplot as plt

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit import DataStructs

from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, fcluster

from notebooks.utils import PandasTableFormatter
from mol_gen_docking.reward.diversity_aware_top_k import diversity_aware_top_k

FIG_PATH = "/home/philippe/-Philippe-MolGenDocking/Figures/results"

In [None]:
MOLSTRAL_PATH = Path("MolGenOutput/polaris")

files = [f for f in MOLSTRAL_PATH.iterdir() if "error" not in str(f) and str(f).endswith("scored.jsonl")]
files = sorted(files)

MODEL_META = {
    "ether0": {
        "size":"24B",
        "thinking": True,
        "Chem.": True,
    },
    "molstral": {
        "size":"24B",
        "thinking": True,
        "Chem.": True,
    },
    "ChemDFM-v2.0": {
        "thinking": False,
        "Chem.": True,
    },
    "ChemDFM-R": {
        "thinking": True,
        "Chem.": True,
    },
    "-R1-": {
        "thinking": True,
        "Chem.": False,
    },
    "Llama-3.3": {
        "thinking": False,
        "Chem.": False,
    },
    "Qwen3-30B-A3B-Thinking": {
        "thinking": True,
        "Chem.": False,
    },
    "gemma-3": {
        "thinking": False,
        "Chem.": False,
    },
}

PROP_NAME = {
    # ASAP
    "asap-discovery/antiviral-potency-2025-unblinded": "antiviral-potency",

    # Biogen ADME (FANG)
    "biogen/adme-fang-hclint-reg-v1": "fang-hclint",
    "biogen/adme-fang-hppb-reg-v1": "fang-hppb",
    "biogen/adme-fang-perm-reg-v1": "fang-perm",
    "biogen/adme-fang-rclint-reg-v1": "fang-rclint",
    "biogen/adme-fang-rppb-reg-v1": "fang-rppb",
    "biogen/adme-fang-solu-reg-v1": "fang-solubility",

    # Novartis
    "novartis/novartis-cyp3a4-v1": "cyp3a4-novartis",

    # Polaris / AZ
    "polaris/az-logd-74-v1": "az-logd",
    "polaris/az-ppb-clearance-v1": "az-ppb-clearance",
    "polaris/drewry2017-pkis2-subset-v2": "pkis2-drewry",

    # Therapeutics Data Commons (TDC)
    "tdcommons/ames": "ames",
    "tdcommons/bbb-martins": "bbb",
    "tdcommons/caco2-wang": "caco2",
    "tdcommons/clearance-hepatocyte-az": "hep-clearance-az",
    "tdcommons/clearance-microsome-az": "mic-clearance-az",
    "tdcommons/cyp2c9-substrate-carbonmangels": "cyp2c9-substrate",
    "tdcommons/cyp2d6-substrate-carbonmangels": "cyp2d6-substrate",
    "tdcommons/cyp3a4-substrate-carbonmangels": "cyp3a4-substrate",
    "tdcommons/dili": "dili",
    "tdcommons/half-life-obach": "half-life",
    "tdcommons/herg": "herg",
    "tdcommons/ld50-zhu": "ld50",
    "tdcommons/lipophilicity-astrazeneca": "lipophilicity",
    "tdcommons/pgp-broccatelli": "pgp",
    "tdcommons/solubility-aqsoldb": "solubility",
    "tdcommons/vdss-lombardo": "vdss",
}


In [None]:
generations = []
size_pattern = re.compile(r'(?i)(?:^|[-_])(\d+\s*[b])(?:$|[-_])')

for f in tqdm(files):
    with f.open("r") as fd:
        for i_l,line in enumerate(fd):
            g = json.loads(line)
            valid = g["reward_meta"]["extracted_answer"] is not None
            if valid:
                extracted = g["reward_meta"]["extracted_answer"]
            else:
                extracted = "invalid"
            reward = float(g["reward"])

            model_name = str(f).split("/")[-1].split("eval")[0][:-1]
            if "scored" in model_name:
                model_name = str(f).split("/")[-1].split("scored")[0][:-2]

            assert any([m in model_name for m in MODEL_META.keys()])
            key = [m for m in MODEL_META.keys() if m in model_name][0]
            model_metadata = MODEL_META[key]
            if not "size" in model_metadata:
                try:
                    size = size_pattern.search(model_name).group(1).upper()
                except:
                    raise ValueError(f"Size not found for model {model_name}")
            else:
                size = model_metadata["size"]
            thinking = r"\CheckmarkBold" if model_metadata["thinking"] else r"\XSolidBrush"
            chem = r"\CheckmarkBold" if model_metadata["Chem."] else r"\XSolidBrush"

            generations.append(
                {
                    "prompt_id": g["metadata"]["prompt_id"],
                    "reward": reward,
                    "model": model_name,
                    "n_props": len(g["metadata"]["properties"]),
                    "properties": ",".join(g["metadata"]["properties"]),
                    "objectives": ",".join(g["metadata"]["objectives"]),
                    "validity": valid,
                    "extracted_answer": extracted,
                    "Size": size,
                    "Think.": thinking,
                    "Task": PROP_NAME[",".join(g["metadata"]["properties"])]
                }
            )

df = pd.DataFrame(generations)
df["gen_id"] = df.index % 3

def find_valid_reward(values):
    valid_ansers = [x for x in values if not x < 0]
    if len(valid_ansers) == 0:
        return 0.
    return np.mean(valid_ansers)
df["agg_reward@5"] = df.reward - ~df.validity
df["agg_reward@5"] = df.groupby(["model", "Task", "prompt_id"])["agg_reward@5"].transform(find_valid_reward)

df

In [None]:
table = df

table["Model"] = table["model"].apply(lambda x: re.sub(r"-\d+(B|b)", "", x).replace("-2507", "").replace("Distill", "D.").replace("-it", "").replace("Thinking", "Think."))
MODEL_ORDER = [
    "Qwen3-A3B-Think.",
    "DeepSeek-R1-D.-Llama",
    "DeepSeek-R1-D.-Qwen",
    "gemma-3",
    "Llama-3.3-Instruct",
    "ChemDFM-R",
    "ether0",
    "ChemDFM-v2.0",
]
table["Model"] = table["Model"].apply(lambda x: x.replace("_", ""))
task_order = table.groupby(["model", "Task"])["agg_reward@5"].mean().reset_index().groupby("Task")["agg_reward@5"].max().sort_values( ascending=False).index

table = table.set_index(["Task"]).loc[task_order].reset_index()


table

In [None]:
def plot_perfs(table, axes, title_name, split_val, legend=False):
    ax = axes[0]
    sns.barplot(
        data=table[table.objectives == split_val],
        x="Task",
        y="agg_reward@5",
        hue="Model",
        ax=ax,
        legend=False,
        # alpha=0.5
        capsize=.02,
        err_kws={"linewidth": .8},
    )
    ax.set_xlabel("")
    ax.set_ylabel("Reward")
    ax.set_xticklabels(rotation=90, ha='center', labels=ax.get_xticklabels())
    ax.set_title(title_name)

    ax = axes[1]
    sns.barplot(
        data=table[table.objectives == split_val],
        x="objectives",
        y="agg_reward@5",
        hue="Model",
        ax=ax,
        legend=legend,
        capsize=.02,
        err_kws={"linewidth": .8},
    )
    ax.set_xlabel("")
    ax.set_xticklabels(rotation=90, ha='center', labels=["Avg."])
    ax.set_title(f"Avg.\n{title_name}")


In [None]:
ax_size_ratio = table[table.objectives == "regression"].Task.nunique() / table.Task.nunique()
avg_ratio = ax_size_ratio/8
fig, axes = plt.subplots(
    1,
    4,
    figsize = (9,3),
    sharey=True,
    gridspec_kw = {"width_ratios":[ax_size_ratio, avg_ratio, 1-ax_size_ratio,  avg_ratio]}
)

plot_perfs(table, [axes[0],axes[1]], "Regression", "regression")
plot_perfs(table, [axes[2],axes[3]], "Classification", "classification", legend=True)

axes[-1].legend(title="Model", loc="lower center", bbox_to_anchor=(-6.5, -.8), ncols = 8, fontsize = 8, title_fontsize = 10)

fig.savefig(f"{FIG_PATH}/molecular_proppred.pdf", bbox_inches="tight")

In [None]:
MODEL_ORDER = df.groupby("Model")["agg_reward@5"].mean().sort_values().index.tolist()

In [None]:
def plot_heatmap(table, axes, title_name, split_val):
    ax = axes[0]
    pivoted = pd.pivot_table(table[table.objectives == split_val], "agg_reward@5", "Model", "Task")
    col_order = pivoted.mean().sort_values(ascending=False).index
    pivoted = pivoted.loc[MODEL_ORDER,col_order]
    sns.heatmap(
        pivoted,
        ax=ax,
        vmin=0,
        vmax=1,
        cbar=False,
        annot=True,
        fmt=".2f",
        annot_kws={"size": 8}
    )
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_xticklabels(rotation=90, ha='center', labels=ax.get_xticklabels())
    ax.set_title(title_name)

    ax = axes[1]
    pivoted = pd.pivot_table(table[table.objectives == split_val], "agg_reward@5", "Model", "objectives")
    pivoted = pivoted.loc[MODEL_ORDER]
    sns.heatmap(
        pivoted,
        ax=ax,
        vmin=0,
        vmax=1,
        cbar=False,
        annot=True,
        fmt=".2f",
        annot_kws={"size": 8}
    )
    ax.set_xlabel("")
    ax.set_xticklabels(rotation=90, ha='center', labels=["Avg."])
    ax.set_title(f"Avg.\n{title_name}")


In [None]:
n_tot_cols = table.Task.nunique() + 2
n_rows = table.Model.nunique()

reg_ratio = table[table.objectives == "regression"].Task.nunique() / n_tot_cols
cls_ratio = table[table.objectives == "classification"].Task.nunique() / n_tot_cols

fig, axes = plt.subplots(
    1,
    4,
    figsize = (0.4*n_tot_cols,0.4 * n_rows),
    sharey=True,
    gridspec_kw = {
        "width_ratios":[reg_ratio, 1/n_tot_cols, cls_ratio,  1/n_tot_cols],
        "wspace":0.01,
    }
)

plot_heatmap(table, [axes[0],axes[1]], "Regression", "regression")
plot_heatmap(table, [axes[2],axes[3]], "Classification", "classification")

for ax in axes[1:]:
    ax.set_ylabel("")


In [None]:
table.Model.unique()

In [None]:
LATEX_PATH = "/home/philippe/-Philippe-MolGenDocking/tables"

In [None]:
formatter = PandasTableFormatter(
    n_decimals = 3, # Number of decimals to keep in the table
    aggregation_methods=["mean"], # Aggregation functions to apply to the data
    main_subset=0, # Subset of values to bold, here the first column will be bolded corresponding to the mean values, if [0,1] the first two columns will be bolded (independently)
    hide_agg_labels=True, # Hide the aggregation column names in the latex
    global_agg=False # Whether to compute global aggregation across all columns (True)
)

style = formatter.style(
    table, # Dataframe to format
    rows= ["Model", "Size", "Think."], # Rows
    cols="Task", # Columns
    values= "agg_reward@5", # Values
    highlight_fn= np.nanmax, # Function to use to highlight the values, here the maximum values will be highlighted
    props=["font-weight: bold; text-decoration: underline;"], # Properties to apply to the highlighted values, here the maximum values will be underlined and bolded, the second maximum values will be bolded
    special_format_agg = {
        "std": lambda x: "\\tiny $\\pm$" + x, # Format to apply to the standard deviation values
    },
    remove_col_names=False,
    # row_order = MODEL_ORDER
)
style

In [None]:

formatter.save_to_latex(style, f"{LATEX_PATH}/gen_table.tex", 1, multicol_align="|c|", hrules=True, n_first_cols=2)
style