In [17]:
dataset_dir = "./data_en"
predict_path = "./infer_results/SFT_finqa_test_en_v2.json"

In [18]:
import json
import os

def load_split(split_name):
    file_path = os.path.join(dataset_dir, f"{split_name}.json")

    if not os.path.exists(file_path):
        return []

    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if isinstance(data, dict):
        data = [data]

    result = []
    for item in data:
        example = {
            "id": item.get("id", ""),
            "pre_text": item.get("pre_text", ""),
            "post_text": item.get("post_text", ""),
            "table": item.get("table", {}),
            "question": item.get("question", ""),
            "program": item.get("program", ""),
            "exe_ans": item.get("exe_ans", ""),
        }
        result.append(example)

    print(f"{len(result)} sample {split_name}")
    return result

# Load t·ª´ng t·∫≠p ri√™ng bi·ªát
test_data = load_split("test")

1147 sample test


In [19]:
import json

with open(predict_path, 'r') as f:
    predictions = json.load(f)
len(predictions)

1147

In [20]:
import re
from thefuzz import fuzz

def extract_llm_output(response: str):
    pattern = re.compile(
        r"```(?:\w*\n)?(.*?)```",
        re.IGNORECASE | re.DOTALL
    )

    match = pattern.search(response)

    if match:
        return match.group(1).strip()
    elif response:
        return response.strip()
    return None

def extract_first_argument_from_table_operation(table_operation):
    matches = re.findall(r'\(([^,]+),', table_operation)
    return matches[0] if matches else None

def remove_multiply_100(major_output, end_operation):
    if end_operation.startswith('multiply') and end_operation.endswith('100)'):
        major_output = major_output.replace(', ' + end_operation, '')
    return major_output

def remove_special_characters(text):
    return text.replace('"', '').replace("'", '').replace('`', '')

def has_nested_function(expr):
    func_depth = 0
    n = len(expr)
    i = 0

    while i < n:
        if expr[i] == '(':
            j = i - 1
            while j >= 0 and (expr[j].isalnum() or expr[j] == '_'):
                j -= 1
            name = expr[j+1:i]

            is_func = bool(name and (name[0].isalpha() or name[0] == '_'))

            if is_func:
                if func_depth > 0:
                    return True
                func_depth += 1
            i += 1
        elif expr[i] == ')':
            if func_depth > 0:
                func_depth -= 1
            i += 1
        else:
            i += 1

    return False


def correct_table_column(sequence_of_operations, table):
    if 'table' not in sequence_of_operations:
        return sequence_of_operations
    else:
        operations = re.findall(r'\b\w+\([^()]*\([^()]*\)[^()]*\)|\b\w+\([^()]*\)', sequence_of_operations)
        for i in range(len(operations)):
            operation = operations[i]

            if ',' not in operation:
                continue
            elif 'table' not in operation:
                continue
            elif has_nested_function(operation):
                continue
            else:
                first_arg = extract_first_argument_from_table_operation(operation)
                possible_headers = table[0] + [i[0] for i in table[1:]]
                if first_arg in possible_headers:
                    continue
                else:
                    scores = [fuzz.partial_ratio(first_arg, header) for header in possible_headers]
                    max_score_index = scores.index(max(scores)) if scores else -1
                    correct_arg = possible_headers[max_score_index] if max_score_index != -1 else first_arg
                    operations[i] = operation.replace(first_arg, correct_arg)
        sequence_of_operations = ", ".join(operations)
    return sequence_of_operations

def parse_operations(expr):
    pattern = r"(?<![a-zA-Z])(add|multiply)\(([^)]*)\)"
    matches = re.findall(pattern, expr)

    ops = []
    for func, args_str in matches:
        args = [a.strip() for a in args_str.split(",") if a.strip()]
        try:
            args = sorted([float(a) if re.match(r"^-?\d+(\.\d+)?$", a) else a for a in args])
        except Exception:
            args = sorted(args)
        ops.append((func, tuple(args)))
    return ops

