In [5]:
import re
import os
import glob
import json
import pandas as pd
from typing import Union
from nnautobench.utils.metrics import calculate_field_metrics

In [6]:
results_root = "results"
results_dict = {os.path.basename(path).replace(".jsonl", ""): pd.read_json(path, orient="records", lines=True) for path in glob.glob(os.path.join(results_root, "*.jsonl"))}

In [7]:
def normalize_text(text: str) -> str:
    try:
        return re.sub(r"\s", " ", re.sub(r"\s{2,}", " ", text)).strip().replace(" / ", " ")
    except:
        print(text)
        return text

def clean_string(value):
    # if its a currency value remove the common prefix
    value = re.sub(r'[$€£¥₹₩₽₨₦₫₴]', '', value)
    value = value.replace(",", "").replace("USD", "").replace("EUR ", "").replace("RM", "")
    return value.strip()

def is_zero_tax(tax_value: Union[str, float]):
    if isinstance(tax_value, float):
        return tax_value == 0.0
    if isinstance(tax_value, str):
        try:
            float_value = float(tax_value)
            return float_value == 0.0
        except ValueError:
            return False
    return False

def compute_conf_score_approval_and_precision(predicted_field_conf_scores: dict, gt_ans: dict, pred_ans: dict, threshold: float=0.99, print_incorrect: bool=False):
    predicted_field_conf_scores = {key: value if isinstance(value, float) else float(value) for key, value in predicted_field_conf_scores.items()}
    gt_ans = gt_ans["fields"]
    gt_ans = {key: str(ans["value"]) for key, ans in gt_ans.items()}
    pred_ans = {key: str(ans["value"]) for key, ans in pred_ans.items()}
    total_correct_approved = 0
    total_incorrect_approved = 0
    all_keys = set(list(gt_ans.keys()) + list(pred_ans.keys()))
    for field_name in all_keys:
        if field_name not in gt_ans:
            gt_ans[field_name] = ""
        if field_name not in pred_ans:
            pred_ans[field_name] = ""
        if "amount" in field_name.lower() or "total_tax" in field_name.lower():
            gt_ans[field_name] = clean_string(gt_ans[field_name])
            pred_ans[field_name] = clean_string(pred_ans[field_name])
        if "date" in field_name.lower():
            pred_ans[field_name] = pred_ans[field_name].replace(" ", "").replace("/", "-").replace(".", "-").rstrip('-')
            gt_ans[field_name] = gt_ans[field_name].replace(" ", "").replace("/", "-").replace(".", "-").rstrip('-')
        if "currency" in field_name.lower():
            pred_ans[field_name] = pred_ans[field_name].replace("DEM", "DM").replace("U. S. DOLLARS", "$").replace("US Dollars", "$")
            gt_ans[field_name] = gt_ans[field_name].replace("DEM", "DM").replace("U. S. DOLLARS", "$").replace("US Dollars", "$")
        if "total_tax" in field_name.lower():
            # if the tax is 0, gt annotations does not have this sometimes
            gt_ans[field_name] = "" if is_zero_tax(gt_ans[field_name]) else gt_ans[field_name]
            pred_ans[field_name] = "" if is_zero_tax(pred_ans[field_name]) else pred_ans[field_name]
            
        predicted_conf_score = predicted_field_conf_scores.get(field_name, 0.0)
        # print(normalize_text(gt_ans[field_name]).lower(), normalize_text(pred_ans[field_name]).lower(), predicted_conf_score)
        if normalize_text(gt_ans[field_name]).lower() == normalize_text(pred_ans[field_name]).lower() and  predicted_conf_score > threshold: # correct and approved
            total_correct_approved += 1
        elif normalize_text(gt_ans[field_name]).lower() != normalize_text(pred_ans[field_name]).lower() and predicted_conf_score > threshold: # incorrect and approved
            if print_incorrect:
                print(field_name, predicted_conf_score)
            total_incorrect_approved += 1
    return total_correct_approved, total_incorrect_approved

def compute_approval_rate_precision(df, threshold: float=0.85):
    total_correct_approved = 0
    total_incorrect_approved = 0
    for i, row in df.iterrows():
        predicted_field_conf_scores = json.loads(row.predicted_field_conf_scores) if isinstance(row.predicted_field_conf_scores, str) else row.predicted_field_conf_scores
        annotation = row.annotation
        preds = row.pred
        approved_c, approved_i = compute_conf_score_approval_and_precision(predicted_field_conf_scores, annotation, preds, threshold=threshold)
        total_correct_approved += approved_c
        total_incorrect_approved += approved_i
    
    # print(total_correct_approved, total_incorrect_approved)
    rate = (total_correct_approved+total_incorrect_approved)/df.total_fields.sum()
    precision = total_correct_approved/(total_correct_approved+total_incorrect_approved) if total_correct_approved+total_incorrect_approved > 0 else -1
    return rate, precision

