In [1]:
import pandas as pd
import numpy as np


###############################################################################
# T‑stage metrics                                                             #
###############################################################################
def t14_calculate_metrics(true_labels: pd.Series, predictions: pd.Series) -> dict:
    """
    Calculate per‑class and overall precision / recall / F1 for T‑stage
    predictions.  Adds *num_errors* at the “overall” level that counts each
    wrong row exactly once.
    """
    # ------------------------------------------------------------------ #
    # Sanity checks                                                      #
    # ------------------------------------------------------------------ #
    if len(true_labels) != len(predictions):
        raise ValueError("true_labels and predictions must be the same length")
    if not isinstance(true_labels, pd.Series) or not isinstance(predictions, pd.Series):
        raise TypeError("true_labels and predictions must be pandas Series")
    if any(pd.isna(pred) or not isinstance(pred, str) for pred in predictions):
        raise ValueError("All predictions must be non‑null strings")

    true_labels = true_labels.apply(lambda x: f"T{int(x) + 1}")          # T1‑T4 …
    unique_true_labels = sorted(set(true_labels))

    metrics = {lab: {"tp": 0, "fp": 0, "fn": 0} for lab in unique_true_labels}
    label_counts = {lab: 0 for lab in unique_true_labels}

    # ------------------------------------------------------------------ #
    # Count TP / FP / FN                                                 #
    # ------------------------------------------------------------------ #
    for true_lab, pred in zip(true_labels, predictions):
        pred_u = str(pred).upper()

        label_counts[true_lab] += 1
        if true_lab in pred_u:                                           # TP test
            metrics[true_lab]["tp"] += 1
        else:
            metrics[true_lab]["fn"] += 1

        for lab in unique_true_labels:                                  # FP test
            if lab in pred_u and lab != true_lab:
                metrics[lab]["fp"] += 1

    # ------------------------------------------------------------------ #
    # Derive precision/recall/F1 per class                               #
    # ------------------------------------------------------------------ #
    results = {}
    total_tp = total_fp = total_fn = 0
    macro_prec = macro_rec = macro_f1 = 0.0
    N = len(true_labels)

    for lab in unique_true_labels:
        tp, fp, fn = (metrics[lab][k] for k in ("tp", "fp", "fn"))

        prec = tp / (tp + fp) if tp + fp else 0.0
        rec  = tp / (tp + fn) if tp + fn else 0.0
        f1   = 2 * prec * rec / (prec + rec) if prec + rec else 0.0

        results[lab] = {
            "precision": round(prec, 3),
            "recall":    round(rec, 3),
            "f1":        round(f1, 3),
            "support":   label_counts[lab],
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "num_errors": fp + fn,            # class‑level (kept as before)
        }

        total_tp += tp
        total_fp += fp
        total_fn += fn
        macro_prec += prec
        macro_rec  += rec
        macro_f1   += f1

    # ------------------------------------------------------------------ #
    # Micro / macro + row‑level error count                              #
    # ------------------------------------------------------------------ #
    micro_prec = total_tp / (total_tp + total_fp) if total_tp + total_fp else 0.0
    micro_rec  = total_tp / (total_tp + total_fn) if total_tp + total_fn else 0.0
    micro_f1   = 2 * micro_prec * micro_rec / (micro_prec + micro_rec) if micro_prec + micro_rec else 0.0

    macro_prec /= len(unique_true_labels)
    macro_rec  /= len(unique_true_labels)
    macro_f1   /= len(unique_true_labels)

    weighted_f1 = sum(results[lab]["f1"] * label_counts[lab] for lab in unique_true_labels) / N

    results["overall"] = {
        "micro_precision": round(micro_prec, 3),
        "micro_recall":    round(micro_rec, 3),
        "micro_f1":        round(micro_f1, 3),
        "macro_precision": round(macro_prec, 3),
        "macro_recall":    round(macro_rec, 3),
        "macro_f1":        round(macro_f1, 3),
        "weighted_f1":     round(weighted_f1, 3),
        "support":         N,
        "total_tp":        total_tp,
        "total_fp":        total_fp,
        "total_fn":        total_fn,
        # wrong rows counted once:
        "num_errors":      N - total_tp,
    }
    return results


###############################################################################
# N‑stage metrics (same changes as above)                                     #
###############################################################################
def n03_calculate_metrics(true_labels: pd.Series, predictions: pd.Series) -> dict:
    """Same as t14_calculate_metrics but for N‑stage labels."""
    if len(true_labels) != len(predictions):
        raise ValueError("true_labels and predictions must be the same length")
    if not isinstance(true_labels, pd.Series) or not isinstance(predictions, pd.Series):
        raise TypeError("true_labels and predictions must be pandas Series")
    if any(pd.isna(pred) or not isinstance(pred, str) for pred in predictions):
        raise ValueError("All predictions must be non‑null strings")

    true_labels = true_labels.apply(lambda x: f"N{int(x)}")
    unique_true_labels = sorted(set(true_labels))

    metrics = {lab: {"tp": 0, "fp": 0, "fn": 0} for lab in unique_true_labels}
    label_counts = {lab: 0 for lab in unique_true_labels}

    for true_lab, pred in zip(true_labels, predictions):
        pred_u = str(pred).upper().replace("NO", "N0").replace("NL", "N1")

        label_counts[true_lab] += 1
        if true_lab in pred_u:
            metrics[true_lab]["tp"] += 1
        else:
            metrics[true_lab]["fn"] += 1

        for lab in unique_true_labels:
            if lab in pred_u and lab != true_lab:
                metrics[lab]["fp"] += 1

    results = {}
    total_tp = total_fp = total_fn = 0
    macro_prec = macro_rec = macro_f1 = 0.0
    N = len(true_labels)

    for lab in unique_true_labels:
        tp, fp, fn = (metrics[lab][k] for k in ("tp", "fp", "fn"))
        prec = tp / (tp + fp) if tp + fp else 0.0
        rec  = tp / (tp + fn) if tp + fn else 0.0
        f1   = 2 * prec * rec / (prec + rec) if prec + rec else 0.0

        results[lab] = {
            "precision": round(prec, 3),
            "recall":    round(rec, 3),
            "f1":        round(f1, 3),
            "support":   label_counts[lab],
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "num_errors": fp + fn,
        }

        total_tp += tp
        total_fp += fp
        total_fn += fn
        macro_prec += prec
        macro_rec  += rec
        macro_f1   += f1

    micro_prec = total_tp / (total_tp + total_fp) if total_tp + total_fp else 0.0
    micro_rec  = total_tp / (total_tp + total_fn) if total_tp + total_fn else 0.0
    micro_f1   = 2 * micro_prec * micro_rec / (micro_prec + micro_rec) if micro_prec + micro_rec else 0.0

    macro_prec /= len(unique_true_labels)
    macro_rec  /= len(unique_true_labels)
    macro_f1   /= len(unique_true_labels)

    weighted_f1 = sum(results[lab]["f1"] * label_counts[lab] for lab in unique_true_labels) / N

    results["overall"] = {
        "micro_precision": round(micro_prec, 3),
        "micro_recall":    round(micro_rec, 3),
        "micro_f1":        round(micro_f1, 3),
        "macro_precision": round(macro_prec, 3),
        "macro_recall":    round(macro_rec, 3),
        "macro_f1":        round(macro_f1, 3),
        "weighted_f1":     round(weighted_f1, 3),
        "support":         N,
        "total_tp":        total_tp,
        "total_fp":        total_fp,
        "total_fn":        total_fn,
        "num_errors":      N - total_tp,
    }
    return results