def compare_expr(expr1, expr2):
    if len(expr1) != len(expr2):
        return False
    ops1 = parse_operations(expr1)
    ops2 = parse_operations(expr2)
    if len(ops1) == 0:
        return False
    return set(ops1) == set(ops2)

In [21]:
from tqdm import tqdm

def preprocess(predictions, groundtruth_data):
    # T·∫°o dictionary ƒë·ªÉ tra c·ª©u ground truth theo id
    gt_dict = {item["id"]: str(item["program"]).strip() for item in groundtruth_data}

    for pred in tqdm(predictions):
        ex_id = pred["id"]
        # L·∫•y output m√¥ h√¨nh v√† x·ª≠ l√Ω ban ƒë·∫ßu
        pred_program = extract_llm_output(str(pred.get("model_answer", "")).strip())

        if not pred_program:
            pred['predicted'] = pred_program
            continue

        # Chu·∫©n h√≥a gi√° tr·ªã v√† lo·∫°i b·ªè k√Ω t·ª± ƒë·∫∑c bi·ªát
        pred_program = pred_program.replace('100.00', '100')
        pred_program = remove_special_characters(pred_program)
        try:
            pred_program = correct_table_column(pred_program, pred['table'])
        except Exception:
            pass

        # N·∫øu c√≥ ground truth, √°p d·ª•ng c√°c ph√©p so s√°nh n√¢ng cao
        if ex_id in gt_dict:
            gt_program = gt_dict[ex_id]

            if (pred_program == gt_program or 
                compare_expr(pred_program, gt_program) or
                pred_program == remove_multiply_100(gt_program, gt_program.split('), ')[-1].strip()) or
                remove_multiply_100(pred_program, pred_program.split('), ')[-1].strip()) == gt_program):
                pred['predicted'] = gt_program
            else:
                pred['predicted'] = pred_program

    return predictions


In [22]:
def program_tokenization(original_program):
    program = []
    cur_tok = ''
    bracket_level = 0
    i = 0
    while i < len(original_program):
        c = original_program[i]
        if c == '(':
            bracket_level += 1
            if bracket_level == 1:
                cur_tok += c
                program.append(cur_tok.strip())
                cur_tok = ''
            else:
                cur_tok += c
        elif c == ')':
            if bracket_level == 1:
                if cur_tok.strip() != '':
                    program.append(cur_tok.strip())
                    cur_tok = ''
                program.append(')')
            else:
                cur_tok += c
            bracket_level -= 1
        elif c == ',':
            if bracket_level == 0:
                if cur_tok.strip() != '':
                    program.append(cur_tok.strip())
                    cur_tok = ''
            elif bracket_level == 1:
                if cur_tok.strip() != '':
                    program.append(cur_tok.strip())
                    cur_tok = ''
            else:
                cur_tok += c
        else:
            cur_tok += c
        i += 1
    if cur_tok.strip() != '':
        program.append(cur_tok.strip())
    program.append('EOF')
    return program

In [24]:
predictions = preprocess(predictions, test_data)
for item in predictions:
    item['predicted'] = program_tokenization(item['predicted'])
    item['gold_program'] = program_tokenization(item['program'])

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1147/1147 [00:00<00:00, 47463.17it/s]


In [26]:
from sympy import simplify
import re
from collections import defaultdict

all_ops = ["add", "subtract", "multiply", "divide", "exp", "greater",
           "table_max", "table_min", "table_sum", "table_average"]

def str_to_num(text):
    """Convert string to number, handling percentages and special formats"""
    text = text.replace(",", "").strip()
    try:
        num = float(text)
    except ValueError:
        if "%" in text:
            text = text.replace("%", "")
            try:
                num = float(text)
                num = num / 100.0
            except ValueError:
                num = "n/a"
        elif "const" in text:
            text = text.replace("const_", "")
            if text == "m1":
                text = "-1"
            try:
                num = float(text)
            except ValueError:
                num = "n/a"
        else:
            num = "n/a"
    return num

