In [1]:
import pandas as pd

In [2]:
import pandas as pd

def t14_calculate_metrics(true_labels: pd.Series, predictions: pd.Series) -> dict:
    """
    Calculates precision, recall, F1-score, and support for T-stage predictions.
    Includes both per-label, micro-average, and macro-average scores.

    Args:
        true_labels: A pandas Series of true labels (e.g., integers from 0 to N-1).
        predictions: A pandas Series of predicted labels (strings).

    Returns:
        A dictionary containing the metrics for each label and overall scores.
    """
    # Check for valid inputs
    if len(true_labels) != len(predictions):
        raise ValueError("The length of true_labels and predictions must be the same.")

    if not isinstance(true_labels, pd.Series) or not isinstance(predictions, pd.Series):
        raise TypeError("true_labels and predictions must be pandas Series.")

    if any(pd.isna(pred) or not isinstance(pred, str) for pred in predictions):
        raise ValueError("All predictions must be non-null strings.")

    # Standardize true labels to "T{x+1}" format
    true_labels = true_labels.apply(lambda x: f"T{int(x)+1}")

    metrics = {}
    label_counts = {}
    unique_true_labels = sorted(list(set(true_labels))) # Ensure consistent order

    for label in unique_true_labels:
        metrics[label] = {"tp": 0, "fp": 0, "fn": 0}
        label_counts[label] = 0

    for true_label, prediction in zip(true_labels, predictions):
        # Ensure prediction is a string and convert to uppercase
        prediction_str = str(prediction).upper()
        
        label_counts[true_label] += 1
        if true_label in prediction_str:
            metrics[true_label]["tp"] += 1
        else:
            metrics[true_label]["fn"] += 1

        # Calculate false positives
        # A prediction is a false positive for a label if:
        # 1. The label is present in the prediction string.
        # 2. The label is NOT the true_label.
        for label_to_check_fp in unique_true_labels:
            if label_to_check_fp in prediction_str and label_to_check_fp != true_label:
                metrics[label_to_check_fp]["fp"] += 1
    
    results = {}
    # Variables for micro-averaging
    total_tp_micro = 0
    total_fp_micro = 0
    total_fn_micro = 0
    
    # Variables for macro-averaging
    macro_precision_sum = 0.0
    macro_recall_sum = 0.0
    macro_f1_sum = 0.0
    
    total_instances = len(true_labels)

    for label in unique_true_labels: # Iterate in a defined order
        counts = metrics[label]
        tp = counts["tp"]
        fp = counts["fp"]
        fn = counts["fn"]

        # Precision: TP / (TP + FP)
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        # Recall: TP / (TP + FN)
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        # F1-Score: 2 * (Precision * Recall) / (Precision + Recall)
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall) > 0
            else 0.0
        )
        support = label_counts[label]

        results[label] = {
            "precision": round(precision, 3),
            "recall": round(recall, 3),
            "f1": round(f1, 3),
            "support": support,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "num_errors": fp + fn, # Sum of false positives and false negatives for this class
        }

        # Accumulate for micro-averages
        total_tp_micro += tp
        total_fp_micro += fp
        total_fn_micro += fn
        
        # Accumulate for macro-averages
        macro_precision_sum += precision
        macro_recall_sum += recall
        macro_f1_sum += f1

    # Calculate macro-averaged metrics
    num_labels = len(unique_true_labels)
    macro_precision = macro_precision_sum / num_labels if num_labels > 0 else 0.0
    macro_recall = macro_recall_sum / num_labels if num_labels > 0 else 0.0
    macro_f1 = macro_f1_sum / num_labels if num_labels > 0 else 0.0 # Often calculated as harmonic mean of macro_precision and macro_recall

    # Calculate micro-averaged (overall) precision, recall, and F1 score
    # Micro-Precision: Sum of all TPs / (Sum of all TPs + Sum of all FPs)
    micro_precision = (
        total_tp_micro / (total_tp_micro + total_fp_micro) if (total_tp_micro + total_fp_micro) > 0 else 0.0
    )
    # Micro-Recall: Sum of all TPs / (Sum of all TPs + Sum of all FNs)
    micro_recall = (
        total_tp_micro / (total_tp_micro + total_fn_micro) if (total_tp_micro + total_fn_micro) > 0 else 0.0
    )
    # Micro-F1: 2 * (Micro-Precision * Micro-Recall) / (Micro-Precision + Micro-Recall)
    micro_f1 = (
        2 * micro_precision * micro_recall / (micro_precision + micro_recall)
        if (micro_precision + micro_recall) > 0
        else 0.0
    )

    # Calculate weighted F1 score
    weighted_f1_sum = 0.0
    for label in unique_true_labels:
        weighted_f1_sum += results[label]["f1"] * label_counts[label]
    weighted_f1 = weighted_f1_sum / total_instances if total_instances > 0 else 0.0

    results["overall"] = {
        "micro_precision": round(micro_precision, 3),
        "micro_recall": round(micro_recall, 3),
        "micro_f1": round(micro_f1, 3),
        "macro_precision": round(macro_precision, 3),
        "macro_recall": round(macro_recall, 3),
        "macro_f1": round(macro_f1, 3),
        "weighted_f1": round(weighted_f1, 3),
        "support": total_instances,
        "total_tp": total_tp_micro,
        "total_fp": total_fp_micro,
        "total_fn": total_fn_micro,
        "num_errors": total_fp_micro + total_fn_micro, # Sum of all false positives and false negatives
    }
    return results


