In [1]:
import copy

import pathlib
import datasets
from ast import literal_eval
import more_itertools
import pandas as pd
import numpy as np
import gc
import tqdm
import polars as pl
import re
import rich
import rich.table
import functools
import rich.console
import enum
import sys
import itertools
import json
import hashlib
import collections
import time
import numpy as np
import contextlib
import os
os.environ["OPENINSTRUCT_PARSE_LATEX_BACKEND"] = "lark" 

sys.path.append("/home/mila/g/gagnonju/marglicot/with_open-instruct/open-instruct")
from open_instruct.math_utils import (
    last_boxed_only_string,
    remove_boxed,
    get_unnormalized_answer,
    normalize_final_answer,
    is_equiv,
    hendrycks_is_equiv
)


  from .autonotebook import tqdm as notebook_tqdm


In [8]:
class Mode(enum.Enum):
    gsm8k = "gsm8k"
    math = "math"

class LearningType(enum.Enum):
    sft = "sft"
    rejection = "rejection"
    zero_shot = "zero_shot"
    few_shot = "few_shot"


def verify_math_sample(model_output, ground_truth_answer):
    ground_truth_answer = last_boxed_only_string(ground_truth_answer)
    if ground_truth_answer is not None:
        try:
            ground_truth_answer = remove_boxed(ground_truth_answer)
        except AssertionError:
            ground_truth_answer = None
    if ground_truth_answer is None:
        raise NotImplementedError(f"Bad ground truth: {ground_truth_answer}")

    raw_answer = model_output
    # for math, more complex. We will try a few different ways to extract the answer.
    # this roughly follows 'flex em' in oe-eval-internal
    all_answers = []
    # First, try find answer in \boxed{}.
    boxed_answer = last_boxed_only_string(raw_answer)
    if boxed_answer is not None:
        try:
            boxed_answer = remove_boxed(boxed_answer)
        except AssertionError:
            boxed_answer = None
    if boxed_answer is not None:
        all_answers.append(boxed_answer)
    # Second, try to extract via minerva format.
    minerva_answer = normalize_final_answer(get_unnormalized_answer(raw_answer))
    if minerva_answer is not None and minerva_answer != "[invalidanswer]":
        all_answers.append(minerva_answer)
    # If nothing still, try to find the last latex-formatted answer
    if len(all_answers) == 0:
        dollars = [m.start() for m in re.finditer("\\$", raw_answer)]
        if len(dollars) > 1:
            # Add the answer between the second to last and last dollar sign
            answer = normalize_final_answer(raw_answer[dollars[-2] + 1 : dollars[-1]])
            all_answers.append(answer)
    # otherwise, just take the full output. Probably wont work, bit of a yolo.
    if len(all_answers) == 0:
        all_answers.append(normalize_final_answer(model_output))
    # now, compare all answers to ground truth.
    matched = False

    for answer in all_answers:
        if is_equiv(answer, ground_truth_answer):
            matched = True
            break
        elif hendrycks_is_equiv(answer, ground_truth_answer):
            matched = True
            break
    # if we got any match, we are good.
    return matched, all_answers



def verify_gsm8k_sample(model_output, ground_truth_answer, verbose=False):
    # model_output = model_output.split("<|assistant|>\n")[-1].strip()
    # gsm is easy: extract numbers, and then just compare last number with answer.
    # matches how we do eval.
    predictions = None
    # replace numbers like `x,xxx` with `xxxx`
    response = re.sub(r"(\d),(\d)", r"\1\2", model_output)
    numbers = re.findall(r"[-+]?\d*\.\d+|[-+]?\d+", response)
    if numbers:
        predictions = numbers[-1]
    else:
        predictions = response
    if verbose:
        print(f"predictions: {predictions}, ground_truth_answer: {ground_truth_answer}")
    return str(predictions).lower() == str(ground_truth_answer).lower(), predictions


pattern = re.compile(r"-?(\d{1,3}(,\d{3})*|\d+)")


def extract_answer_predicted(text):
    found = more_itertools.last(pattern.findall(text), None)
    return found


def extract_answer_gold(text):
    return text.rsplit("####", 1)[-1].strip()


def get_of_expected(container, idx, expected_size: int):
        assert len(container) == expected_size
        return container[idx]


functools.lru_cache(maxsize=None)
def load_parquet(path):
    return pl.read_parquet(path)




