In [None]:
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
import pandas as pd
from mol_gen_docking.data.pydantic_dataset import read_jsonl, write_jsonl, Sample, Message, Conversation
from pathlib import Path
import jsonlines
from typing import Dict, Any, List, Tuple, Iterator
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from tqdm.auto import tqdm

from mol_gen_docking.evaluation.sft_extraction import SFTExtractionConfig, SFTExtractor, Completion

In [None]:
SOURCE = "MiniMax-M2"
DATASET = "molgendata_train"
PATH = Path("MolGenOutput") / DATASET / SOURCE
PROMPT_PATH = Path(f"data/{DATASET.replace('_train', '')}/train_data/train_prompts_boxed.jsonl")

prompt_dataset = read_jsonl(PROMPT_PATH)

comp_dataset = []

n_files = len(list(PATH.glob("*_scored.jsonl")))
i = 0


for path in tqdm(PATH.glob("*_scored.jsonl"), total=n_files):
    with jsonlines.open(path) as reader:
        for obj in reader:
            obj["source"] = SOURCE
            obj["metadata"]["prompt_id"] = prompt_dataset[i//16].identifier
            obj  = Completion.model_validate(obj)
            comp_dataset.append(obj)
            i+=1

print(f"Extracted {len(comp_dataset)} completions from {n_files} files.")

In [None]:
def plot_info(extractor: SFTExtractor) -> plt.Figure:
    df = pd.DataFrame(extractor.metadata.model_dump())
    # Create bins for n_tokens
    bins_tokens = [0,1000, 2000, 5000, 7000, 10000, 1e8]
    bins_lab = ["<1k", "1k-2k", "2k-5k", "5k-7k", "7k-10k", ">10k"]
    bins = pd.cut(df["n_tokens"], bins=bins_tokens, right=False, labels=bins_lab)
    df["n_tokens_bins"] =bins

    fig, axes = plt.subplots(1, 4, figsize=(16, 3))
    axes = axes.flatten()

    sns.histplot(
        data=df, x="rewards", hue="n_tokens_bins", palette = "flare_r", multiple="stack", ax=axes[0]
    )
    axes[0].set_xlabel("Reward")

    grouped = df.sort_values("rewards").groupby("prompt_ids").agg({"rewards": "first", "n_tokens_bins": "first"}).reset_index()
    sns.histplot(
        data=grouped, x="rewards", hue="n_tokens_bins", palette = "flare_r", multiple="stack", ax=axes[1], legend=False
    )
    axes[1].set_xlabel("Min Reward per Prompt")

    grouped = df.sort_values("rewards", ascending=False).groupby("prompt_ids").agg({"rewards": "first", "n_tokens_bins": "first"}).reset_index()
    sns.histplot(
        data=grouped, x="rewards", hue="n_tokens_bins", palette = "flare_r", multiple="stack", ax=axes[2], legend=False
    )
    axes[2].set_xlabel("Max Reward per Prompt")

    grouped = df.groupby("prompt_ids").agg({"rewards": "mean", "n_tokens_bins": lambda x: pd.Series.mode(x)[0]}).reset_index()
    sns.histplot(
        data=grouped, x="rewards", hue="n_tokens_bins", palette = "flare_r", multiple="stack", ax=axes[3], legend=False
    )
    axes[3].set_xlabel("Mean Reward per Prompt")

    fig.tight_layout()
    return fig

In [None]:
def grid_translate(grid: Dict[str, List[Any]]) -> Iterator[SFTExtractionConfig]:
    keys, values = zip(*grid.items())
    for v in itertools.product(*values):
        config_dict = {}
        for key, value in zip(keys, v):
            if not value == "default":
                config_dict[key] = value
        yield SFTExtractionConfig(**config_dict)


grid0 = {
    "min_reward_threshold": [None, 0.2, 0.5],
    "div_threshold": [None, 0.5],
    "reward_info_template": [{}],
    "source_info_template": [{}],
}

for config in grid_translate(grid0):
    extractor = SFTExtractor(config)
    new_dataset = extractor.extract(comp_dataset, prompt_dataset)
    fig = plot_info(extractor)
    fig.suptitle(f"{name}")
    plt.show()


    # Save traces
    trace_path = Path("data/traces")
    for key, value in config.model_dump().items():
        trace_path /= f"{key}_{value}"
    trace_path /= f"{DATASET}.jsonl"
    trace_path.parent.mkdir(parents=True, exist_ok=True)
    write_mport load_dataset
    jsonl(trace_path, new_dataset)