def n03_calculate_metrics(true_labels: pd.Series, predictions: pd.Series) -> dict:
    """
    Calculates precision, recall, F1-score, and support for N-stage predictions.
    Includes both per-label, micro-average, and macro-average scores.
    Handles specific label replacements: "NO" to "N0", "NL" to "N1".

    Args:
        true_labels: A pandas Series of true labels (e.g., integers from 0 to N-1).
        predictions: A pandas Series of predicted labels (strings).

    Returns:
        A dictionary containing the metrics for each label and overall scores.
    """
    # Check for valid inputs
    if len(true_labels) != len(predictions):
        raise ValueError("The length of true_labels and predictions must be the same.")

    if not isinstance(true_labels, pd.Series) or not isinstance(predictions, pd.Series):
        raise TypeError("true_labels and predictions must be pandas Series.")

    if any(pd.isna(pred) or not isinstance(pred, str) for pred in predictions):
        raise ValueError("All predictions must be non-null strings.")

    # Standardize true labels to "N{x}" format
    true_labels = true_labels.apply(lambda x: f"N{int(x)}")

    metrics = {}
    label_counts = {}
    unique_true_labels = sorted(list(set(true_labels))) # Ensure consistent order

    for label in unique_true_labels:
        metrics[label] = {"tp": 0, "fp": 0, "fn": 0}
        label_counts[label] = 0

    for true_label, prediction in zip(true_labels, predictions):
        # Ensure prediction is a string, convert to uppercase, and apply replacements
        prediction_str = str(prediction).upper()
        prediction_str = prediction_str.replace("NO", "N0").replace("NL", "N1")
        
        label_counts[true_label] += 1
        if true_label in prediction_str:
            metrics[true_label]["tp"] += 1
        else:
            metrics[true_label]["fn"] += 1

        # Calculate false positives
        for label_to_check_fp in unique_true_labels:
            if label_to_check_fp in prediction_str and label_to_check_fp != true_label:
                metrics[label_to_check_fp]["fp"] += 1

    results = {}
    # Variables for micro-averaging
    total_tp_micro = 0
    total_fp_micro = 0
    total_fn_micro = 0
    
    # Variables for macro-averaging
    macro_precision_sum = 0.0
    macro_recall_sum = 0.0
    macro_f1_sum = 0.0

    total_instances = len(true_labels)

    for label in unique_true_labels: # Iterate in a defined order
        counts = metrics[label]
        tp = counts["tp"]
        fp = counts["fp"]
        fn = counts["fn"]

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall) > 0
            else 0.0
        )
        support = label_counts[label]

        results[label] = {
            "precision": round(precision, 3),
            "recall": round(recall, 3),
            "f1": round(f1, 3),
            "support": support,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "num_errors": fp + fn,
        }

        total_tp_micro += tp
        total_fp_micro += fp
        total_fn_micro += fn

        macro_precision_sum += precision
        macro_recall_sum += recall
        macro_f1_sum += f1

    # Calculate macro-averaged metrics
    num_labels = len(unique_true_labels)
    macro_precision = macro_precision_sum / num_labels if num_labels > 0 else 0.0
    macro_recall = macro_recall_sum / num_labels if num_labels > 0 else 0.0
    macro_f1 = macro_f1_sum / num_labels if num_labels > 0 else 0.0

    # Calculate micro-averaged (overall) precision, recall, and F1 score
    micro_precision = (
        total_tp_micro / (total_tp_micro + total_fp_micro) if (total_tp_micro + total_fp_micro) > 0 else 0.0
    )
    micro_recall = (
        total_tp_micro / (total_tp_micro + total_fn_micro) if (total_tp_micro + total_fn_micro) > 0 else 0.0
    )
    micro_f1 = (
        2 * micro_precision * micro_recall / (micro_precision + micro_recall)
        if (micro_precision + micro_recall) > 0
        else 0.0
    )

    # Calculate weighted F1 score
    weighted_f1_sum = 0.0
    for label in unique_true_labels:
        weighted_f1_sum += results[label]["f1"] * label_counts[label]
    weighted_f1 = weighted_f1_sum / total_instances if total_instances > 0 else 0.0
    
    results["overall"] = {
        "micro_precision": round(micro_precision, 3),
        "micro_recall": round(micro_recall, 3),
        "micro_f1": round(micro_f1, 3),
        "macro_precision": round(macro_precision, 3),
        "macro_recall": round(macro_recall, 3),
        "macro_f1": round(macro_f1, 3),
        "weighted_f1": round(weighted_f1, 3),
        "support": total_instances,
        "total_tp": total_tp_micro,
        "total_fp": total_fp_micro,
        "total_fn": total_fn_micro,
        "num_errors": total_fp_micro + total_fn_micro,
    }

    return results