class FractionTimeSpent:
    def __init__(self):
        self._data = collections.defaultdict(list)
        self._start_time = collections.defaultdict(float)

    def start(self, key: str):
        self._start_time[key] = time.perf_counter()

    def stop(self, key: str):
        self._data[key].append(time.perf_counter() - self._start_time[key])
        self._start_time[key] = None

    @contextlib.contextmanager
    def time_block(self, key: str):
        self.start(key)
        yield
        self.stop(key)

    def get(self):
        sum_ = 0
        means = {}
        for key in self._data:
            means[key] = np.mean(self._data[key])
            sum_ += means[key]

        normalized_data = {}
        for key in self._data:
            normalized_data[key] = means[key] / sum_

        return normalized_data, means


def compute_score(path, mode: Mode, time_spent: FractionTimeSpent, compute_score: bool):

    if mode == Mode.gsm8k:
        verify = verify_gsm8k_sample
        extract = extract_answer_gold
    elif mode == Mode.math:
        verify = verify_math_sample
        extract = lambda x: x
    else:
        raise ValueError(f"Invalid mode: {MODE}")

    with time_spent.time_block("read_parquet"):
        ds = load_parquet(path)

    with time_spent.time_block("convert_to_series"):
        parsed_predictions = pl.Series([
            more_itertools.one(get_of_expected(more_itertools.one(literal_eval(pred)), 0, 2)) 
            for pred in ds["predictions"]
        ])
        gold = pl.Series([more_itertools.one(literal_eval(pred)) for pred in ds["gold"]])
        original_score = pl.Series([literal_eval(x)["qem"] for x in ds["metrics"]])


    results_root = path.parent.parent.parent.parent / "results" 
    assert results_root.exists(), results_root
    second_part = results_root / path.relative_to(results_root.parent / "details").parent.parent
    assert second_part.exists(), second_part
    assert (second_part / "meta_info.json").exists(), second_part / "meta_info.json"
    meta_info = json.loads((second_part / "meta_info.json").read_text())

    if compute_score:
        with time_spent.time_block("verify"):
            ongoing = []
            extracted_predictions = []  
            extracted_golds = []
            is_equal = []

            for i, (generated, gold_individual) in enumerate(more_itertools.zip_equal(parsed_predictions, gold)):
                if i % 1000 == 0:
                    print(f"i: {i / len(parsed_predictions):0.1%}")
                extracted_gold_i = extract(gold_individual).replace(",", "")
                
                # Check if we get the same answer for the gold from the verify function 
                # and from the reference gold. If not, this is a bug.
                # is_equal_golds, test_extracted_gold_as_pred = verify(
                #     model_output=extracted_gold_i, 
                #     ground_truth_answer=extracted_gold_i
                # )
                
                # if extracted_gold_i != test_extracted_gold_as_pred:
                #     print(f"extracted_gold_i: {extracted_gold_i}, test_extracted_gold_as_pred: {test_extracted_gold_as_pred}, is_equal: {is_equal_golds}")

                # Actually verify the prediction.
                verify_output, extracted_prediction_i = verify(model_output=generated, ground_truth_answer=extracted_gold_i)
                ongoing.append(extracted_prediction_i is not None and extracted_gold_i is not None and verify_output)
                extracted_predictions.append(extracted_prediction_i)
                extracted_golds.append(extracted_gold_i)
                is_equal.append(verify_output)
            score = np.mean(ongoing)
    
        main_output = pl.DataFrame(
            {
                "epoch": meta_info["epoch"], 
                "learning_rate": meta_info["cfg"]["learning_rate"],
                "score": score, 
                "original_score": original_score.mean(),
                "path": str(path), 
            }
        )
        predictions_output = pl.DataFrame({"predictions": parsed_predictions, "gold": gold, "extracted_predictions": extracted_predictions, "extracted_golds": extracted_golds, "is_equal": is_equal})
    else:
        main_output = pl.DataFrame({"original_score": original_score.mean(), "learning_rate": meta_info["cfg"]["learning_rate"], "epoch": meta_info["epoch"], "path": str(path)})
        predictions_output = None

    return main_output, predictions_output


In [9]:
MODE = Mode.math
LEARNING_TYPE = LearningType.sft
COMPUTE_SCORE = True

pl.Config(fmt_str_lengths=500, tbl_width_chars=10000, tbl_rows=100, tbl_cell_alignment="LEFT")


paths = [
    x for x in pathlib.Path("/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/").glob("**/*.parquet") 
    if f"|{MODE.value}|" in x.name and 
    "previous_outputs" not in str(x) and 
    LEARNING_TYPE.value in str(x)
]

print(f"Found {len(paths)} paths")

time_spent = FractionTimeSpent()
results = []
gen_gold = []
for i, path in enumerate(tqdm.tqdm(sorted(paths))):
    df, gen_gold_i = compute_score(path=path, mode=MODE, time_spent=time_spent, compute_score=COMPUTE_SCORE)
    results.append(df)
    gen_gold.append(gen_gold_i)
    print(results[-1])
    # rich.print(time_spent.get())
