In [None]:
filename_in = "../../data/out/cot_entropy/mmlu_qwen_3b.parquet"
filename_out = "../../data/out/cot_entropy/mmlu_qwen_3b_v2.parquet"
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"

In [None]:
import pandas as pd

df = pd.read_parquet(filename_in)

In [None]:
import json
import re

import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from reasoning_fine_tune.entropy_estimation.estimate_cot_entropy import get_embeddings
from reasoning_fine_tune.utils.correctness import check_answer_correct_mmlu
from reasoning_fine_tune.utils.device import DEVICE_MAP

tqdm.pandas()


def fix_df(model_name, df):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map=DEVICE_MAP, torch_dtype=torch.bfloat16)

    model_name = model.config_class().model_type
    print(model_name)

    field_response = f"{model_name}_response"
    field_ans_correct = f"{model_name}_ans_correct"
    field_answer_embeddings = f"{model_name}_answer_embeddings"

    def fix_row(row):
        response = row[field_response]
        matched_group = re.search("\\[\\[(.+?)\\]\\]", response)

        if matched_group is None:
            row[field_answer_embeddings] = None
            row[field_ans_correct] = False
            return row

        extracted_answer = matched_group.group(1)
        row[field_ans_correct] = check_answer_correct_mmlu(row, extracted_answer)

        answer_embeddings = get_embeddings(model, tokenizer, extracted_answer)
        if answer_embeddings is not None:
            row[field_answer_embeddings] = json.dumps(answer_embeddings)
        return row

    fixed_df = df.progress_apply(fix_row, axis=1)
    print(df.value_counts(field_ans_correct))
    print(fixed_df.value_counts(field_ans_correct))
    return fixed_df


In [None]:
df_fixed = fix_df(MODEL_NAME, df)
df_fixed

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

qwen2
qwen2_ans_correct
False    7482
True     4550
Name: count, dtype: int64


  8%|▊         | 1019/12032 [01:32<15:26, 11.88it/s]

In [None]:
df_fixed.to_parquet(filename_out, compression="gzip")