def process_row(row_in):
    """Process a table row to extract numerical values"""
    row_out = []
    invalid_flag = 0

    for num in row_in:
        num = num.replace("$", "").strip()
        num = num.split("(")[0].strip()

        num = str_to_num(num)

        if num == "n/a":
            invalid_flag = 1
            break

        row_out.append(num)

    if invalid_flag:
        return "n/a"

    return row_out

def safe_parse_list(values):
    nums = []
    for v in values:
        try:
            if v is None:
                continue
            v = str(v).strip()
            if v == "" or v.lower() in ["none", "nan", "n/a", "-", "na"]:
                continue

            matches = re.findall(r"[-+]?\d+(?:[\.,]\d+)?", v)
            if not matches:
                continue

            for m in matches:
                m = m.replace(",", ".")
                try:
                    nums.append(float(m))
                except ValueError:
                    continue
        except Exception:
            continue
    return nums

def eval_program(program, table):
    """
    Calculate the numerical results of the program
    Returns: (invalid_flag, result)
    """
    invalid_flag = 0
    this_res = "n/a"

    try:
        # Make a copy to avoid modifying original
        program = list(program)

        # Check if program ends with EOF
        if not program or program[-1] != "EOF":
            return 1, "n/a"

        program = program[:-1]  # remove EOF

        # Check structure validity
        if len(program) % 4 != 0:
            return 1, "n/a"

        for ind, token in enumerate(program):
            if ind % 4 == 0:
                if token.strip("(") not in all_ops:
                    return 1, "n/a"
            elif ind % 4 == 3:
                if token != ")":
                    return 1, "n/a"

        # Parse operations directly from token array
        # Structure: [op(, arg1, arg2, ), op(, arg1, arg2, ), ...]
        res_dict = {}

        for step_idx in range(0, len(program), 4):
            ind = step_idx // 4

            # Extract operation and arguments from tokens
            op = program[step_idx].strip("(")
            arg1 = program[step_idx + 1].strip()
            arg2 = program[step_idx + 2].strip()
            # program[step_idx + 3] should be ")"

            if op in ["add", "subtract", "multiply", "divide", "exp", "greater"]:
                # Resolve arg1
                if "#" in arg1:
                    arg1_ind = int(arg1.replace("#", ""))
                    if arg1_ind not in res_dict or arg1_ind >= ind:
                        invalid_flag = 1
                        break
                    arg1 = res_dict[arg1_ind]
                else:
                    arg1 = str_to_num(arg1)
                    if arg1 == "n/a":
                        invalid_flag = 1
                        break

                # Resolve arg2
                if "#" in arg2:
                    arg2_ind = int(arg2.replace("#", ""))
                    if arg2_ind not in res_dict or arg2_ind >= ind:
                        invalid_flag = 1
                        break
                    arg2 = res_dict[arg2_ind]
                else:
                    arg2 = str_to_num(arg2)
                    if arg2 == "n/a":
                        invalid_flag = 1
                        break

                # Execute operation
                if op == "add":
                    this_res = arg1 + arg2
                elif op == "subtract":
                    this_res = arg1 - arg2
                elif op == "multiply":
                    this_res = arg1 * arg2
                elif op == "divide":
                    if arg2 == 0:
                        invalid_flag = 1
                        break
                    this_res = arg1 / arg2
                elif op == "exp":
                    this_res = arg1 ** arg2
                elif op == "greater":
                    this_res = "yes" if arg1 > arg2 else "no"

                res_dict[ind] = this_res

            elif "table" in op:
                # --- 1. Build row dictionary ---
                table_dict = {}
                for row in table:
                    if len(row) > 0:
                        table_dict[row[0]] = row[1:]

                # --- 2. Build header (column names) ---
                header = table[0][1:] if len(table) > 0 else []

                # arg1 l√† t√™n c·ªôt/h√†ng (c√≥ th·ªÉ ch·ª©a d·∫•u ngo·∫∑c)
                target_name = arg1.strip() if arg1 else ""
                cal_values = None

                # --- 3. N·∫øu target l√† h√†ng ---
                if target_name in table_dict:
                    cal_values = safe_parse_list(table_dict[target_name])

                # --- 4. N·∫øu target l√† c·ªôt ---
                elif target_name in header:
                    col_index = header.index(target_name) + 1  # v√¨ c·ªôt 0 l√† t√™n h√†ng
                    col_values = []
                    for i in range(1, len(table)):
                        if len(table[i]) > col_index:
                            val = table[i][col_index]
                            col_values.append(val)
                    cal_values = safe_parse_list(col_values)

                # --- 5. N·∫øu kh√¥ng t√¨m th·∫•y ---
                else:
                    invalid_flag = 1
                    break

                # --- 6. N·∫øu kh√¥ng c√≥ gi√° tr·ªã h·ª£p l·ªá ---
                if not cal_values:
                    invalid_flag = 1
                    break

                # --- 7. T√≠nh to√°n k·∫øt qu·∫£ ---
                if op == "table_max":
                    this_res = max(cal_values)
                elif op == "table_min":
                    this_res = min(cal_values)
                elif op == "table_sum":
                    this_res = sum(cal_values)
                elif op == "table_average":
                    this_res = sum(cal_values) / len(cal_values)
                else:
                    invalid_flag = 1
                    break

                res_dict[ind] = this_res

        if invalid_flag:
            return 1, "n/a"

        # Round numerical results
        if this_res != "yes" and this_res != "no" and this_res != "n/a":
            # Don't round here - keep full precision for comparison
            pass

    except Exception as e:
        invalid_flag = 1
        this_res = "n/a"

    return invalid_flag, this_res