In [3]:
def calculate_mean_std(
        results: list[dict],
        cat: str | None,
        level: str = "label"     # "label", "micro", or "macro"
    ) -> dict:
    """
    Compute mean ± std for a single class ("label" level) or for the
    overall micro / macro aggregates produced by *t14_calculate_metrics*.

    Args
    ----
    results : list of metrics dicts (output of t14_calculate_metrics)
    cat     : class label (e.g. "T1") – ignored for micro/macro levels
    level   : "label" | "micro" | "macro"

    Returns
    -------
    dict with keys:
        mean_precision, mean_recall, mean_f1,
        std_precision,  std_recall,  std_f1,
        (plus sums and raw means used elsewhere)
    """
    # ------------------------------------------------------------------ #
    # Gather the three score lists                                       #
    # ------------------------------------------------------------------ #
    precision_list, recall_list, f1_list = [], [], []
    support_list, num_errors_list = [], []

    for res in results:
        if level == "label":
            src = res[cat]                                 # per‑class block
        elif level == "micro":
            src = {                                        # NEW ↓
                "precision": res["overall"]["micro_precision"],
                "recall":    res["overall"]["micro_recall"],
                "f1":        res["overall"]["micro_f1"],
                "support":   res["overall"]["support"],
                "num_errors": res["overall"]["num_errors"],
            }
        elif level == "macro":
            src = {                                        # NEW ↓
                "precision": res["overall"]["macro_precision"],
                "recall":    res["overall"]["macro_recall"],
                "f1":        res["overall"]["macro_f1"],
                "support":   res["overall"]["support"],
                "num_errors": res["overall"]["num_errors"],
            }
        else:
            raise ValueError(f"Unknown level: {level}")

        precision_list.append(src["precision"])
        recall_list.append(src["recall"])
        f1_list.append(src["f1"])
        support_list.append(src["support"])
        num_errors_list.append(src["num_errors"])

    # ------------------------------------------------------------------ #
    # Mean / std                                                         #
    # ------------------------------------------------------------------ #
    mean_p = sum(precision_list) / len(precision_list)
    mean_r = sum(recall_list)    / len(recall_list)
    mean_f = sum(f1_list)        / len(f1_list)

    std_p = (sum((x - mean_p) ** 2 for x in precision_list) / len(precision_list)) ** 0.5
    std_r = (sum((x - mean_r) ** 2 for x in recall_list)    / len(recall_list))    ** 0.5
    std_f = (sum((x - mean_f) ** 2 for x in f1_list)        / len(f1_list))        ** 0.5

    return {
        "mean_precision": round(mean_p, 3),
        "mean_recall":    round(mean_r, 3),
        "mean_f1":        round(mean_f, 3),
        "std_precision":  round(std_p, 3),
        "std_recall":     round(std_r, 3),
        "std_f1":         round(std_f, 3),
        "sum_support":    sum(support_list),
        "sum_num_errors": sum(num_errors_list),
        "raw_mean_precision": mean_p,   # keep raw for higher‑level macro
        "raw_mean_recall":    mean_r,
        "raw_mean_f1":        mean_f,
    }