pl_results = pl.concat(results)
pl_results = pl_results.sort(["learning_rate", "epoch"])

Found 52 paths


  0%|          | 0/52 [00:00<?, ?it/s]

i: 0.0%
verify_output = False, extracted_prediction_i = ['Simplifythegivenexpression\\[\\frac{\\secx}{\\sinx}-\\frac{\\sinx}{\\cosx}.\\]']
verify_output = True, extracted_prediction_i = ['\\frac{\\sqrt{3}}{2}']
verify_output = False, extracted_prediction_i = ['9']
verify_output = False, extracted_prediction_i = ['\\sqrt{10} \\text{ inches if we pick properly}']
verify_output = False, extracted_prediction_i = ['1+2x\\gt1']
verify_output = True, extracted_prediction_i = ['\\frac{1}{3}']
verify_output = False, extracted_prediction_i = ['6']
verify_output = False, extracted_prediction_i = ['10']
verify_output = False, extracted_prediction_i = ['\\frac{1}{3}.']
verify_output = False, extracted_prediction_i = ['13k+x']
verify_output = False, extracted_prediction_i = ['f^{-1}(f(5))']
verify_output = False, extracted_prediction_i = ['t']
verify_output = False, extracted_prediction_i = ['2015']
verify_output = False, extracted_prediction_i = ['3141.75cubic\n\\]']
verify_output = False, extracte

  0%|          | 0/52 [00:18<?, ?it/s]


KeyboardInterrupt: 

In [8]:
results.sort(["learning_rate", "epoch"])