###############################################################################
# Helper: print error counts per split                                        #
###############################################################################
def print_error_counts(results):
    fp_list, fn_list, row_err_list = [], [], []

    print("Run  WrongRows  FN  FP  (FP+FN)")
    for i, res in enumerate(results):
        fp  = res["overall"]["total_fp"]
        fn  = res["overall"]["total_fn"]
        err = res["overall"]["num_errors"]          # each row once

        fp_list.append(fp)
        fn_list.append(fn)
        row_err_list.append(err)

        print(f"{i:<4} {err:<10} {fn:<3} {fp:<3} {fp+fn}")

    print("\nMean ± Std across splits")
    print(f"Wrong rows: {np.mean(row_err_list):.2f} ± {np.std(row_err_list):.2f}")
    print(f"FN:         {np.mean(fn_list):.2f} ± {np.std(fn_list):.2f}")
    print(f"FP:         {np.mean(fp_list):.2f} ± {np.std(fp_list):.2f}")


###############################################################################
# calculate_mean_std and output_tabular_performance remain unchanged          #
# (they already pick up the updated overall['num_errors'])                    #
###############################################################################
# … (rest of your notebook cells stay exactly as before)


In [None]:
import pandas as pd

In [None]:
# import pandas as pd
# import numpy as np

# def t14_calculate_metrics(true_labels: pd.Series, predictions: pd.Series) -> dict:
#     """
#     Calculates precision, recall, F1-score, and support for T-stage predictions.
#     Includes both per-label, micro-average, and macro-average scores.

#     Args:
#         true_labels: A pandas Series of true labels (e.g., integers from 0 to N-1).
#         predictions: A pandas Series of predicted labels (strings).

#     Returns:
#         A dictionary containing the metrics for each label and overall scores.
#     """
#     # Check for valid inputs
#     if len(true_labels) != len(predictions):
#         raise ValueError("The length of true_labels and predictions must be the same.")

#     if not isinstance(true_labels, pd.Series) or not isinstance(predictions, pd.Series):
#         raise TypeError("true_labels and predictions must be pandas Series.")

#     if any(pd.isna(pred) or not isinstance(pred, str) for pred in predictions):
#         raise ValueError("All predictions must be non-null strings.")

#     # Standardize true labels to "T{x+1}" format
#     true_labels = true_labels.apply(lambda x: f"T{int(x)+1}")

#     metrics = {}
#     label_counts = {}
#     unique_true_labels = sorted(list(set(true_labels))) # Ensure consistent order

#     for label in unique_true_labels:
#         metrics[label] = {"tp": 0, "fp": 0, "fn": 0}
#         label_counts[label] = 0

#     for true_label, prediction in zip(true_labels, predictions):
#         # Ensure prediction is a string and convert to uppercase
#         prediction_str = str(prediction).upper()
        
#         label_counts[true_label] += 1
#         if true_label in prediction_str:
#             metrics[true_label]["tp"] += 1
#         else:
#             metrics[true_label]["fn"] += 1

#         # Calculate false positives
#         # A prediction is a false positive for a label if:
#         # 1. The label is present in the prediction string.
#         # 2. The label is NOT the true_label.
#         for label_to_check_fp in unique_true_labels:
#             if label_to_check_fp in prediction_str and label_to_check_fp != true_label:
#                 metrics[label_to_check_fp]["fp"] += 1
    
#     results = {}
#     # Variables for micro-averaging
#     total_tp_micro = 0
#     total_fp_micro = 0
#     total_fn_micro = 0
    
#     # Variables for macro-averaging
#     macro_precision_sum = 0.0
#     macro_recall_sum = 0.0
#     macro_f1_sum = 0.0
    
#     total_instances = len(true_labels)

#     for label in unique_true_labels: # Iterate in a defined order
#         counts = metrics[label]
#         tp = counts["tp"]
#         fp = counts["fp"]
#         fn = counts["fn"]

#         # Precision: TP / (TP + FP)
#         precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
#         # Recall: TP / (TP + FN)
#         recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
#         # F1-Score: 2 * (Precision * Recall) / (Precision + Recall)
#         f1 = (
#             2 * precision * recall / (precision + recall)
#             if (precision + recall) > 0
#             else 0.0
#         )
#         support = label_counts[label]

#         results[label] = {
#             "precision": round(precision, 3),
#             "recall": round(recall, 3),
#             "f1": round(f1, 3),
#             "support": support,
#             "tp": tp,
#             "fp": fp,
#             "fn": fn,
#             "num_errors": fp + fn, # Sum of false positives and false negatives for this class
#         }

#         # Accumulate for micro-averages
#         total_tp_micro += tp
#         total_fp_micro += fp
#         total_fn_micro += fn
        
#         # Accumulate for macro-averages
#         macro_precision_sum += precision
#         macro_recall_sum += recall
#         macro_f1_sum += f1

#     # Calculate macro-averaged metrics
#     num_labels = len(unique_true_labels)
#     macro_precision = macro_precision_sum / num_labels if num_labels > 0 else 0.0
#     macro_recall = macro_recall_sum / num_labels if num_labels > 0 else 0.0
#     macro_f1 = macro_f1_sum / num_labels if num_labels > 0 else 0.0 # Often calculated as harmonic mean of macro_precision and macro_recall

#     # Calculate micro-averaged (overall) precision, recall, and F1 score
#     # Micro-Precision: Sum of all TPs / (Sum of all TPs + Sum of all FPs)
#     micro_precision = (
#         total_tp_micro / (total_tp_micro + total_fp_micro) if (total_tp_micro + total_fp_micro) > 0 else 0.0
#     )
#     # Micro-Recall: Sum of all TPs / (Sum of all TPs + Sum of all FNs)
#     micro_recall = (
#         total_tp_micro / (total_tp_micro + total_fn_micro) if (total_tp_micro + total_fn_micro) > 0 else 0.0
#     )
#     # Micro-F1: 2 * (Micro-Precision * Micro-Recall) / (Micro-Precision + Micro-Recall)
#     micro_f1 = (
#         2 * micro_precision * micro_recall / (micro_precision + micro_recall)
#         if (micro_precision + micro_recall) > 0
#         else 0.0
#     )

#     # Calculate weighted F1 score
#     weighted_f1_sum = 0.0
#     for label in unique_true_labels:
#         weighted_f1_sum += results[label]["f1"] * label_counts[label]
#     weighted_f1 = weighted_f1_sum / total_instances if total_instances > 0 else 0.0

#     results["overall"] = {
#         "micro_precision": round(micro_precision, 3),
#         "micro_recall": round(micro_recall, 3),
#         "micro_f1": round(micro_f1, 3),
#         "macro_precision": round(macro_precision, 3),
#         "macro_recall": round(macro_recall, 3),
#         "macro_f1": round(macro_f1, 3),
#         "weighted_f1": round(weighted_f1, 3),
#         "support": total_instances,
#         "total_tp": total_tp_micro,
#         "total_fp": total_fp_micro,
#         "total_fn": total_fn_micro,
#         "num_errors": total_fp_micro + total_fn_micro, # Sum of all false positives and false negatives
#     }
#     return results


# def n03_calculate_metrics(true_labels: pd.Series, predictions: pd.Series) -> dict:
#     """
#     Calculates precision, recall, F1-score, and support for N-stage predictions.
#     Includes both per-label, micro-average, and macro-average scores.
#     Handles specific label replacements: "NO" to "N0", "NL" to "N1".