###############################################################################
# Updated helper: output_tabular_performance                                  #
###############################################################################
def output_tabular_performance(
        results: list[dict],
        categories: list[str] = ("T1", "T2", "T3", "T4"),
        show_overall: bool = True         # NEW flag
    ) -> None:
    """
    Print mean ± std precision/recall/F1 for each class, followed by:
      • category‑macro average (same as before)
      • micro‑average (overall)
      • macro‑average (overall)
    """
    # ------------------------------------------------------------------ #
    # Per‑class lines                                                    #
    # ------------------------------------------------------------------ #
    label_means_p, label_means_r, label_means_f = [], [], []

    for cat in categories:
        stats = calculate_mean_std(results, cat, level="label")
        print(f"{cat:8s} "
              f"{stats['mean_precision']:.3f}({stats['std_precision']:.3f}) "
              f"{stats['mean_recall']:.3f}({stats['std_recall']:.3f}) "
              f"{stats['mean_f1']:.3f}({stats['std_f1']:.3f})")

        label_means_p.append(stats["raw_mean_precision"])
        label_means_r.append(stats["raw_mean_recall"])
        label_means_f.append(stats["raw_mean_f1"])

    # ------------------------------------------------------------------ #
    # Category‑macro (average of label means)                            #
    # ------------------------------------------------------------------ #
    print(f"{'Cat‑Macro':8s} "
          f"{sum(label_means_p)/len(label_means_p):.3f} "
          f"{sum(label_means_r)/len(label_means_r):.3f} "
          f"{sum(label_means_f)/len(label_means_f):.3f}")

    # ------------------------------------------------------------------ #
    # Overall micro / macro                                              #
    # ------------------------------------------------------------------ #
    if show_overall:
        micro = calculate_mean_std(results, None, level="micro")
        macro = calculate_mean_std(results, None, level="macro")

        print(f"{'MicroAvg.':8s} "
              f"{micro['mean_precision']:.3f}({micro['std_precision']:.3f}) "
              f"{micro['mean_recall']:.3f}({micro['std_recall']:.3f}) "
              f"{micro['mean_f1']:.3f}({micro['std_f1']:.3f})")

        print(f"{'MacroAvg.':8s} "
              f"{macro['mean_precision']:.3f}({macro['std_precision']:.3f}) "
              f"{macro['mean_recall']:.3f}({macro['std_recall']:.3f}) "
              f"{macro['mean_f1']:.3f}({macro['std_f1']:.3f})")

In [4]:
mixtral_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/mixtral_rag_result/0929_ltm_rag2.csv') # 이거야. 여기에 mixtral결과 다있어