def normalize(tok):
    if tok is None:
        return ""
    if tok in ["none", "None"]:
        tok = "none"
    if tok.startswith("const_"):
        tok = tok.replace("const_", "")
    return tok

def equal_program(program1, program2):
    """
    Check if two programs are symbolically equivalent
    program1: gold program
    program2: predicted program
    """
    try:
        # Make copies
        program1 = list(program1)
        program2 = list(program2)

        # Remove EOF
        if program1 and program1[-1] == "EOF":
            program1 = program1[:-1]
        if program2 and program2[-1] == "EOF":
            program2 = program2[:-1]

        # Nomalize
        program1 = [normalize(tok) for tok in program1]
        program2 = [normalize(tok) for tok in program2]

        # Quick exact match check first
        if program1 == program2:
            return True

        # Check structure of program2
        if len(program2) % 4 != 0:
            return False

        for ind, token in enumerate(program2):
            if ind % 4 == 0:
                if token.strip("(") not in all_ops:
                    return False
            elif ind % 4 == 3:
                if token != ")":
                    return False

        # Build symbolic map from program1
        sym_map = {}
        program1_str = "|".join(program1)
        steps1 = program1_str.split(")")[:-1]

        sym_ind = 0
        step_dict_1 = {}

        for ind, step in enumerate(steps1):
            step = step.strip()

            if len(step.split("(")) > 2:
                continue

            op = step.split("(")[0].strip("|").strip()
            args = step.split("(")[1].strip("|").strip()

            arg_list = args.split("|")
            if len(arg_list) != 2:
                continue

            arg1 = arg_list[0].strip()
            arg2 = arg_list[1].strip()

            step_dict_1[ind] = step

            if "table" in op:
                # For table operations, treat the entire step as a single variable
                if step not in sym_map:
                    sym_map[step] = "a" + str(sym_ind)
                    sym_ind += 1
            else:
                if "#" not in arg1:
                    if arg1 not in sym_map:
                        sym_map[arg1] = "a" + str(sym_ind)
                        sym_ind += 1

                if "#" not in arg2:
                    if arg2 not in sym_map:
                        sym_map[arg2] = "a" + str(sym_ind)
                        sym_ind += 1

        # Validate program2 against symbolic map
        program2_str = "|".join(program2)
        steps2 = program2_str.split(")")[:-1]
        step_dict_2 = {}

        for ind, step in enumerate(steps2):
            step = step.strip()

            if len(step.split("(")) > 2:
                return False

            op = step.split("(")[0].strip("|").strip()
            args = step.split("(")[1].strip("|").strip()

            arg_list = args.split("|")
            if len(arg_list) != 2:
                return False

            arg1 = arg_list[0].strip()
            arg2 = arg_list[1].strip()

            step_dict_2[ind] = step

            if "table" in op:
                # For table operations, must match exact step
                if step not in sym_map:
                    return False
            else:
                if "#" not in arg1:
                    if arg1 not in sym_map:
                        return False
                else:
                    arg1_ind = int(arg1.strip("#"))
                    if arg1_ind >= ind:
                        return False

                if "#" not in arg2:
                    if arg2 not in sym_map:
                        return False
                else:
                    arg2_ind = int(arg2.strip("#"))
                    if arg2_ind >= ind:
                        return False

        # If both programs have only table operations (single step), they're equal if steps match
        if len(steps1) == 1 and len(steps2) == 1:
            return steps1[0].strip() == steps2[0].strip()

        # Recursive function to build symbolic expression
        def symbol_recur(step, step_dict):
            if len(step.split("(")) > 2:
                return ""

            step = step.strip()
            op = step.split("(")[0].strip("|").strip()
            args = step.split("(")[1].strip("|").strip()

            arg_list = args.split("|")
            if len(arg_list) != 2:
                return ""

            arg1 = arg_list[0].strip()
            arg2 = arg_list[1].strip()

            if "table" in op:
                return sym_map.get(step, "")

            # Resolve arg1
            if "#" in arg1:
                arg1_ind = int(arg1.replace("#", ""))
                arg1_part = symbol_recur(step_dict[arg1_ind], step_dict)
            else:
                arg1_part = sym_map.get(arg1, "")

            # Resolve arg2
            if "#" in arg2:
                arg2_ind = int(arg2.replace("#", ""))
                arg2_part = symbol_recur(step_dict[arg2_ind], step_dict)
            else:
                arg2_part = sym_map.get(arg2, "")

            if not arg1_part or not arg2_part:
                return ""

            if op == "add":
                return f"( {arg1_part} + {arg2_part} )"
            elif op == "subtract":
                return f"( {arg1_part} - {arg2_part} )"
            elif op == "multiply":
                return f"( {arg1_part} * {arg2_part} )"
            elif op == "divide":
                return f"( {arg1_part} / {arg2_part} )"
            elif op == "exp":
                return f"( {arg1_part} ** {arg2_part} )"
            elif op == "greater":
                return f"( {arg1_part} > {arg2_part} )"

            return ""

        # Build and compare symbolic programs
        sym_prog1 = symbol_recur(steps1[-1], step_dict_1)
        sym_prog2 = symbol_recur(steps2[-1], step_dict_2)

        if not sym_prog1 or not sym_prog2:
            return False

        sym_prog1 = simplify(sym_prog1, evaluate=False)
        sym_prog2 = simplify(sym_prog2, evaluate=False)

        return sym_prog1 == sym_prog2

    except Exception as e:
        return False