#     Args:
#         true_labels: A pandas Series of true labels (e.g., integers from 0 to N-1).
#         predictions: A pandas Series of predicted labels (strings).

#     Returns:
#         A dictionary containing the metrics for each label and overall scores.
#     """
#     # Check for valid inputs
#     if len(true_labels) != len(predictions):
#         raise ValueError("The length of true_labels and predictions must be the same.")

#     if not isinstance(true_labels, pd.Series) or not isinstance(predictions, pd.Series):
#         raise TypeError("true_labels and predictions must be pandas Series.")

#     if any(pd.isna(pred) or not isinstance(pred, str) for pred in predictions):
#         raise ValueError("All predictions must be non-null strings.")

#     # Standardize true labels to "N{x}" format
#     true_labels = true_labels.apply(lambda x: f"N{int(x)}")

#     metrics = {}
#     label_counts = {}
#     unique_true_labels = sorted(list(set(true_labels))) # Ensure consistent order

#     for label in unique_true_labels:
#         metrics[label] = {"tp": 0, "fp": 0, "fn": 0}
#         label_counts[label] = 0

#     for true_label, prediction in zip(true_labels, predictions):
#         # Ensure prediction is a string, convert to uppercase, and apply replacements
#         prediction_str = str(prediction).upper()
#         prediction_str = prediction_str.replace("NO", "N0").replace("NL", "N1")
        
#         label_counts[true_label] += 1
#         if true_label in prediction_str:
#             metrics[true_label]["tp"] += 1
#         else:
#             metrics[true_label]["fn"] += 1

#         # Calculate false positives
#         for label_to_check_fp in unique_true_labels:
#             if label_to_check_fp in prediction_str and label_to_check_fp != true_label:
#                 metrics[label_to_check_fp]["fp"] += 1

#     results = {}
#     # Variables for micro-averaging
#     total_tp_micro = 0
#     total_fp_micro = 0
#     total_fn_micro = 0
    
#     # Variables for macro-averaging
#     macro_precision_sum = 0.0
#     macro_recall_sum = 0.0
#     macro_f1_sum = 0.0

#     total_instances = len(true_labels)

#     for label in unique_true_labels: # Iterate in a defined order
#         counts = metrics[label]
#         tp = counts["tp"]
#         fp = counts["fp"]
#         fn = counts["fn"]

#         precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
#         recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
#         f1 = (
#             2 * precision * recall / (precision + recall)
#             if (precision + recall) > 0
#             else 0.0
#         )
#         support = label_counts[label]

#         results[label] = {
#             "precision": round(precision, 3),
#             "recall": round(recall, 3),
#             "f1": round(f1, 3),
#             "support": support,
#             "tp": tp,
#             "fp": fp,
#             "fn": fn,
#             "num_errors": fp + fn,
#         }

#         total_tp_micro += tp
#         total_fp_micro += fp
#         total_fn_micro += fn

#         macro_precision_sum += precision
#         macro_recall_sum += recall
#         macro_f1_sum += f1

#     # Calculate macro-averaged metrics
#     num_labels = len(unique_true_labels)
#     macro_precision = macro_precision_sum / num_labels if num_labels > 0 else 0.0
#     macro_recall = macro_recall_sum / num_labels if num_labels > 0 else 0.0
#     macro_f1 = macro_f1_sum / num_labels if num_labels > 0 else 0.0

#     # Calculate micro-averaged (overall) precision, recall, and F1 score
#     micro_precision = (
#         total_tp_micro / (total_tp_micro + total_fp_micro) if (total_tp_micro + total_fp_micro) > 0 else 0.0
#     )
#     micro_recall = (
#         total_tp_micro / (total_tp_micro + total_fn_micro) if (total_tp_micro + total_fn_micro) > 0 else 0.0
#     )
#     micro_f1 = (
#         2 * micro_precision * micro_recall / (micro_precision + micro_recall)
#         if (micro_precision + micro_recall) > 0
#         else 0.0
#     )

#     # Calculate weighted F1 score
#     weighted_f1_sum = 0.0
#     for label in unique_true_labels:
#         weighted_f1_sum += results[label]["f1"] * label_counts[label]
#     weighted_f1 = weighted_f1_sum / total_instances if total_instances > 0 else 0.0
    
#     results["overall"] = {
#         "micro_precision": round(micro_precision, 3),
#         "micro_recall": round(micro_recall, 3),
#         "micro_f1": round(micro_f1, 3),
#         "macro_precision": round(macro_precision, 3),
#         "macro_recall": round(macro_recall, 3),
#         "macro_f1": round(macro_f1, 3),
#         "weighted_f1": round(weighted_f1, 3),
#         "support": total_instances,
#         "total_tp": total_tp_micro,
#         "total_fp": total_fp_micro,
#         "total_fn": total_fn_micro,
#         "num_errors": total_fp_micro + total_fn_micro,
#     }

#     return results

In [2]:
# def print_error_counts(results):
#     fp_list = []
#     fn_list = []
#     num_errors_list = []

#     print("Run   FP    FN    FP+FN")
#     for i, res in enumerate(results):
#         fp = res["overall"]["total_fp"]
#         fn = res["overall"]["total_fn"]
#         num_errors = res["overall"]["num_errors"]
#         fp_list.append(fp)
#         fn_list.append(fn)
#         num_errors_list.append(num_errors)
#         print(f"{i:<5} {fp:<5} {fn:<5} {num_errors:<6}")

#     print("\nMean ± Std across splits:")
#     print(f"FP:    {np.mean(fp_list):.2f} ± {np.std(fp_list):.2f}")
#     print(f"FN:    {np.mean(fn_list):.2f} ± {np.std(fn_list):.2f}")
#     print(f"FP+FN: {np.mean(num_errors_list):.2f} ± {np.std(num_errors_list):.2f}")

def calculate_mean_std(
        results: list[dict],
        cat: str | None,
        level: str = "label"     # "label", "micro", or "macro"
    ) -> dict:
    """
    Compute mean ± std for a single class ("label" level) or for the
    overall micro / macro aggregates produced by *t14_calculate_metrics*.

    Args
    ----
    results : list of metrics dicts (output of t14_calculate_metrics)
    cat     : class label (e.g. "T1") – ignored for micro/macro levels
    level   : "label" | "micro" | "macro"

    Returns
    -------
    dict with keys:
        mean_precision, mean_recall, mean_f1,
        std_precision,  std_recall,  std_f1,
        (plus sums and raw means used elsewhere)
    """
    # ------------------------------------------------------------------ #
    # Gather the three score lists                                       #
    # ------------------------------------------------------------------ #
    precision_list, recall_list, f1_list = [], [], []
    support_list, num_errors_list = [], []

    for res in results:
        if level == "label":
            src = res[cat]                                 # per‑class block
        elif level == "micro":
            src = {                                        # NEW ↓
                "precision": res["overall"]["micro_precision"],
                "recall":    res["overall"]["micro_recall"],
                "f1":        res["overall"]["micro_f1"],
                "support":   res["overall"]["support"],
                "num_errors": res["overall"]["num_errors"],
            }
        elif level == "macro":
            src = {                                        # NEW ↓
                "precision": res["overall"]["macro_precision"],
                "recall":    res["overall"]["macro_recall"],
                "f1":        res["overall"]["macro_f1"],
                "support":   res["overall"]["support"],
                "num_errors": res["overall"]["num_errors"],
            }
        else:
            raise ValueError(f"Unknown level: {level}")

        precision_list.append(src["precision"])
        recall_list.append(src["recall"])
        f1_list.append(src["f1"])
        support_list.append(src["support"])
        num_errors_list.append(src["num_errors"])

    # ------------------------------------------------------------------ #
    # Mean / std                                                         #
    # ------------------------------------------------------------------ #
    mean_p = sum(precision_list) / len(precision_list)
    mean_r = sum(recall_list)    / len(recall_list)
    mean_f = sum(f1_list)        / len(f1_list)

    std_p = (sum((x - mean_p) ** 2 for x in precision_list) / len(precision_list)) ** 0.5
    std_r = (sum((x - mean_r) ** 2 for x in recall_list)    / len(recall_list))    ** 0.5
    std_f = (sum((x - mean_f) ** 2 for x in f1_list)        / len(f1_list))        ** 0.5

    return {
        "mean_precision": round(mean_p, 3),
        "mean_recall":    round(mean_r, 3),
        "mean_f1":        round(mean_f, 3),
        "std_precision":  round(std_p, 3),
        "std_recall":     round(std_r, 3),
        "std_f1":         round(std_f, 3),
        "sum_support":    sum(support_list),
        "sum_num_errors": sum(num_errors_list),
        "raw_mean_precision": mean_p,   # keep raw for higher‑level macro
        "raw_mean_recall":    mean_r,
        "raw_mean_f1":        mean_f,
    }


