In [1]:
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json  # only needed if save_summary_json=True

def analyze_and_export_mistakes(
    pred_csv_path: str,
    out_mistakes_csv_path: str,
    charts_dir: str | None = None,
    truth_col: str = "class_label",
    pred_col: str = "predicted_label",
    id_col: str = "tweet_id",
    text_col: str = "tweet_text",
    save_summary_json: bool = True,
):
    """
    pandas-first evaluation:
      - loads predictions CSV
      - exports misclassified rows
      - computes accuracy, macro-F1, per-class metrics using crosstab
      - saves charts and CSVs
    Returns: (mistakes_df, summary_dict, per_class_df, conf_mat_df)
    """
    df = pd.read_csv(pred_csv_path)
    for c in (truth_col, pred_col):
        if c not in df.columns:
            raise ValueError(f"Column '{c}' not found in {pred_csv_path}")

    # restrict to rows with truth labels present
    df_eval = df[df[truth_col].astype(str).str.len() > 0].copy()

    # mistakes
    mistakes_df = df_eval.loc[df_eval[truth_col] != df_eval[pred_col]].copy()
    out_p = pathlib.Path(out_mistakes_csv_path)
    out_p.parent.mkdir(parents=True, exist_ok=True)
    mistakes_df.to_csv(out_p, index=False)

    # confusion matrix (rows=true, cols=pred)
    labels = sorted(set(df_eval[truth_col]) | set(df_eval[pred_col]))
    conf_mat_df = pd.crosstab(
        df_eval[truth_col],
        df_eval[pred_col],
        dropna=False
    ).reindex(index=labels, columns=labels, fill_value=0)

    # per-class metrics (vectorized)
    tp = np.diag(conf_mat_df.values)
    support_true = conf_mat_df.sum(axis=1).values
    support_pred = conf_mat_df.sum(axis=0).values
    fp = support_pred - tp
    fn = support_true - tp

    precision = np.divide(tp, tp + fp, out=np.zeros_like(tp, dtype=float), where=(tp + fp) != 0)
    recall    = np.divide(tp, tp + fn, out=np.zeros_like(tp, dtype=float), where=(tp + fn) != 0)
    f1        = np.divide(2*precision*recall, precision+recall,
                          out=np.zeros_like(tp, dtype=float), where=(precision+recall) != 0)
    error_rate = np.divide(fn + fp, support_true + fp, out=np.zeros_like(tp, dtype=float),
                           where=(support_true + fp) != 0)

    per_class_df = pd.DataFrame({
        "label": labels,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "support": support_true.astype(int),
        "error_rate": error_rate
    }).sort_values("label")

    accuracy = (tp.sum() / conf_mat_df.values.sum()) if conf_mat_df.values.sum() else 0.0
    macro_f1 = float(per_class_df["f1"].mean()) if not per_class_df.empty else 0.0

    summary = {
        "num_total_with_truth": int(len(df_eval)),
        "num_correct": int(tp.sum()),
        "num_incorrect": int(len(mistakes_df)),
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "labels": labels,
    }

    # charts & tables
    if charts_dir:
        charts_dir = pathlib.Path(charts_dir)
        charts_dir.mkdir(parents=True, exist_ok=True)

        # confusion matrix (counts)
        plt.figure(figsize=(8 + 0.3*len(labels), 6 + 0.3*len(labels)))
        plt.imshow(conf_mat_df.values, interpolation="nearest")
        plt.title("Confusion Matrix (counts)")
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.xticks(range(len(labels)), labels, rotation=90)
        plt.yticks(range(len(labels)), labels)
        plt.tight_layout()
        plt.savefig(charts_dir / "confusion_matrix.png", dpi=200)
        plt.close()

        # per-class F1
        plt.figure(figsize=(max(8, 0.6*len(labels)), 5))
        plt.bar(per_class_df["label"], per_class_df["f1"])
        plt.title("Per-class F1")
        plt.xticks(rotation=90)
        plt.ylim(0, 1)
        plt.tight_layout()
        plt.savefig(charts_dir / "per_class_f1.png", dpi=200)
        plt.close()

        # per-class error rate
        plt.figure(figsize=(max(8, 0.6*len(labels)), 5))
        plt.bar(per_class_df["label"], per_class_df["error_rate"])
        plt.title("Per-class Error Rate")
        plt.xticks(rotation=90)
        plt.ylim(0, 1)
        plt.tight_layout()
        plt.savefig(charts_dir / "per_class_error_rate.png", dpi=200)
        plt.close()

        # top confusions (off-diagonal)
        C = conf_mat_df.values
        pairs = [(labels[i], labels[j], int(C[i, j])) for i in range(len(labels)) for j in range(len(labels)) if i != j and C[i, j] > 0]
        pairs.sort(key=lambda x: x[2], reverse=True)
        top_k = pairs[:15]
        if top_k:
            plt.figure(figsize=(10, max(4, 0.4*len(top_k))))
            ylabels = [f"{t} → {p}" for t, p, _ in top_k]
            counts = [c for _, _, c in top_k]
            y = np.arange(len(top_k))
            plt.barh(y, counts)
            plt.yticks(y, ylabels)
            plt.gca().invert_yaxis()
            plt.title("Top Confusions (off-diagonal)")
            plt.tight_layout()
            plt.savefig(charts_dir / "top_confusions.png", dpi=200)
            plt.close()

        # save numeric summaries
        per_class_df.to_csv(charts_dir / "per_class_metrics.csv", index=False)
        conf_mat_df.to_csv(charts_dir / "confusion_matrix.csv")
        if save_summary_json:
            with open(charts_dir / "summary.json", "w") as f:
                json.dump(summary, f, indent=2)

    return mistakes_df, summary, per_class_df, conf_mat_df


In [2]:
pred_csv = "runs/california_wildfires_2018/dev/gpt-4o-mini/20251017-192014-modeS/predictions.csv"

mistakes_df, summary, per_cls, conf_df = analyze_and_export_mistakes(
    pred_csv_path=pred_csv,
    out_mistakes_csv_path="analysis/california_wildfires_2018/dev/gpt-4o-mini/mistakes.csv",
    charts_dir="analysis/california_wildfires_2018/dev/gpt-4o-mini/charts",
    truth_col="class_label",          # adjust if your CSV uses different names
    pred_col="predicted_label",
    id_col="tweet_id",
    text_col="tweet_text",
)

summary, mistakes_df.head()


({'num_total_with_truth': 752,
  'num_correct': 529,
  'num_incorrect': 223,
  'accuracy': np.float64(0.7034574468085106),
  'macro_f1': 0.590928846013498,
  'labels': ['caution_and_advice',
   'displaced_people_and_evacuations',
   'infrastructure_and_utility_damage',
   'injured_or_dead_people',
   'missing_or_found_people',
   'not_humanitarian',
   'other_relevant_information',
   'requests_or_urgent_needs',
   'rescue_volunteering_or_donation_effort',
   'sympathy_and_support']},
                tweet_id                                         tweet_text  \
 3   1062711111869333504  BBC News - California wildfires: Nine dead and...   
 10  1064454370731982848  @lucydragonn We are cooperating with the Count...   
 11  1064237337775767552  Wondering what steps you can take to make Cali...   
 12  1065590550726885379  Due to substrate Southern California (Malibu) ...   
 13  1062489445172166656  Nearly 9,000 firefighters battling the #Califo...   
 
                                cl