print("Mixtral T-stage metrics:")

print("ZSCOT T-stage:")
print(t14_calculate_metrics(mixtral_df['t'], mixtral_df['zscot_t_stage'])['overall'])
print()
print("RAG T-stage:")
print(t14_calculate_metrics(mixtral_df['t'], mixtral_df['rag_raw_t_stage'])['overall'])
print()
kewltm_t_results = []
t_label = 't'
kewltm_t_stage = "cmem_t_40reports_ans_str"
run_lst = [0, 1, 2, 3, 4, 5, 6, 8]

for run in run_lst:
    t_test_df = pd.read_csv(
        f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv"
    ).sort_values(by="patient_filename")

    kewltm_t_results.append(
        t14_calculate_metrics(t_test_df[t_label], t_test_df[kewltm_t_stage])
    )
print("KEwLTM T-stage:")
output_tabular_performance(kewltm_t_results)

print()
print("KEwRAG T-stage:")
print(t14_calculate_metrics(mixtral_df['t'], mixtral_df['ltm_rag1_t_stage'])['overall'])


Mixtral T-stage metrics:
ZSCOT T-stage:
{'micro_precision': 0.85, 'micro_recall': 0.863, 'micro_f1': 0.856, 'macro_precision': 0.831, 'macro_recall': 0.765, 'macro_f1': 0.792, 'weighted_f1': 0.854, 'support': 800, 'total_tp': 690, 'total_fp': 122, 'total_fn': 110, 'num_errors': 232}

RAG T-stage:
{'micro_precision': 0.81, 'micro_recall': 0.815, 'micro_f1': 0.812, 'macro_precision': 0.771, 'macro_recall': 0.73, 'macro_f1': 0.743, 'weighted_f1': 0.812, 'support': 800, 'total_tp': 652, 'total_fp': 153, 'total_fn': 148, 'num_errors': 301}

KEwLTM T-stage:
T1       0.904(0.017) 0.812(0.040) 0.855(0.018)
T2       0.882(0.022) 0.938(0.018) 0.909(0.005)
T3       0.834(0.054) 0.810(0.058) 0.818(0.018)
T4       0.807(0.082) 0.634(0.038) 0.707(0.029)
Cat‑Macro 0.857 0.799 0.822
MicroAvg. 0.876(0.006) 0.878(0.007) 0.877(0.007)
MacroAvg. 0.857(0.022) 0.799(0.020) 0.822(0.010)

