In [None]:
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
import pandas as pd
from mol_gen_docking.data.pydantic_dataset import read_jsonl, write_jsonl, Sample, Message, Conversation
from pathlib import Path
import jsonlines
from typing import Dict, Any, List, Tuple, Iterator
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from tqdm.auto import tqdm

from mol_gen_docking.evaluation.sft_extraction import SFTExtractionConfig, SFTExtractor, Completion

In [None]:
SOURCE = "MiniMax-M2"
DATASET = "molgendata_train"
PATH = Path("MolGenOutput") / DATASET / SOURCE
PROMPT_PATH = Path(f"data/{DATASET.replace('_train', '')}/train_data/train_prompts_boxed.jsonl")

prompt_dataset = read_jsonl(PROMPT_PATH)

comp_dataset = []

n_files = len(list(PATH.glob("*_scored.jsonl")))
i = 0


for path in tqdm(PATH.glob("*_scored.jsonl"), total=n_files):
    with jsonlines.open(path) as reader:
        for obj in reader:
            obj["source"] = SOURCE
            obj["metadata"]["prompt_id"] = prompt_dataset[i//16].identifier
            obj  = Completion.model_validate(obj)
            comp_dataset.append(obj)
            i+=1

print(f"Extracted {len(comp_dataset)} completions from {n_files} files.")

In [None]:
def grid_translate(grid: Dict[str, List[Any]]) -> Iterator[SFTExtractionConfig]:
    keys, values = zip(*grid.items())
    for v in itertools.product(*values):
        config_dict = {}
        for key, value in zip(keys, v):
            if not value == "default":
                config_dict[key] = value
        yield SFTExtractionConfig(**config_dict)


grid0 = {
    "min_reward_threshold": [None, 0.2, 0.5],
    "div_threshold": [None, 0.5, 0.2],
    "reward_info_template": [{}],
    "source_info_template": [{}],
    "system_prompt_path": ["system_prompts/vanilla_boxed.json"],
}
dfs = []
for config in grid_translate(grid0):
    extractor = SFTExtractor(config)
    new_dataset = extractor.extract(comp_dataset, prompt_dataset)
    # Save traces
    trace_path = Path("data/traces")
    for key, value in config.model_dump().items():
        trace_path /= f"{key}_{value}"
    trace_path /= f"{DATASET}.jsonl"
    trace_path.parent.mkdir(parents=True, exist_ok=True)
    write_jsonl(trace_path, new_dataset)
    print(f"N samples extracted for config:\n === {config.model_dump()} ===\n**{sum([len(sample.conversations) for sample in new_dataset])}**")

    # Get metadatas
    df = pd.DataFrame(extractor.metadata.model_dump())
    # Create bins for n_tokens
    bins_tokens = [0,1000, 2000, 5000, 7000, 10000, 1e8]
    bins_lab = ["<1k", "1k-2k", "2k-5k", "5k-7k", "7k-10k", ">10k"]
    bins = pd.cut(df["n_tokens"], bins=bins_tokens, right=False, labels=bins_lab)
    df["n_tokens_bins"] =bins
    df["min_reward_threshold"] = config.min_reward_threshold if config.min_reward_threshold is not None else 0.
    df["div_threshold"] = config.div_threshold if config.div_threshold is not None else 1.

    dfs.append(df)

full_df = pd.concat(dfs).reset_index(drop=True)

In [None]:
# Plot the distribution of the reward
df_min = full_df.sort_values("rewards", ascending=True).groupby("prompt_ids").first().reset_index()

def prop_plot(
        data, **kwargs
):
    grouped_data = data.groupby(["prompt_ids", "div_threshold"]).size().sort_values(ascending=False).reset_index().rename(columns={0: "y"})
    def get_x(sub_df):
        return list(range(sub_df.shape[0]))

    grouped_data["x"] = grouped_data.groupby("div_threshold").prompt_ids.transform(get_x)
    # Plot the distribution of the reward
    sns.lineplot(grouped_data, x="x", y="y", hue="div_threshold", **kwargs)


facet = sns.FacetGrid(full_df, col="min_reward_threshold", margin_titles=True, sharex=True, sharey=True, height=2.5, aspect=1.8)
facet.map_dataframe(
    prop_plot
)

facet.set_titles(col_template="$r_m$={col_name}", row_template="$t_d$={row_name}")
# add legend outside of the plot
facet.add_legend(title="div_threshold", bbox_to_anchor=(0.9, 0.5), loc="center left")

In [None]:
# Plot the distribution of the reward
facet = sns.FacetGrid(full_df, col="min_reward_threshold", row="div_threshold", margin_titles=True, sharex=True, sharey=True, height=1.5, aspect=1.8)
facet.map_dataframe(
    sns.histplot,
    x="rewards",
    bins=20,
    hue="n_tokens_bins",
    multiple="stack",
    palette="flare_r",
    stat="probability",
    binwidth=0.1
)
facet.set_titles(col_template="$r_m$={col_name}", row_template="$t_d$={row_name}")


In [None]:
# Plot the distribution of the reward
df_max = full_df.sort_values("rewards", ascending=False).groupby("prompt_ids").first().reset_index()

facet = sns.FacetGrid(df_max, col="min_reward_threshold", row="div_threshold", margin_titles=True, sharex=True, sharey=True, height=1.5, aspect=1.8)
facet.map_dataframe(
    sns.histplot,
    x="rewards",
    bins=20,
    hue="n_tokens_bins",
    multiple="stack",
    palette="flare_r",
    stat="probability",
    binwidth=0.1
)
facet.set_titles(col_template="$r_m$={col_name}", row_template="$t_d$={row_name}")


In [None]:
# Plot the distribution of the reward
df_min = full_df.sort_values("rewards", ascending=True).groupby("prompt_ids").first().reset_index()

facet = sns.FacetGrid(df_min, col="min_reward_threshold", row="div_threshold", margin_titles=True, sharex=True, sharey=True, height=1.5, aspect=1.8)
facet.map_dataframe(
    sns.histplot,
    x="rewards",
    bins=20,
    hue="n_tokens_bins",
    multiple="stack",
    palette="flare_r",
    stat="probability",
    binwidth=0.1
)
facet.set_titles(col_template="$r_m$={col_name}", row_template="$t_d$={row_name}")


In [None]:
grid1 = {
    "min_reward_threshold": [None, 0.2, 0.5],
    "div_threshold": [None, 0.5, 0.2],
    "reward_info_template": ["default"],
    "source_info_template": ["default"],
    "system_prompt_path": ["system_prompts/vanilla_boxed.json"],
}
for config in grid_translate(grid1):
    extractor = SFTExtractor(config)
    new_dataset = extractor.extract(comp_dataset, prompt_dataset)
    # Save traces
    trace_path = Path("data/traces")
    for key, value in config.model_dump().items():
        trace_path /= f"{key}_{value if not isinstance(value, dict) else 'default'}"
    trace_path /= f"{DATASET}.jsonl"
    trace_path.parent.mkdir(parents=True, exist_ok=True)
    write_jsonl(trace_path, new_dataset)
    print(f"N samples extracted for config:\n === {config.model_dump()} ===\n**{sum([len(sample.conversations) for sample in new_dataset])}**")