def evaluate_result(predictions, gold_data):
    """
    Evaluate predictions against gold data.
    Returns: (execution_accuracy, program_accuracy)
    """
    gold_dict = {item["id"]: item for item in gold_data}

    exe_correct = 0
    prog_correct = 0
    total = 0

    # --- Error buckets ---
    gold_predict = []
    false_predict = []
    exe_errors = []          # Execution-level errors
    prog_errors = []         # Program-level mismatches
    prog_diff_same_result = []  # Different programs, same execution result
    invalid_programs = []    # Execution crashed
    type_errors = []         # Type conversion / invalid data errors

    for pred in predictions:
        pred_id = pred["id"]
        if pred_id not in gold_dict:
            print(f"Warning: ID {pred_id} not found in gold data")
            continue

        gold_item = gold_dict[pred_id]
        table = gold_item["table"]
        gold_ans = gold_item["exe_ans"]
        gold_prog = program_tokenization(gold_item["program"])
        pred_prog = pred["predicted"]

        # ====== EXECUTION CHECK ======
        invalid_flag, pred_ans = eval_program(pred_prog, table)

        exec_correct_flag = False  # track if result is correct

        if invalid_flag == 1:
            invalid_programs.append({
                "id": pred_id,
                "pred_prog": pred_prog,
                "gold_prog": gold_prog,
                "reason": "Program execution failed"
            })
            false_predict.append(pred["id"])
        else:
            try:
                if pred_ans in ["yes", "no"]:
                    if pred_ans == gold_ans:
                        exe_correct += 1
                        gold_predict.append(pred["id"])
                        exec_correct_flag = True
                    else:
                        exe_errors.append({
                            "id": pred_id,
                            "error_type": "BOOLEAN_MISMATCH",
                            "gold_ans": gold_ans,
                            "pred_ans": pred_ans,
                            "gold_prog": gold_prog,
                            "pred_prog": pred_prog
                        })
                        false_predict.append(pred["id"])
                else:
                    pred_ans_float = float(pred_ans)
                    gold_ans_float = float(gold_ans)
                    abs_diff = abs(pred_ans_float - gold_ans_float)
                    rel_diff = abs_diff / (abs(gold_ans_float) + 1e-10)
                    pred_sign = 1 if pred_ans_float >= 0 else -1
                    gold_sign = 1 if gold_ans_float >= 0 else -1

                    if pred_sign != gold_sign and abs(abs(pred_ans_float) - abs(gold_ans_float)) < 0.1:
                        exe_errors.append({
                            "id": pred_id,
                            "error_type": "SIGN_MISMATCH",
                            "gold_ans": gold_ans,
                            "pred_ans": pred_ans,
                            "gold_prog": gold_prog,
                            "pred_prog": pred_prog
                        })
                        false_predict.append(pred["id"])
                    elif abs_diff < 0.1 or rel_diff < 1e-4:
                        exe_correct += 1
                        gold_predict.append(pred["id"])
                        exec_correct_flag = True
                    else:
                        ratio = abs(pred_ans_float) / (abs(gold_ans_float) + 1e-10)
                        if 90 < ratio < 110 or 0.009 < ratio < 0.011:
                            normalized_pred = pred_ans_float / 100 if ratio > 1 else pred_ans_float * 100
                            normalized_diff = abs(normalized_pred - gold_ans_float)
                            rel_diff = normalized_diff / (abs(gold_ans_float) + 1e-10)
                            if normalized_diff < 0.1 or rel_diff < 0.005:
                                exe_correct += 1
                                gold_predict.append(pred["id"])
                                exec_correct_flag = True
                            else:
                                exe_errors.append({
                                    "id": pred_id,
                                    "error_type": "PERCENTAGE_NORM_FAILED",
                                    "gold_ans": gold_ans,
                                    "pred_ans": pred_ans,
                                    "normalized_pred": normalized_pred,
                                    "abs_diff": normalized_diff,
                                    "rel_diff": rel_diff,
                                    "gold_prog": gold_prog,
                                    "pred_prog": pred_prog
                                })
                                false_predict.append(pred["id"])
                        else:
                            exe_errors.append({
                                "id": pred_id,
                                "error_type": "MAGNITUDE_MISMATCH",
                                "gold_ans": gold_ans,
                                "pred_ans": pred_ans,
                                "abs_diff": abs_diff,
                                "rel_diff": rel_diff,
                                "ratio": ratio,
                                "gold_prog": gold_prog,
                                "pred_prog": pred_prog
                            })
                            false_predict.append(pred["id"])
            except (ValueError, TypeError) as e:
                type_errors.append({
                    "id": pred_id,
                    "error": str(e),
                    "gold_ans": gold_ans,
                    "pred_ans": pred_ans,
                    "gold_prog": gold_prog,
                    "pred_prog": pred_prog
                })
                if str(pred_ans) == str(gold_ans):
                    exe_correct += 1
                    gold_predict.append(pred["id"])
                    exec_correct_flag = True
                else:
                    false_predict.append(pred["id"])

        # ====== PROGRAM CHECK ======
        if equal_program(gold_prog, pred_prog):
            prog_correct += 1
        else:
            # N·∫øu ch∆∞∆°ng tr√¨nh kh√°c nh∆∞ng cho c√πng k·∫øt qu·∫£
            if exec_correct_flag:
                prog_diff_same_result.append({
                    "id": pred_id,
                    "error_type": "PROGRAM_DIFFERENT_BUT_CORRECT_EXECUTION",
                    "gold_prog": gold_prog,
                    "pred_prog": pred_prog,
                    "gold_ans": gold_ans,
                    "pred_ans": pred_ans
                })
            else:
                prog_errors.append({
                    "id": pred_id,
                    "error_type": "PROGRAM_MISMATCH",
                    "gold_prog": gold_prog,
                    "pred_prog": pred_prog
                })

        total += 1

    # ====== SUMMARY ======
    exe_acc = exe_correct / total if total > 0 else 0
    prog_acc = prog_correct / total if total > 0 else 0

    print("\n" + "=" * 70)
    print("EVALUATION SUMMARY")
    print("=" * 70)
    print(f"Total examples: {total}")
    print(f"Execution Accuracy: {exe_acc:.4f} ({exe_correct}/{total})")
    print(f"Program Accuracy: {prog_acc:.4f} ({prog_correct}/{total})")

    # ====== EXECUTION ERRORS ======
    print("\n" + "=" * 70)
    print("EXECUTION-LEVEL ERRORS")
    print("=" * 70)
    print(f"‚ùå Invalid programs: {len(invalid_programs)}")
    print(f"‚ùå Result execution errors: {len(type_errors)}")
    print(f"‚ùå Logic execution errors: {len(exe_errors)}")

    if invalid_programs:
        print("\nüîπ Invalid Program Samples:")
        for e in invalid_programs[:5]:
            print(f"  - ID: {e['id']}")
            print(f"    Gold: {e['pred_prog']} | Pred: {e['gold_prog']}")
            print(f"    Reason: {e['reason']}")

    if type_errors:
        print("\nüîπ Result execution Error Samples:")
        for e in type_errors[:5]:
            print(f"  - ID: {e['id']}")
            print(f"    Error: {e['error']}")
            print(f"    Gold: {e['gold_ans']} | Pred: {e['pred_ans']}")

    if exe_errors:
        grouped = defaultdict(list)
        for e in exe_errors:
            grouped[e["error_type"]].append(e)
        print("\nüîπ Logic Execution Error Breakdown:")
        for t, lst in grouped.items():
            print(f"  ‚Ä¢ {t}: {len(lst)} cases")
            for e in lst[:5]:
                print(f"    - ID: {e['id']} | Gold: {e['gold_ans']} | Pred: {e['pred_ans']}")

    # ====== PROGRAM ERRORS ======
    print("\n" + "=" * 70)
    print("PROGRAM-LEVEL ERRORS")
    print("=" * 70)
    print(f"‚ùå Different program but same result: {len(prog_diff_same_result)}")
    print(f"‚ùå Different program and different result: {len(prog_errors)}")

    if prog_diff_same_result:
        print("\nüîπ Same Result but Different Program Samples:")
        for e in prog_diff_same_result[:5]:
            print(f"  - ID: {e['id']}")
            print(f"    Gold answer: {e['gold_ans']} | Pred answer: {e['pred_ans']}")
            print(f"    Gold program: {e['gold_prog']}")
            print(f"    Pred program: {e['pred_prog']}")

    if prog_errors:
        print("\nüîπ Different Program + Different Result Samples:")
        for e in prog_errors[:5]:
            print(f"  - ID: {e['id']}")
            print(f"    Gold program: {e['gold_prog']}")
            print(f"    Pred program: {e['pred_prog']}")

    print("\n" + "=" * 70)
    print("DONE")
    print("=" * 70)

    return gold_predict, false_predict