KEwRAG T-stage:
{'micro_precision': 0.853, 'micro_recall': 0.848, 'micro_f1': 0.85, 'macro_precision': 0.792, 'macro_recall

In [5]:
print("Mixtral N-stage metrics:")

print("ZSCOT N-stage:")
print(n03_calculate_metrics(mixtral_df['n'], mixtral_df['zscot_n_stage'])['overall'])
print()
print("RAG N-stage:")
print(n03_calculate_metrics(mixtral_df['n'], mixtral_df['rag_raw_n_stage'])['overall'])
print()
kewltm_n_results = []
n_label = 'n'
kewltm_n_stage = "cmem_n_40reports_ans_str"
run_lst = [0, 1, 3, 4, 5, 6, 7, 9]
for run in run_lst:
    n_test_df = pd.read_csv(
        f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv"
    ).sort_values(by="patient_filename")

    kewltm_n_results.append(
        n03_calculate_metrics(n_test_df[n_label], n_test_df[kewltm_n_stage])
    )
print("KEwLTM N-stage:")
output_tabular_performance(kewltm_n_results, categories=["N0", "N1", "N2", "N3"])
print()
print("KEwRAG N-stage:")
print(n03_calculate_metrics(mixtral_df['n'], mixtral_df['ltm_rag1_n_stage'])['overall'])


Mixtral N-stage metrics:
ZSCOT N-stage:
{'micro_precision': 0.874, 'micro_recall': 0.873, 'micro_f1': 0.873, 'macro_precision': 0.843, 'macro_recall': 0.822, 'macro_f1': 0.832, 'weighted_f1': 0.872, 'support': 800, 'total_tp': 698, 'total_fp': 101, 'total_fn': 102, 'num_errors': 203}

RAG N-stage:
{'micro_precision': 0.841, 'micro_recall': 0.835, 'micro_f1': 0.838, 'macro_precision': 0.803, 'macro_recall': 0.799, 'macro_f1': 0.797, 'weighted_f1': 0.84, 'support': 800, 'total_tp': 668, 'total_fp': 126, 'total_fn': 132, 'num_errors': 258}

KEwLTM N-stage:
N0       0.944(0.008) 0.952(0.018) 0.948(0.011)
N1       0.885(0.020) 0.883(0.026) 0.884(0.010)
N2       0.713(0.031) 0.745(0.054) 0.727(0.022)
N3       0.886(0.058) 0.784(0.042) 0.830(0.017)
Cat‑Macro 0.857 0.841 0.847
MicroAvg. 0.883(0.007) 0.883(0.007) 0.883(0.007)
MacroAvg. 0.857(0.011) 0.841(0.011) 0.847(0.008)

KEwRAG N-stage:
{'micro_precision': 0.853, 'micro_recall': 0.859, 'micro_f1': 0.856, 'macro_precision': 0.807, 'macro_rec

In [6]:
kewltm_t_stage = 'kepa_t_ans_str'
zscot_t_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1118_t14_med42_v2_test_800.csv').sort_values(by="patient_filename")[["patient_filename", 't', 'zscot_t_ans_str']]
rag_t_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1120_t14_rag_raw_med42_v2_800.csv').sort_values(by="patient_filename")[["patient_filename", 't', 't14_rag_raw_t_pred']]
ltm_rag_t_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1128_t14_ltm_rag1_med42_v2_800.csv').sort_values(by="patient_filename")[["patient_filename", 't', 't14_ltm_rag1_t_pred']]

print("Med42 T-stage metrics:")

print("ZSCOT T-stage:")
print(t14_calculate_metrics(zscot_t_df['t'], zscot_t_df['zscot_t_ans_str'])['overall'])
print()
print("RAG T-stage:")
print(t14_calculate_metrics(rag_t_df['t'], rag_t_df['t14_rag_raw_t_pred'])['overall'])
print()
kewltm_t_results = []
t_label = 't'
run_lst = [0, 1, 2, 3, 4, 5, 6, 8]

for run in run_lst:
    t_test_df = pd.read_csv(
        f"/home/yl3427/cylab/selfCorrectionAgent/result/1114_t14_med42_v2_test_{run}_outof_10runs.csv"
    ).sort_values(by="patient_filename")

    kewltm_t_results.append(
        t14_calculate_metrics(t_test_df[t_label], t_test_df[kewltm_t_stage])
    )
print("KEwLTM T-stage:")
output_tabular_performance(kewltm_t_results)

print()
print("KEwRAG T-stage:")
print(t14_calculate_metrics(ltm_rag_t_df['t'], ltm_rag_t_df['t14_ltm_rag1_t_pred'])['overall'])

Med42 T-stage metrics:
ZSCOT T-stage:
{'micro_precision': 0.769, 'micro_recall': 0.769, 'micro_f1': 0.769, 'macro_precision': 0.746, 'macro_recall': 0.678, 'macro_f1': 0.703, 'weighted_f1': 0.77, 'support': 800, 'total_tp': 615, 'total_fp': 185, 'total_fn': 185, 'num_errors': 370}

RAG T-stage:
{'micro_precision': 0.836, 'micro_recall': 0.836, 'micro_f1': 0.836, 'macro_precision': 0.786, 'macro_recall': 0.748, 'macro_f1': 0.764, 'weighted_f1': 0.834, 'support': 800, 'total_tp': 669, 'total_fp': 131, 'total_fn': 131, 'num_errors': 262}

KEwLTM T-stage:
T1       0.813(0.073) 0.759(0.076) 0.783(0.064)
T2       0.855(0.031) 0.913(0.023) 0.882(0.016)
T3       0.869(0.063) 0.703(0.099) 0.770(0.065)
T4       0.630(0.046) 0.615(0.057) 0.621(0.042)
Cat‑Macro 0.792 0.747 0.764
MicroAvg. 0.835(0.025) 0.835(0.025) 0.835(0.025)
MacroAvg. 0.792(0.032) 0.747(0.043) 0.764(0.034)

KEwRAG T-stage:
{'micro_precision': 0.879, 'micro_recall': 0.879, 'micro_f1': 0.879, 'macro_precision': 0.838, 'macro_recal

In [7]:
kewltm_n_stage = 'kepa_n_ans_str'
zscot_n_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1118_n03_med42_v2_test_800.csv').sort_values(by="patient_filename")[["patient_filename", 'n', 'zscot_n_ans_str']]
rag_n_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1120_n03_rag_raw_med42_v2_800.csv').sort_values(by="patient_filename")[["patient_filename", 'n', 'n03_rag_raw_n_pred']]
ltm_rag_n_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1128_n03_ltm_rag1_med42_v2_800.csv').sort_values(by="patient_filename")[["patient_filename", 'n', 'n03_ltm_rag1_n_pred']]

print("Med42 N-stage metrics:")

print("ZSCOT N-stage:")
print(n03_calculate_metrics(zscot_n_df['n'], zscot_n_df['zscot_n_ans_str'])['overall'])
print()
print("RAG N-stage:")
print(n03_calculate_metrics(rag_n_df['n'], rag_n_df['n03_rag_raw_n_pred'])['overall'])
print()
kewltm_n_results = []
n_label = 'n'
run_lst = [0, 1, 3, 4, 5, 6, 7, 9]

for run in run_lst:
    n_test_df = pd.read_csv(
        f"/home/yl3427/cylab/selfCorrectionAgent/result/1114_n03_med42_v2_test_{run}_outof_10runs.csv"
    ).sort_values(by="patient_filename")

    kewltm_n_results.append(
        n03_calculate_metrics(n_test_df[n_label], n_test_df[kewltm_n_stage])
    )
print("KEwLTM N-stage:")
output_tabular_performance(kewltm_n_results, categories=["N0", "N1", "N2", "N3"])

print()
print("KEwRAG N-stage:")
print(n03_calculate_metrics(ltm_rag_n_df['n'], ltm_rag_n_df['n03_ltm_rag1_n_pred'])['overall'])

Med42 N-stage metrics:
ZSCOT N-stage:
{'micro_precision': 0.738, 'micro_recall': 0.738, 'micro_f1': 0.738, 'macro_precision': 0.748, 'macro_recall': 0.723, 'macro_f1': 0.724, 'weighted_f1': 0.742, 'support': 800, 'total_tp': 590, 'total_fp': 210, 'total_fn': 210, 'num_errors': 420}

RAG N-stage:
{'micro_precision': 0.79, 'micro_recall': 0.79, 'micro_f1': 0.79, 'macro_precision': 0.76, 'macro_recall': 0.799, 'macro_f1': 0.759, 'weighted_f1': 0.796, 'support': 800, 'total_tp': 632, 'total_fp': 168, 'total_fn': 168, 'num_errors': 336}

KEwLTM N-stage:
N0       0.950(0.011) 0.821(0.059) 0.879(0.032)
N1       0.775(0.056) 0.821(0.032) 0.795(0.022)
N2       0.657(0.067) 0.711(0.076) 0.675(0.018)
N3       0.759(0.103) 0.858(0.029) 0.800(0.062)
Cat‑Macro 0.785 0.803 0.787
MicroAvg. 0.809(0.024) 0.809(0.024) 0.809(0.024)
MacroAvg. 0.785(0.021) 0.803(0.025) 0.788(0.026)

KEwRAG N-stage:
{'micro_precision': 0.88, 'micro_recall': 0.88, 'micro_f1': 0.88, 'macro_precision': 0.845, 'macro_recall': 0.