###############################################################################
# Updated helper: output_tabular_performance                                  #
###############################################################################
def output_tabular_performance(
        results: list[dict],
        categories: list[str] = ("T1", "T2", "T3", "T4"),
        show_overall: bool = True         # NEW flag
    ) -> None:
    """
    Print mean ± std precision/recall/F1 for each class, followed by:
      • category‑macro average (same as before)
      • micro‑average (overall)
      • macro‑average (overall)
    """
    # ------------------------------------------------------------------ #
    # Per‑class lines                                                    #
    # ------------------------------------------------------------------ #
    label_means_p, label_means_r, label_means_f = [], [], []

    for cat in categories:
        stats = calculate_mean_std(results, cat, level="label")
        print(f"{cat:8s} "
              f"{stats['mean_precision']:.3f}({stats['std_precision']:.3f}) "
              f"{stats['mean_recall']:.3f}({stats['std_recall']:.3f}) "
              f"{stats['mean_f1']:.3f}({stats['std_f1']:.3f})")

        label_means_p.append(stats["raw_mean_precision"])
        label_means_r.append(stats["raw_mean_recall"])
        label_means_f.append(stats["raw_mean_f1"])

    # ------------------------------------------------------------------ #
    # Category‑macro (average of label means)                            #
    # ------------------------------------------------------------------ #
    print(f"{'Cat‑Macro':8s} "
          f"{sum(label_means_p)/len(label_means_p):.3f} "
          f"{sum(label_means_r)/len(label_means_r):.3f} "
          f"{sum(label_means_f)/len(label_means_f):.3f}")

    # ------------------------------------------------------------------ #
    # Overall micro / macro                                              #
    # ------------------------------------------------------------------ #
    if show_overall:
        micro = calculate_mean_std(results, None, level="micro")
        macro = calculate_mean_std(results, None, level="macro")

        print(f"{'MicroAvg.':8s} "
              f"{micro['mean_precision']:.3f}({micro['std_precision']:.3f}) "
              f"{micro['mean_recall']:.3f}({micro['std_recall']:.3f}) "
              f"{micro['mean_f1']:.3f}({micro['std_f1']:.3f})")

        print(f"{'MacroAvg.':8s} "
              f"{macro['mean_precision']:.3f}({macro['std_precision']:.3f}) "
              f"{macro['mean_recall']:.3f}({macro['std_recall']:.3f}) "
              f"{macro['mean_f1']:.3f}({macro['std_f1']:.3f})")

In [3]:
mixtral_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/mixtral_rag_result/0929_ltm_rag2.csv') # 이거야. 여기에 mixtral결과 다있어

print("Mixtral T-stage metrics:")

print("ZSCOT T-stage:")
print(t14_calculate_metrics(mixtral_df['t'], mixtral_df['zscot_t_stage'])['overall'])
print()
print("RAG T-stage:")
print(t14_calculate_metrics(mixtral_df['t'], mixtral_df['rag_raw_t_stage'])['overall'])
print()
kewltm_t_results = []
t_label = 't'
kewltm_t_stage = "cmem_t_40reports_ans_str"
run_lst = [0, 1, 2, 3, 4, 5, 6, 8]

for run in run_lst:
    t_test_df = pd.read_csv(
        f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv"
    ).sort_values(by="patient_filename")

    kewltm_t_results.append(
        t14_calculate_metrics(t_test_df[t_label], t_test_df[kewltm_t_stage])
    )
print("KEwLTM T-stage:")
output_tabular_performance(kewltm_t_results)
print_error_counts(kewltm_t_results)

print()
print("KEwRAG T-stage:")
print(t14_calculate_metrics(mixtral_df['t'], mixtral_df['ltm_rag1_t_stage'])['overall'])


Mixtral T-stage metrics:
ZSCOT T-stage:
{'micro_precision': 0.85, 'micro_recall': 0.863, 'micro_f1': 0.856, 'macro_precision': 0.831, 'macro_recall': 0.765, 'macro_f1': 0.792, 'weighted_f1': 0.854, 'support': 800, 'total_tp': 690, 'total_fp': 122, 'total_fn': 110, 'num_errors': 110}

RAG T-stage:
{'micro_precision': 0.81, 'micro_recall': 0.815, 'micro_f1': 0.812, 'macro_precision': 0.771, 'macro_recall': 0.73, 'macro_f1': 0.743, 'weighted_f1': 0.812, 'support': 800, 'total_tp': 652, 'total_fp': 153, 'total_fn': 148, 'num_errors': 148}

KEwLTM T-stage:
T1       0.904(0.017) 0.812(0.040) 0.855(0.018)
T2       0.882(0.022) 0.938(0.018) 0.909(0.005)
T3       0.834(0.054) 0.810(0.058) 0.818(0.018)
T4       0.807(0.082) 0.634(0.038) 0.707(0.029)
Cat‑Macro 0.857 0.799 0.822
MicroAvg. 0.876(0.006) 0.878(0.007) 0.877(0.007)
MacroAvg. 0.857(0.022) 0.799(0.020) 0.822(0.010)
Run  WrongRows  FN  FP  (FP+FN)
0    79         79  81  160
1    88         88  89  177
2    95         95  95  190
3    83 

In [4]:
print("Mixtral N-stage metrics:")

print("ZSCOT N-stage:")
print(n03_calculate_metrics(mixtral_df['n'], mixtral_df['zscot_n_stage'])['overall'])
print()
print("RAG N-stage:")
print(n03_calculate_metrics(mixtral_df['n'], mixtral_df['rag_raw_n_stage'])['overall'])
print()
kewltm_n_results = []
n_label = 'n'
kewltm_n_stage = "cmem_n_40reports_ans_str"
run_lst = [0, 1, 3, 4, 5, 6, 7, 9]
for run in run_lst:
    n_test_df = pd.read_csv(
        f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv"
    ).sort_values(by="patient_filename")

    kewltm_n_results.append(
        n03_calculate_metrics(n_test_df[n_label], n_test_df[kewltm_n_stage])
    )
print("KEwLTM N-stage:")
output_tabular_performance(kewltm_n_results, categories=["N0", "N1", "N2", "N3"])
print_error_counts(kewltm_n_results)