In [27]:
gold_predict, false_predict = evaluate_result(predictions, test_data)


EVALUATION SUMMARY
Total examples: 1147
Execution Accuracy: 0.7777 (892/1147)
Program Accuracy: 0.6800 (780/1147)

EXECUTION-LEVEL ERRORS
‚ùå Invalid programs: 15
‚ùå Result execution errors: 1
‚ùå Logic execution errors: 239

üîπ Invalid Program Samples:
  - ID: ETFC/2012/page_24.pdf-1
    Gold: ['multiply(', 'divide(165000, 254000)', '100', ')', 'EOF'] | Pred: ['divide(', '165000', '254000', ')', 'EOF']
    Reason: Program execution failed
  - ID: SLG/2011/page_91.pdf-1
    Gold: ['const(', '22825000', ')', 'EOF'] | Pred: ['add(', '22825000', '49250000', ')', 'EOF']
    Reason: Program execution failed
  - ID: HWM/2017/page_42.pdf-2
    Gold: ['divide(', '177.79', '100', ')', 'exp(', '#0', 'const_1/5', ')', 'subtract(', '#1', 'const_1', ')', 'EOF'] | Pred: ['subtract(', '177.79', '100', ')', 'divide(', 'const_1', 'const_5', ')', 'subtract(', '#1', 'const_1', ')', 'exp(', '#0', '#2', ')', 'subtract(', '#3', 'const_1', ')', 'EOF']
    Reason: Program execution failed
  - ID: CB/2008/