original_score,learning_rate,epoch,path
f64,f64,i64,str
0.0474,5e-05,0,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_15-2025-04-06_02-03-21_0_model/2025-04-06T18-43-37.600892/details_custom|math|0_2025-04-06T18-43-37.600892.parquet"""
0.0516,5e-05,0,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_5-2025-04-06_02-03-21_0_model/2025-04-06T18-40-47.464082/details_custom|math|0_2025-04-06T18-40-47.464082.parquet"""
0.0598,5e-05,1,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_5-2025-04-06_02-03-21_1_model/2025-04-06T18-46-43.835252/details_custom|math|0_2025-04-06T18-46-43.835252.parquet"""
0.0532,5e-05,1,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_15-2025-04-06_02-03-21_1_model/2025-04-06T18-49-34.874454/details_custom|math|0_2025-04-06T18-49-34.874454.parquet"""
0.0636,5e-05,2,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_5-2025-04-06_02-03-21_2_model/2025-04-06T18-49-52.203110/details_custom|math|0_2025-04-06T18-49-52.203110.parquet"""
0.059,5e-05,2,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_15-2025-04-06_02-03-21_2_model/2025-04-06T18-52-32.727044/details_custom|math|0_2025-04-06T18-52-32.727044.parquet"""
0.0636,5e-05,3,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_15-2025-04-06_02-03-21_3_model/2025-04-06T18-55-46.106898/details_custom|math|0_2025-04-06T18-55-46.106898.parquet"""
0.0616,5e-05,3,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_5-2025-04-06_02-03-21_3_model/2025-04-06T18-55-50.759057/details_custom|math|0_2025-04-06T18-55-50.759057.parquet"""
0.0616,5e-05,4,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_15-2025-04-06_02-03-21_4_model/2025-04-06T18-58-56.977535/details_custom|math|0_2025-04-06T18-58-56.977535.parquet"""
0.0632,5e-05,5,"""/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/sft_outputs_math/0_shot/details/_home_mila_g_gagnonju_scratch_marglicot_saves_sft_saves_cot_math_smollm2_1.7B_0_00005_15-2025-04-06_02-03-21_5_model/2025-04-06T19-01-51.111235/details_custom|math|0_2025-04-06T19-01-51.111235.parquet"""


In [4]:
pl.Config(fmt_str_lengths=500, tbl_width_chars=10000, tbl_rows=100, tbl_cell_alignment="LEFT")

rich.print(", ".join([f"[bold]{c}:[/] {results[c].dtype}" for c in results.columns]))
formatted_results = copy.deepcopy(results)
root = pl.col("path").map_elements(lambda x: str(pathlib.Path(x).parent.parent.parent.parent.relative_to("/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important")), return_dtype=pl.Utf8)
path = pl.col("path").map_elements(lambda x: str(pathlib.Path(x).relative_to("/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important")), return_dtype=pl.Utf8)
formatted_results = formatted_results.with_columns(path, root.alias("root"))
print(formatted_results)


shape: (52, 5)
┌────────────────┬───────────────┬───────┬───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┬─────────────────────────┐
│ original_score ┆ learning_rate ┆ epoch ┆ path                                                                                                                                                                                                                                      ┆ root                    │
│ ---            ┆ ---           ┆ ---   ┆ ---                                                                                                                                                                                                                                       ┆ ---                     │
│ f64            ┆ f64           ┆ i64   ┆ str                        

In [5]:
# Write code to format 
# /home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/rejection_sampling_outputs_gsm8k_8/8_shot/details/_home_mila_g_gagnonju_scratch_rejection_sampling_saves_gsm8k_8_2025-03-30_03-38-45_epoch_2/2025-03-30T23-10-24.672696/details_custom|gsm8k|8_2025-03-30T23-10-24.672696.parquet 

formatted_results = copy.deepcopy(results)

if not "file_hash" in locals():
    file_hash = {}

# Extract components from path
def extract_path_info(path_str):
    path = pathlib.Path(path_str)
    
    # Find indices of key components
    epoch = re.findall(r"epoch_\d+", path_str)[0].split("_")[1]
    shot_count = re.findall(r"\d+_shot", path_str)[0].split("_")[0]
    dataset = re.findall(r"\|\w+\|\w+", path_str)[0].split("|")[1]
    date_and_time = re.findall(r"\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}", path_str)[0]

    results_root = path.parent.parent.parent
    hydra_config = json.loads((results_root.parent / "results" / path.relative_to(results_root).parent.parent / "hydra_config.json").read_text())
    assert dataset == hydra_config["dataset"]["name"], f"Dataset mismatch: {dataset} != {hydra_config['dataset']['name']}"
    assert int(shot_count) == hydra_config["few_shot_qty"], f"Shot count mismatch: {shot_count} != {hydra_config['few_shot_qty']}"
    learning_rate = str(hydra_config["training"]["learning_rate"])
    model_path = pathlib.Path(hydra_config["output_dir"]) / f"epoch_{epoch}" / "model.safetensors"
    assert model_path.exists(), f"Model path does not exist: {model_path}"
    if model_path not in file_hash:
        file_hash[model_path] = hashlib.md5(model_path.read_bytes()).hexdigest()
    model_md5 = file_hash[model_path]

    return pl.Series([dataset, shot_count, epoch, date_and_time, learning_rate, model_md5])

# Add new columns with path information
path_info = pl.col("path").map_elements(extract_path_info, return_dtype=pl.List(pl.Utf8))
formatted_results = formatted_results.with_columns([
    path_info.list.get(0).alias("dataset"),
    path_info.list.get(1).alias("shot_count"),
    path_info.list.get(2).alias("epoch"),
    path_info.list.get(3).alias("date_and_time"),
    path_info.list.get(4).alias("learning_rate"),
    path_info.list.get(5).alias("model_md5")
])
formatted_results = formatted_results.drop("path")


print(formatted_results)


ComputeError: IndexError: list index out of range

In [None]:
print(results["original_score"].max())
print(results["score"].max())

In [None]:
HIGHLIGHT_CRITERIA = 0.05

console = rich.console.Console(width=2000)

table = rich.table.Table(title="LightEval Results")
table.add_column("Index")
table.add_column("Path")
table.add_column("Score")
table.add_column("Original Score")

for i, row in enumerate(results.iter_rows(named=True)):
    diff = abs(row["score"] - row["original_score"])
    highlight = "[red bold]" if diff > HIGHLIGHT_CRITERIA else ""
    table.add_row(str(i), highlight + str(pathlib.Path(row["path"]).relative_to("/home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important").parent.parent), highlight + str(row["score"]), highlight + str(row["original_score"]), f"{highlight}{diff}")

console.print(table)


In [127]:

table = rich.table.Table(title="LightEval Results", show_lines=True)
table.add_column("Predictions")
table.add_column("Gold")
table.add_column("Extracted Predictions")
table.add_column("Extracted Gold")
table.add_column("Is Equal")

SHOW_N = 15

for row in itertools.islice(gen_gold[0].iter_rows(named=True), SHOW_N):
    table.add_row(str(row["predictions"]), str(row["gold"]), str(row["extracted_predictions"]), str(row["extracted_golds"]), str(row["is_equal"]))
# rich.print(table)

In [None]:
gen_gold[0]["is_equal"].mean()

In [None]:
!find /home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/rejection_sampling_outputs_gsm8k_8/8_shot/ -iname "*.json"

In [None]:
cat /home/mila/g/gagnonju/marglicot/light_eval_tests/all_eval_outputs_important/rejection_sampling_outputs_gsm8k_8/8_shot/results/_home_mila_g_gagnonju_scratch_rejection_sampling_saves_gsm8k_8_2025-03-30_03-38-45_epoch_2/hydra_config.json