print()
print("KEwRAG N-stage:")
print(n03_calculate_metrics(mixtral_df['n'], mixtral_df['ltm_rag1_n_stage'])['overall'])


Mixtral N-stage metrics:
ZSCOT N-stage:
{'micro_precision': 0.874, 'micro_recall': 0.873, 'micro_f1': 0.873, 'macro_precision': 0.843, 'macro_recall': 0.822, 'macro_f1': 0.832, 'weighted_f1': 0.872, 'support': 800, 'total_tp': 698, 'total_fp': 101, 'total_fn': 102, 'num_errors': 102}

RAG N-stage:
{'micro_precision': 0.841, 'micro_recall': 0.835, 'micro_f1': 0.838, 'macro_precision': 0.803, 'macro_recall': 0.799, 'macro_f1': 0.797, 'weighted_f1': 0.84, 'support': 800, 'total_tp': 668, 'total_fp': 126, 'total_fn': 132, 'num_errors': 132}

KEwLTM N-stage:
N0       0.944(0.008) 0.952(0.018) 0.948(0.011)
N1       0.885(0.020) 0.883(0.026) 0.884(0.010)
N2       0.713(0.031) 0.745(0.054) 0.727(0.022)
N3       0.886(0.058) 0.784(0.042) 0.830(0.017)
Cat‑Macro 0.857 0.841 0.847
MicroAvg. 0.883(0.007) 0.883(0.007) 0.883(0.007)
MacroAvg. 0.857(0.011) 0.841(0.011) 0.847(0.008)
Run  WrongRows  FN  FP  (FP+FN)
0    85         85  85  170
1    78         78  78  156
2    85         85  84  169
3    8

In [5]:
kewltm_t_stage = 'kepa_t_ans_str'
zscot_t_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1118_t14_med42_v2_test_800.csv').sort_values(by="patient_filename")[["patient_filename", 't', 'zscot_t_ans_str']]
rag_t_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1120_t14_rag_raw_med42_v2_800.csv').sort_values(by="patient_filename")[["patient_filename", 't', 't14_rag_raw_t_pred']]
ltm_rag_t_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1128_t14_ltm_rag1_med42_v2_800.csv').sort_values(by="patient_filename")[["patient_filename", 't', 't14_ltm_rag1_t_pred']]

print("Med42 T-stage metrics:")

print("ZSCOT T-stage:")
print(t14_calculate_metrics(zscot_t_df['t'], zscot_t_df['zscot_t_ans_str'])['overall'])
print()
print("RAG T-stage:")
print(t14_calculate_metrics(rag_t_df['t'], rag_t_df['t14_rag_raw_t_pred'])['overall'])
print()
kewltm_t_results = []
t_label = 't'
run_lst = [0, 1, 2, 3, 4, 5, 6, 8]

for run in run_lst:
    t_test_df = pd.read_csv(
        f"/home/yl3427/cylab/selfCorrectionAgent/result/1114_t14_med42_v2_test_{run}_outof_10runs.csv"
    ).sort_values(by="patient_filename")

    kewltm_t_results.append(
        t14_calculate_metrics(t_test_df[t_label], t_test_df[kewltm_t_stage])
    )
print("KEwLTM T-stage:")
output_tabular_performance(kewltm_t_results)
print_error_counts(kewltm_t_results)

print()
print("KEwRAG T-stage:")
print(t14_calculate_metrics(ltm_rag_t_df['t'], ltm_rag_t_df['t14_ltm_rag1_t_pred'])['overall'])

Med42 T-stage metrics:
ZSCOT T-stage:
{'micro_precision': 0.769, 'micro_recall': 0.769, 'micro_f1': 0.769, 'macro_precision': 0.746, 'macro_recall': 0.678, 'macro_f1': 0.703, 'weighted_f1': 0.77, 'support': 800, 'total_tp': 615, 'total_fp': 185, 'total_fn': 185, 'num_errors': 185}

RAG T-stage:
{'micro_precision': 0.836, 'micro_recall': 0.836, 'micro_f1': 0.836, 'macro_precision': 0.786, 'macro_recall': 0.748, 'macro_f1': 0.764, 'weighted_f1': 0.834, 'support': 800, 'total_tp': 669, 'total_fp': 131, 'total_fn': 131, 'num_errors': 131}

KEwLTM T-stage:
T1       0.813(0.073) 0.759(0.076) 0.783(0.064)
T2       0.855(0.031) 0.913(0.023) 0.882(0.016)
T3       0.869(0.063) 0.703(0.099) 0.770(0.065)
T4       0.630(0.046) 0.615(0.057) 0.621(0.042)
Cat‑Macro 0.792 0.747 0.764
MicroAvg. 0.835(0.025) 0.835(0.025) 0.835(0.025)
MacroAvg. 0.792(0.032) 0.747(0.043) 0.764(0.034)
Run  WrongRows  FN  FP  (FP+FN)
0    92         92  92  184
1    143        143 143 286
2    127        127 127 254
3    115

In [6]:
kewltm_n_stage = 'kepa_n_ans_str'
zscot_n_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1118_n03_med42_v2_test_800.csv').sort_values(by="patient_filename")[["patient_filename", 'n', 'zscot_n_ans_str']]
rag_n_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1120_n03_rag_raw_med42_v2_800.csv').sort_values(by="patient_filename")[["patient_filename", 'n', 'n03_rag_raw_n_pred']]
ltm_rag_n_df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/result/1128_n03_ltm_rag1_med42_v2_800.csv').sort_values(by="patient_filename")[["patient_filename", 'n', 'n03_ltm_rag1_n_pred']]

print("Med42 N-stage metrics:")

print("ZSCOT N-stage:")
print(n03_calculate_metrics(zscot_n_df['n'], zscot_n_df['zscot_n_ans_str'])['overall'])
print()
print("RAG N-stage:")
print(n03_calculate_metrics(rag_n_df['n'], rag_n_df['n03_rag_raw_n_pred'])['overall'])
print()
kewltm_n_results = []
n_label = 'n'
run_lst = [0, 1, 3, 4, 5, 6, 7, 9]

for run in run_lst:
    n_test_df = pd.read_csv(
        f"/home/yl3427/cylab/selfCorrectionAgent/result/1114_n03_med42_v2_test_{run}_outof_10runs.csv"
    ).sort_values(by="patient_filename")

    kewltm_n_results.append(
        n03_calculate_metrics(n_test_df[n_label], n_test_df[kewltm_n_stage])
    )
print("KEwLTM N-stage:")
output_tabular_performance(kewltm_n_results, categories=["N0", "N1", "N2", "N3"])
print_error_counts(kewltm_n_results)

print()
print("KEwRAG N-stage:")
print(n03_calculate_metrics(ltm_rag_n_df['n'], ltm_rag_n_df['n03_ltm_rag1_n_pred'])['overall'])

Med42 N-stage metrics:
ZSCOT N-stage:
{'micro_precision': 0.738, 'micro_recall': 0.738, 'micro_f1': 0.738, 'macro_precision': 0.748, 'macro_recall': 0.723, 'macro_f1': 0.724, 'weighted_f1': 0.742, 'support': 800, 'total_tp': 590, 'total_fp': 210, 'total_fn': 210, 'num_errors': 210}

RAG N-stage:
{'micro_precision': 0.79, 'micro_recall': 0.79, 'micro_f1': 0.79, 'macro_precision': 0.76, 'macro_recall': 0.799, 'macro_f1': 0.759, 'weighted_f1': 0.796, 'support': 800, 'total_tp': 632, 'total_fp': 168, 'total_fn': 168, 'num_errors': 168}