In [28]:
gold_data = []
false_data = []

for pre in predictions:
    if pre['id'] in gold_predict:
        gold_data.append(pre)
    elif pre['id'] in false_predict:
        false_data.append(pre)

In [29]:
from collections import Counter
from pprint import pprint

def insight_report(false_data):
    def count_ops(program_tokens):
        """ƒê·∫øm s·ªë ops trong m·ªôt ch∆∞∆°ng tr√¨nh (d·ª±a v√†o token k·∫øt th√∫c '(')."""
        return sum(1 for tk in program_tokens if tk.endswith("("))
    
    group_ids = {
        "1_ops": [],
        "2_ops": [],
        "3_ops": []
    }
    
    stats = {
        "1_ops": Counter(),
        "2_ops": Counter(),
        "3_ops": Counter()
    }
    
    for item in false_data:
        gold_program = item["gold_program"]
        predict_program = item["predicted"]
        
        n_ops_gold = count_ops(gold_program)
        n_ops_pred = count_ops(predict_program)
        
        if n_ops_gold == 1:
            group_key = "1_ops"
        elif n_ops_gold == 2:
            group_key = "2_ops"
        else:
            group_key = "3_ops"
        
        group_ids[group_key].append(item)
        stats[group_key]["gold_ops_count"] += 1
        stats[group_key][n_ops_pred] += 1  # l∆∞u key l√† s·ªë ops (int)
    
    print("===== Insight Report =====")
    for group, counter in stats.items():
        total = counter["gold_ops_count"]
        print(f"\nGroup {group}:")
        print(f"  Total gold samples: {total}")
        # sort theo s·ªë ops tƒÉng d·∫ßn, b·ªè key gold_ops_count
        for ops_count in sorted(k for k in counter if k != "gold_ops_count"):
            v = counter[ops_count]
            print(f"  Predicted with {ops_count} ops: {v} ({v/total:.2%})")



In [30]:
insight_report(false_data)

===== Insight Report =====

Group 1_ops:
  Total gold samples: 118
  Predicted with 0 ops: 4 (3.39%)
  Predicted with 1 ops: 82 (69.49%)
  Predicted with 2 ops: 22 (18.64%)
  Predicted with 3 ops: 5 (4.24%)
  Predicted with 4 ops: 2 (1.69%)
  Predicted with 5 ops: 2 (1.69%)
  Predicted with 6 ops: 1 (0.85%)

Group 2_ops:
  Total gold samples: 103
  Predicted with 1 ops: 51 (49.51%)
  Predicted with 2 ops: 44 (42.72%)
  Predicted with 3 ops: 8 (7.77%)

Group 3_ops:
  Total gold samples: 34
  Predicted with 1 ops: 9 (26.47%)
  Predicted with 2 ops: 13 (38.24%)
  Predicted with 3 ops: 8 (23.53%)
  Predicted with 4 ops: 4 (11.76%)
