In [24]:
import os
import re
import glob
import json
import itertools
from tqdm import tqdm
import response_extractor
from collections import defaultdict
import multiprocessing as mp
from functools import partial

In [10]:
prediction_folder = (
    "/mnt/pfs/zitao_team/tianqiaoliu/Project/teammates/hidden_cot/prediction"
)
result_dir = prediction_folder

In [17]:
def num_match(predictions, references):
    # for f1, f2 in itertools.product([float, eval], repeat=2):
    try:
        if abs(eval(f"({predictions})") - eval(f"({references})")) < 1e-3:
            return True
    except:
        pass
    return False


dir_name = prediction_folder


def unique_verify_dicts(dicts, key):
    seen = set()
    result = []
    for d in dicts:
        meta = json.loads(json.loads(d["meta"])["meta"])
        if meta["input"] not in seen:
            seen.add(meta["input"])
            result.append(d)
    return result


def unique_dicts(dicts, key):
    seen = set()
    result = []
    for d in dicts:
        if d[key] not in seen:
            seen.add(d[key])
            result.append(d)
    return result


def unique_conversation_dicts(dicts):
    seen = set()
    result = []
    for d in dicts:
        meta = json.loads(d["meta"])
        if "conversations" in meta:
            if meta["conversations"][-1]["value"] not in seen:
                seen.add(meta["conversations"][-1]["value"])
                result.append(d)
        else:
            if meta["question"] not in seen:
                seen.add(meta["question"])
                result.append(d)
    return result


def process_model(model_name, data_name):
    # Since in this we have ckpts, we change the folder with model-name + data-name + ckpts
    result = {}
    model_data_folder = os.path.join(result_dir, model_name, data_name)
    for ckpts in os.listdir(model_data_folder):
        model_data_ckpt = os.path.join(model_data_folder, ckpts)
        gpt_verification = [
            json.loads(line)
            for file in glob.glob(model_data_ckpt + "/*add_gpt4_verification.jsonl")
            for line in open(file, "r")
        ]
        gpt_verification = unique_verify_dicts(gpt_verification, "question")
        before = len(gpt_verification)
        gpt_verification = [
            line for line in gpt_verification if line.get("response", None)
        ]
        after = len(gpt_verification)
        if len(gpt_verification) > 0:
            result["gpt4_verification"] = {}
            result["gpt4_verification"]["instances"] = []
            correct = 0
            for line in gpt_verification:
                new_line = {}
                if isinstance(line["response"], str):
                    match = re.findall(
                        r"<answer>(.*)</answer>",
                        line["response"].split("[assistant](#verification)")[-1],
                    )
                else:
                    with open("error.log", "a") as f:
                        f.write(f"{line}\n")
                new_line["correct"] = 1 if match and match[0] == "correct" else 0
                new_line["verification"] = match[0] if match else None
                result["gpt4_verification"]["instances"].append(new_line)
                correct += new_line["correct"]
            result["gpt4_verification"]["acc"] = correct / len(gpt_verification)
            result["gpt4_verification"]["count"] = len(gpt_verification)
            result["gpt4_verification"]["response_null_count"] = before - after
        if result:
            os.makedirs(f"metric_output/{model_name}/{ckpts}", exist_ok=True)
            with open(f"metric_output/{model_name}/{ckpts}/{data_name}.json", "w") as f:
                json.dump(
                    result, f, indent=4, separators=(",", ": "), ensure_ascii=False
                )

In [18]:
model_list = [
    "llama2-7B-gsm8k-without-cot",
    "llama2-7B-gsm8k-normal-cot",
    "llama2-13B-gsm8k-without-cot",
    # "HCOT_7b_mse_10_bs_128_ckpt_100",
    "llama2-13B-gsm8k-normal-cot",
]
others_dataset_list = ["GSM8K"]

In [19]:
for one_model_name in model_list:
    for one_data_name in others_dataset_list:
        process_model(one_model_name, one_data_name)

In [25]:
model_list = [
    "HCOT_7b_mse_10_bs_128_ckpt_100",
]
others_dataset_list = ["GSM8K"]

In [26]:
for one_model_name in model_list:
    for one_data_name in others_dataset_list:
        process_model(one_model_name, one_data_name)