KEwLTM N-stage:
N0       0.950(0.011) 0.821(0.059) 0.879(0.032)
N1       0.775(0.056) 0.821(0.032) 0.795(0.022)
N2       0.657(0.067) 0.711(0.076) 0.675(0.018)
N3       0.759(0.103) 0.858(0.029) 0.800(0.062)
Cat‑Macro 0.785 0.803 0.787
MicroAvg. 0.809(0.024) 0.809(0.024) 0.809(0.024)
MacroAvg. 0.785(0.021) 0.803(0.025) 0.788(0.026)
Run  WrongRows  FN  FP  (FP+FN)
0    142        142 142 284
1    140        140 140 280
2    93         93  93  186
3    148   

# Error analysis

In [None]:
import pandas as pd
df = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/mixtral_rag_result/0929_ltm_rag2.csv')
df.columns
# new_df2 = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_0_outof_10runs.csv")
# new_df2.columns

In [9]:
import pandas as pd
from pathlib import Path
from typing import Tuple, Optional


###############################################################################
# Helper – map the integer label to the string used in the predictions
###############################################################################
def _canon_label(x, stage: str) -> str:
    """
    Convert the integer‑coded ground‑truth to the canonical string that appears
    in prediction strings.

        T‑stage: 0 → "T1", 1 → "T2", …
        N‑stage: 0 → "N0", 1 → "N1", …
    """
    stage = stage.lower()
    if stage == "t":
        return f"T{int(x) + 1}"
    elif stage == "n":
        return f"N{int(x)}"
    else:
        raise ValueError(f'Unknown stage "{stage}". Use "t" or "n".')