def compute_approval_rate_precision_grid(df, model_name: str, thresholds: list=[0.8, 0.85, 0.9, 0.95, 0.99]):
    results = []
    for threshold in thresholds:
        rate, precision = compute_approval_rate_precision(df, threshold)
        results.append({"model_name": model_name, "threshold": threshold, "approval_rate": rate, "precision": precision, "weighted": rate*precision})
    return pd.DataFrame(results)
    
def get_weighted_accuracy(df, name):
    total_tp = 0
    total_fields = 0
    for i, row in df.iterrows():
        pred = row.pred
        fields = row.queried_labels
        annotation = row.annotation["fields"]
        acc_metrics = calculate_field_metrics(pred, annotation, fields)
        total_tp += acc_metrics["tp"]
        total_fields += len(fields)
    weighted_acc = total_tp/total_fields
    return weighted_acc


In [8]:
for name, result_df in results_dict.items():
    print(f"#### {name} ####")
    precision_thresholds = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.9999]
    if "nanonets" in name:
        precision_thresholds = [0.99] # model is calibrated on 0.99
    prec_results = compute_approval_rate_precision_grid(result_df, name, precision_thresholds)
    max_prec_results = prec_results.sort_values(by="precision", ascending=False).iloc[0]
    weighted_acc_results = get_weighted_accuracy(result_df, name)
    results_df = pd.DataFrame(
        {
            "approval_rate": max_prec_results["approval_rate"], 
            "precision": max_prec_results["precision"], 
            "weighted_acc": weighted_acc_results, 
            "threshold": max_prec_results["threshold"]
        }, index=[name]
    )
    display(results_df)

#### claude_37_bin ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
claude_37_bin,0.885931,0.751577,0.80475,0.2


#### qwen_bin ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
qwen_bin,0.72907,0.695573,0.781536,0.2


#### claude_35_prob ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
claude_35_prob,0.008659,0.895105,0.809755,0.99


#### gemma3-27b_prob ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gemma3-27b_prob,0.327363,0.816229,0.819358,0.95


#### nanonets ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
nanonets,0.769818,0.975876,0.84168,0.99


#### qwen_25_logits ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
qwen_25_logits,0.340465,0.934863,0.781276,0.9999


#### qwen_prob ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
qwen_prob,0.016047,0.726592,0.781638,0.99


#### claude_37_prob ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
claude_37_prob,0.006149,0.823529,0.804595,0.99


#### gemini_consistency ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gemini_consistency,0.816636,0.841551,0.8094,0.2


#### gpt_logits ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gpt_logits,0.339804,0.948532,0.801292,0.9999


#### pixtral_consistency ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
pixtral_consistency,0.523946,0.885072,0.762599,0.2


#### claude_35_bin ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
claude_35_bin,0.886949,0.770471,0.80899,0.2


#### gpt4o_prob ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gpt4o_prob,0.032634,0.801105,0.801905,0.99


#### gpt_4o_consistency ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gpt_4o_consistency,0.74722,0.920856,0.798787,0.2


#### claude_35_consistency ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
claude_35_consistency,0.882812,0.797443,0.792009,0.2


#### pixtral_prob ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
pixtral_prob,0.151992,0.710953,0.811425,0.99


#### gpt4o_bin ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gpt4o_bin,0.86736,0.760186,0.801798,0.2


#### pixtral_bin ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
pixtral_bin,0.559709,0.826157,0.810744,0.2


#### gemini_bin ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gemini_bin,0.699621,0.718839,0.80964,0.2


#### gemini_prob ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gemini_prob,0.080774,0.938244,0.807888,0.99


#### gemma3-27b_yesno ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gemma3-27b_yesno,0.537472,0.664318,0.819006,0.2


#### gemma3-27b_consistency ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
gemma3-27b_consistency,0.637414,0.832483,0.813889,0.2


#### qwen_25_consistency ####


Unnamed: 0,approval_rate,precision,weighted_acc,threshold
qwen_25_consistency,0.762065,0.893533,0.782299,0.2


#### ------------------- End of Code ------------------- ####