In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.chdir("..")

# import sys
# sys.path.append("./vrevals")

import re
import ast
import json
import numpy as np
import pandas as pd
import swifter
import yaml
import time
import glob

import matplotlib.pyplot as plt

from collections import Counter
from src.utils.math_equivalence import is_math_equiv, extract_boxed_answer, extract_code_answer
from src.utils.pass_k_utils import estimate_pass_at_k

In [2]:
def load_all_samples(sample_dir, split):
    all_generation_csv = glob.glob(f"{sample_dir}/{split}.generations*.csv")
    all_generation_df = [pd.read_csv(p) for p in all_generation_csv]

    # Concatenate all dataframes in all_generation_df into a single dataframe
    generation_df = pd.concat(all_generation_df, ignore_index=True)
    generation_df['response'] = generation_df['response'].astype(str)
    generation_df['pred_answer'] = generation_df.response.apply(lambda x: extract_boxed_answer(x) if "\\boxed" in x else extract_code_answer(x))
    generation_df['is_valid'] = generation_df.pred_answer.apply(lambda x: len(x) > 0)

    def is_number(s):
        try:
            return float(s)
        except ValueError:
            return float('nan')
    generation_df['is_correct'] = generation_df.swifter.progress_bar(True).apply(lambda row: is_number(row['gt_answer']) == is_number(row['pred_answer']), axis=1)
    return generation_df


def compute_passk(generation_df, cut_off_size = 256):
    valid_generation_df = generation_df
    grouped_df = valid_generation_df.groupby(['question_id', 'prompt_id'])
    corrects_grouped = grouped_df.is_correct.apply(list).reset_index(name='corrects')
    corrects_grouped.corrects = corrects_grouped.corrects.apply(lambda x: x[:cut_off_size])
    corrects_grouped['num_samples'] = corrects_grouped.corrects.apply(len)
    corrects_grouped['num_math_equal'] = corrects_grouped.corrects.apply(sum)

    corrects_grouped = corrects_grouped.sort_values(by=['question_id'])

    min_num_samples = corrects_grouped['num_samples'].min()

    k_list = range(1,2000)
    detail_pass_at_k = {
        f"pass@{k}": estimate_pass_at_k(corrects_grouped['num_samples'].values, 
                                        corrects_grouped['num_math_equal'].values, k)
        for k in k_list
        if (min_num_samples >= k).all()
    }
    pass_at_k = {k: detail_pass_at_k[k].mean() for k in detail_pass_at_k}
    return detail_pass_at_k, pass_at_k

In [33]:
class Args:
    dataset_name = "gsm8k"
    split = "test"
    k_list = [1,4,8,32]
    subset_num = None
    step_by_step_prompt = True
    n_threads = 1
args = Args()
# root_dir = "runs"
root_dir = "/home/nlp/hnn5071/vreval_runs"
# job_dir = f"runs/{args.dataset_name}.qwen2.5-0.5b-metaninstruct1"
# job_dir = f"{root_dir}/{args.dataset_name}.evolm-1b-exclusive_styles"
job_dir = f"{root_dir}/{args.dataset_name}.evolm-1b-mutual_styles"
# job_dir = f"{root_dir}/{args.dataset_name}.qwen2.5-0.5b-metaninstruct1"
sampler_config_dir = f'{job_dir}/sft_ep8.direct/sample_1'
prompt_csv_path = f'{job_dir}/{args.split}.prompts.csv'
print(sampler_config_dir)

/home/nlp/hnn5071/vreval_runs/gsm8k.evolm-1b-mutual_styles/sft_ep8.direct/sample_1


In [34]:
generation_df = load_all_samples(sampler_config_dir, args.split)

Pandas Apply:   0%|          | 0/337664 [00:00<?, ?it/s]

In [35]:
detail_pass_at_k, pass_at_k = compute_passk(generation_df, cut_off_size=256)
overall_results = {
    'detail_pass_at_k': {k:v.tolist() for k,v in detail_pass_at_k.items()},
    'pass_at_k': pass_at_k,
}
final_metrics = {'overall': overall_results}


In [36]:
pass_at_k

{'pass@1': np.float64(0.10306399260803631),
 'pass@2': np.float64(0.17281395405015668),
 'pass@3': np.float64(0.22575716275620042),
 'pass@4': np.float64(0.2683814846882837),
 'pass@5': np.float64(0.30397630587864916),
 'pass@6': np.float64(0.33445581647116707),
 'pass@7': np.float64(0.3610369669283338),
 'pass@8': np.float64(0.38454418421219755),
 'pass@9': np.float64(0.4055635583271011),
 'pass@10': np.float64(0.42452772526729504),
 'pass@11': np.float64(0.44176577255491445),
 'pass@12': np.float64(0.4575341965759602),
 'pass@13': np.float64(0.47203698063343696),
 'pass@14': np.float64(0.48543911205239876),
 'pass@15': np.float64(0.49787597320322463),
 'pass@16': np.float64(0.509460042845325),
 'pass@17': np.float64(0.520285789210604),
 'pass@18': np.float64(0.5304333145733101),
 'pass@19': np.float64(0.5399711175695646),
 'pass@20': np.float64(0.5489582192790327),
 'pass@21': np.float64(0.5574458221555201),
 'pass@22': np.float64(0.5654786204124198),
 'pass@23': np.float64(0.5730958

In [37]:
t = time.localtime()
metrics_json_name = f'metrics.{t.tm_mon}.{t.tm_mday},{t.tm_hour}:{t.tm_min}.json'
with open(os.path.join(sampler_config_dir, metrics_json_name), mode='w', encoding='utf-8') as json_file:
    json.dump(final_metrics, json_file, indent=4, ensure_ascii=False)