###############################################################################
# Core – compare two methods and write three CSVs                             #
###############################################################################
def compare_error_cases(
    df_baseline: pd.DataFrame,
    df_method: pd.DataFrame,
    *,
    stage: str,
    true_col: str,
    base_pred_col: str,
    base_reason_col: Optional[str],
    meth_pred_col: str,
    meth_reason_col: Optional[str],
    patient_id_col: str = "patient_filename",
    out_common_csv: str | Path,
    out_base_only_csv: str | Path,
    out_meth_only_csv: str | Path,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Create three CSVs: common errors, baseline‑only errors, method‑only errors.
    Returns the three data frames.
    """

    # -------- flag errors row‑wise ----------------------------------- #
    def _flag(df: pd.DataFrame, pred_col: str) -> pd.DataFrame:
        canon = df[true_col].apply(lambda v: _canon_label(v, stage))
        wrong = ~df.apply(
            lambda r: canon.loc[r.name] in str(r[pred_col]).upper(), axis=1
        )
        out = df.copy()
        out["canon_truth"] = canon
        out["is_error"] = wrong
        return out

    base_f = _flag(df_baseline, base_pred_col)
    meth_f = _flag(df_method, meth_pred_col)

    # -------- build id sets ------------------------------------------ #
    base_err_ids = set(base_f.loc[base_f.is_error, patient_id_col])
    meth_err_ids = set(meth_f.loc[meth_f.is_error, patient_id_col])

    common_ids    = base_err_ids & meth_err_ids
    base_only_ids = base_err_ids - meth_err_ids
    meth_only_ids = meth_err_ids - base_err_ids

    # -------- helpers to pick & merge -------------------------------- #
    def _subset(df, ids, keep_pred, keep_reason):
        cols = [patient_id_col, true_col, "canon_truth", keep_pred, "is_error"]
        if keep_reason is not None and keep_reason in df.columns:
            cols.insert(3, keep_reason)
        return df.loc[df[patient_id_col].isin(ids), cols].copy()

    common_df = _subset(base_f, common_ids, base_pred_col, base_reason_col).merge(
        _subset(meth_f, common_ids, meth_pred_col, meth_reason_col),
        on=[patient_id_col],
        suffixes=("_base", "_meth"),
        how="left",
    )
    base_only_df = _subset(base_f, base_only_ids, base_pred_col, base_reason_col)
    meth_only_df = _subset(meth_f, meth_only_ids, meth_pred_col, meth_reason_col)

    # -------- write to disk ------------------------------------------ #
    Path(out_common_csv).parent.mkdir(parents=True, exist_ok=True)
    common_df.to_csv(out_common_csv, index=False)
    base_only_df.to_csv(out_base_only_csv, index=False)
    meth_only_df.to_csv(out_meth_only_csv, index=False)

    print(f"[saved] {out_common_csv}   ({len(common_df)} rows)")
    print(f"[saved] {out_base_only_csv} ({len(base_only_df)} rows)")
    print(f"[saved] {out_meth_only_csv} ({len(meth_only_df)} rows)")

    return common_df, base_only_df, meth_only_df


In [11]:
# ------------------------------------------------------------------ #
# 1)  ZSCOT  vs  KEwLTM                                              #
# ------------------------------------------------------------------ #
base_df  = pd.read_csv(
    "/home/yl3427/cylab/selfCorrectionAgent/mixtral_rag_result/0929_ltm_rag2.csv"
)
ltm_df   = pd.read_csv(
    "/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_0_outof_10runs.csv"
)

print("\n--- Comparing ZSCOT vs. KEwLTM (T‑stage) ---")
df_zscot_t = base_df[[
    "patient_filename", "t", "text",
    "zscot_t_stage", "zscot_t_reasoning"
]]
df_zscot_t = df_zscot_t[df_zscot_t["patient_filename"].isin(ltm_df["patient_filename"])]
df_kewltm_t = ltm_df[[
    "patient_filename", "t", "text",
    "cmem_t_40reports_ans_str", "cmem_t_40reasoning"
]]

compare_error_cases(
    df_baseline       = df_zscot_t,
    df_method         = df_kewltm_t,
    stage             = "t",
    true_col          = "t",
    base_pred_col     = "zscot_t_stage",
    base_reason_col   = "zscot_t_reasoning",
    meth_pred_col     = "cmem_t_40reports_ans_str",
    meth_reason_col   = "cmem_t_40reasoning",
    out_common_csv    = "t_common_errors_zscot_kewltm.csv",
    out_base_only_csv = "t_zscot_only_errors_vs_kewltm.csv",
    out_meth_only_csv = "t_kewltm_only_errors_vs_zscot.csv",
)

# ------------------------------------------------------------------ #
# 2)  RAG  vs  KEwRAG                                                #
# ------------------------------------------------------------------ #
print("\n--- Comparing RAG vs. KEwRAG (T‑stage) ---")
df_rag_t = base_df[[
    "patient_filename", "t", "text",
    "rag_raw_t_stage", "rag_raw_t_reasoning"
]]

df_kewrag_t = base_df[[
    "patient_filename", "t", "text",
    "ltm_rag1_t_stage", "ltm_rag1_t_reasoning"
]]

compare_error_cases(
    df_baseline       = df_rag_t,
    df_method         = df_kewrag_t,
    stage             = "t",
    true_col          = "t",
    base_pred_col     = "rag_raw_t_stage",
    base_reason_col   = "rag_raw_t_reasoning",
    meth_pred_col     = "ltm_rag1_t_stage",
    meth_reason_col   = "ltm_rag1_t_reasoning",
    out_common_csv    = "t_common_errors_rag_kewrag.csv",
    out_base_only_csv = "t_rag_only_errors_vs_kewrag.csv",
    out_meth_only_csv = "t_kewrag_only_errors_vs_rag.csv",
)



--- Comparing ZSCOT vs. KEwLTM (T‑stage) ---
[saved] t_common_errors_zscot_kewltm.csv   (53 rows)
[saved] t_zscot_only_errors_vs_kewltm.csv (46 rows)
[saved] t_kewltm_only_errors_vs_zscot.csv (26 rows)

--- Comparing RAG vs. KEwRAG (T‑stage) ---
[saved] t_common_errors_rag_kewrag.csv   (67 rows)
[saved] t_rag_only_errors_vs_kewrag.csv (81 rows)
[saved] t_kewrag_only_errors_vs_rag.csv (55 rows)


(                                     patient_filename  t_base  \
 0   TCGA-5L-AAT0.F9B6971F-23C0-465F-BFEC-778BF228A1AE       1   
 1   TCGA-5L-AAT1.B5CA42BB-9514-42C6-9FB0-C8889C1DC51A       1   
 2   TCGA-A2-A0CK.B065FC65-CD33-4878-AE2C-7E8C04F5ECAB       2   
 3   TCGA-A2-A0ET.E9D3FFF1-5FB2-4F17-9C1D-D9775E3CC5AC       1   
 4   TCGA-A2-A0SV.161E2817-7DB2-46F8-BFEB-256DBBEFE633       1   
 ..                                                ...     ...   
 62  TCGA-PL-A8LV.D35DBECD-5241-4562-85CC-2822BB338279       3   
 63  TCGA-PL-A8LX.E6DD0840-4D71-4EEC-B559-F6BFC9E7E68B       3   
 64  TCGA-PL-A8LY.8C97B391-96B4-468D-AAA3-24E196DE03CA       2   
 65  TCGA-PL-A8LZ.436F3280-98C7-4FA9-BD6F-1B02CAF1D262       3   
 66  TCGA-S3-AA15.DD2B9E47-8C67-4599-B0B6-0D30DE727B55       1   
 
    canon_truth_base                                rag_raw_t_reasoning  \
 0                T2  According to the pathology report, the patient...   
 1                T2  According to the pathology report,

In [None]:
# import pandas as pd
# from typing import List, Tuple

# def _error_mask(
#         df: pd.DataFrame,
#         stage_prefix: str,
#         true_col: str,
#         pred_col: str,
#         labels: List[int] | None,
#     ) -> pd.Series:

#     def canon(gt: int) -> str:
#         # Correctly formats ground truth labels (e.g., T1, N0)
#         return f"{stage_prefix}{int(gt)+1}" if stage_prefix == "T" else f"{stage_prefix}{int(gt)}"

#     # Normalise predictions (e.g., "NO" to "N0")
#     def norm_pred(p: str) -> str:
#         p_str = str(p).upper() # Ensure p is string before upper()
#         if stage_prefix == "N":
#             p_str = p_str.replace("NO", "N0").replace("NL", "N1") # Common variations
#         return p_str

#     df_tmp = df.copy()
#     # Ensure true labels are integers before applying canon
#     try:
#         df_tmp["__gt_int"] = df_tmp[true_col].astype(int)
#     except ValueError:
#         # Handle cases where true_col might already be like "T1", "N0" if not careful
#         # This function expects integer true_col that it will convert to T1/N0 format.
#         # For robustness, one might add more sophisticated parsing if true_col format varies.
#         raise ValueError(f"Column '{true_col}' must contain integer representations of stages for _error_mask.")

#     df_tmp["__gt"]   = df_tmp["__gt_int"].apply(canon)
#     df_tmp["__pred"] = df_tmp[pred_col].astype(str).apply(norm_pred) # ensure pred_col is str

#     # Determine the full set of possible canonical tags
#     if labels is None:
#         # Use the integer labels from __gt_int to form all_tags
#         unique_int_labels = sorted(df_tmp["__gt_int"].unique())
#     else:
#         unique_int_labels = sorted(list(set(labels))) # Use provided labels if available
    
#     all_tags = {canon(i) for i in unique_int_labels}


#     # Row‑wise error test
#     def row_is_error(row) -> bool:
#         pred_text = row["__pred"]
#         gt_text   = row["__gt"]

#         # An error if:
#         # 1. The ground truth tag is NOT in the prediction string (False Negative component)
#         # OR
#         # 2. Any OTHER tag (not the ground truth) IS in the prediction string (False Positive component)
#         contains_gt    = gt_text in pred_text
#         # Check if any tag from all_tags is in pred_text, AND that tag is not the gt_text
#         contains_other = any(tag in pred_text and tag != gt_text for tag in all_tags)
        
#         is_error = (not contains_gt) or contains_other
#         return is_error

#     return df_tmp.apply(row_is_error, axis=1)


# def compare_error_cases(
#         df_baseline: pd.DataFrame,
#         df_method:   pd.DataFrame,
#         *,
#         stage: str,                           # "t" or "n"
#         id_col: str           = "patient_filename",
#         true_col: str         = "t",          # or "n" (ensure this matches the stage and contains integer labels)
#         base_pred_col: str    = "prediction",
#         base_reason_col: str  = "reasoning",
#         meth_pred_col: str    = "prediction",
#         meth_reason_col: str  = "reasoning",
#         text_col: str         = "text",
#         labels: List[int] | None = None, # Integer labels for the stage
#         out_common_csv: str = "common_errors.csv", 
#         out_base_only_csv: str = "baseline_only_errors.csv",
#         out_meth_only_csv: str = "method_only_errors.csv",
#     ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: 

#     stage_prefix = "T" if stage.lower().startswith("t") else "N"

#     df_baseline[id_col] = df_baseline[id_col].astype(str)
#     df_method[id_col] = df_method[id_col].astype(str)
    
#     common_ids = set(df_baseline[id_col]) & set(df_method[id_col])
    
#     df_b = df_baseline[df_baseline[id_col].isin(common_ids)].sort_values(by=id_col).reset_index(drop=True)
#     df_m = df_method[df_method[id_col].isin(common_ids)].sort_values(by=id_col).reset_index(drop=True)

#     mask_b = _error_mask(df_b, stage_prefix, true_col, base_pred_col, labels)
#     mask_m = _error_mask(df_m, stage_prefix, true_col, meth_pred_col, labels)

#     common_error_ids = set(df_b.loc[mask_b & mask_m, id_col])
#     base_only_error_ids  = set(df_b.loc[mask_b & ~mask_m, id_col])
#     meth_only_error_ids  = set(df_m.loc[mask_m & ~mask_b, id_col])

#     output_cols_base = [id_col, true_col, base_pred_col, base_reason_col, text_col]
#     output_cols_meth = [id_col, true_col, meth_pred_col, meth_reason_col, text_col]
    
#     common_errors_df_b = df_b[df_b[id_col].isin(common_error_ids)][output_cols_base].rename(
#         columns={base_pred_col: "baseline_pred", base_reason_col: "baseline_reason", true_col: f"{true_col}_true"}
#     )
#     common_errors_df_m = df_m[df_m[id_col].isin(common_error_ids)][[id_col, meth_pred_col, meth_reason_col]].rename(
#         columns={meth_pred_col: "method_pred", meth_reason_col: "method_reason"}
#     )
#     common_errors_df = pd.merge(common_errors_df_b, common_errors_df_m, on=id_col, how="left")
#     # Ensure original true_col name is preserved if it was renamed
#     if f"{true_col}_true" in common_errors_df.columns:
#          common_errors_df.rename(columns={f"{true_col}_true": true_col}, inplace=True)

#     # Reorder columns for clarity
#     common_errors_df = common_errors_df[[
#         id_col, true_col, 
#         "baseline_pred", "baseline_reason", 
#         "method_pred", "method_reason", 
#         text_col
#     ]]

#     baseline_only_errors_df = (
#         df_b[df_b[id_col].isin(base_only_error_ids)]
#         [output_cols_base]
#         .rename(columns={
#             base_pred_col:  "baseline_pred",
#             base_reason_col:"baseline_reason"})
#         .reset_index(drop=True)
#     )

#     method_only_errors_df = (
#         df_m[df_m[id_col].isin(meth_only_error_ids)]
#         [output_cols_meth]
#         .rename(columns={
#             meth_pred_col:  "method_pred",
#             meth_reason_col:"method_reason"})
#         .reset_index(drop=True)
#     )

#     common_errors_df.to_csv(out_common_csv, index=False)
#     baseline_only_errors_df.to_csv(out_base_only_csv, index=False)
#     method_only_errors_df.to_csv(out_meth_only_csv, index=False)
    
#     print(f"Saved {len(common_errors_df)} common errors between baseline and method                 → {out_common_csv}")
#     print(f"Saved {len(baseline_only_errors_df)} errors unique to baseline ({base_pred_col})        → {out_base_only_csv}")
#     print(f"Saved {len(method_only_errors_df)} errors unique to method ({meth_pred_col})          → {out_meth_only_csv}")

#     return common_errors_df, baseline_only_errors_df, method_only_errors_df

# # --- Example Usage (Illustrative - replace with your actual DataFrames and filenames) ---
# if __name__ == '__main__':
#     base_df  = pd.read_csv('/home/yl3427/cylab/selfCorrectionAgent/mixtral_rag_result/0929_ltm_rag2.csv')
#     ltm_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_0_outof_10runs.csv")

#     print("\n--- Comparing ZSCOT vs. KEwLTM (T-stage) ---")
#     df_zscot_t = base_df[[
#         "patient_filename", "t", "text", 
#         "zscot_t_stage", "zscot_t_reasoning"
#     ]].copy()
#     df_kewltm_t = ltm_df[[
#         "patient_filename", "t", "text",
#         "cmem_t_40reports_ans_str", "cmem_t_40reasoning"
#     ]].copy()

#     common_z_k_t, z_only_t, k_only_t = compare_error_cases(
#         df_baseline=df_zscot_t,
#         df_method=df_kewltm_t,
#         stage="t",
#         true_col="t", 
#         base_pred_col="zscot_t_stage",
#         base_reason_col="zscot_t_reasoning",
#         meth_pred_col="cmem_t_40reports_ans_str",
#         meth_reason_col="cmem_t_40reasoning",
#         out_common_csv="t_common_errors_zscot_kewltm.csv",
#         out_base_only_csv="t_zscot_only_errors_vs_kewltm.csv",
#         out_meth_only_csv="t_kewltm_only_errors_vs_zscot.csv"
#     )
    
#     # --- Error Comparison: Pair 2: RAG vs KEwRAG (T-stage) ---
#     print("\n--- Comparing RAG vs. KEwRAG (T-stage) ---")
#     df_rag_t = base_df[[
#         "patient_filename", "t", "text",
#         "rag_raw_t_stage", "rag_raw_t_reasoning"
#     ]].copy()
#     df_kewrag_t = base_df[[
#         "patient_filename", "t", "text",
#         "ltm_rag1_t_stage", "ltm_rag1_t_reasoning"
#     ]].copy()
    
#     common_r_kr_t, r_only_t, kr_only_t = compare_error_cases(
#         df_baseline=df_rag_t,
#         df_method=df_kewrag_t,
#         stage="t",
#         true_col="t",
#         base_pred_col="rag_raw_t_stage",
#         base_reason_col="rag_raw_t_reasoning",
#         meth_pred_col="ltm_rag1_t_stage",
#         meth_reason_col="ltm_rag1_t_reasoning",
#         labels=t_stage_int_labels, # Pass integer labels
#         out_common_csv="t_common_errors_rag_kewrag.csv",
#         out_base_only_csv="t_rag_only_errors_vs_kewrag.csv",
#         out_meth_only_csv="t_kewrag_only_errors_vs_rag.csv"
#     )

# Plot (sensitive)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json

x_axis = np.array(range(1, 11)) * 10

memory_precision_cumulative = []
memory_recall_cumulative = []
memory_f1_cumulative = []


for run in [0, 1, 2, 3, 4, 5, 6, 8]:
    test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_t14_dynamic_test_{run}_outof_10runs.csv")

    for i in np.array(range(1, 11)): # memory (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
        result = t14_calculate_metrics(test_df['t'], test_df[f'cmem_t_{i*10}reports_ans_str'])['overall']
        if run == 0:
            memory_precision_cumulative.append(result['macro_precision'])
            memory_recall_cumulative.append(result['macro_recall'])
            memory_f1_cumulative.append(result['macro_f1'])
        else:
            memory_precision_cumulative[i-1] += result['macro_precision']
            memory_recall_cumulative[i-1] += result['macro_recall']
            memory_f1_cumulative[i-1] += result['macro_f1']



# average
precision_avg = [p / 8 for p in memory_precision_cumulative]
recall_avg = [r / 8 for r in memory_recall_cumulative]
f1_avg = [f / 8  for f in memory_f1_cumulative]


plt.figure(figsize=(15, 10))

plt.plot(x_axis, precision_avg, label='Average KEwLTM Precision', color='blue', marker='o')
plt.plot(x_axis, recall_avg, label='Average KEwLTM Recall', color='green', marker='o')
plt.plot(x_axis, f1_avg, label='Average KEwLTM F1 Score', color='red', marker='o')

plt.xlabel('Number of Training Reports')
plt.ylabel('Scores')
# plt.title(f'The Average of 10 Results on 700 Test Reports (t14)')
plt.legend()
plt.grid(True)

plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json

x_axis = np.array(range(1, 11)) * 10

memory_precision_cumulative = []
memory_recall_cumulative = []
memory_f1_cumulative = []


for run in [0, 1, 3, 4, 5, 6, 7, 9]:
    test_df = pd.read_csv(f"/home/yl3427/cylab/selfCorrectionAgent/result/0718_n03_dynamic_test_{run}_outof_10runs.csv")

    for i in np.array(range(1, 11)): # memory (10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
        result = n03_calculate_metrics(test_df['n'], test_df[f'cmem_n_{i*10}reports_ans_str'])['overall']
        if run == 0:
            memory_precision_cumulative.append(result['macro_precision'])
            memory_recall_cumulative.append(result['macro_recall'])
            memory_f1_cumulative.append(result['macro_f1'])
        else:
            memory_precision_cumulative[i-1] += result['macro_precision']
            memory_recall_cumulative[i-1] += result['macro_recall']
            memory_f1_cumulative[i-1] += result['macro_f1']



# average
precision_avg = [p / 8 for p in memory_precision_cumulative]
recall_avg = [r / 8 for r in memory_recall_cumulative]
f1_avg = [f / 8  for f in memory_f1_cumulative]


plt.figure(figsize=(15, 10))

plt.plot(x_axis, precision_avg, label='Average KEwLTM Precision', color='blue', marker='o')
plt.plot(x_axis, recall_avg, label='Average KEwLTM Recall', color='green', marker='o')
plt.plot(x_axis, f1_avg, label='Average KEwLTM F1 Score', color='red', marker='o')

plt.xlabel('Number of Training Reports')
plt.ylabel('Scores')
# plt.title(f'The Average of 10 Results on 700 Test Reports (t14)')
plt.legend()
plt.grid(True)

plt.show()