In [None]:
import pandas as pd

# === EDIT THESE PATHS ===
pkl_path = "slake_df_captions"    # pickle file with extra columns
csv2_path = "Results/qwen7b_oneword_filtered.csv"  # base CSV that will be overwritten
# ========================

# Load first file (pickle, assumed to be a pandas DataFrame)
df1 = pd.read_pickle(pkl_path)

# Load second file (CSV)
df2 = pd.read_csv(csv2_path)

# Find columns in df1 that are NOT already in df2
unique_cols_df1 = [col for col in df1.columns if col not in df2.columns]

# Take only those unique columns from df1
df1_unique = df1[unique_cols_df1]

# Combine: keep all columns from df2, add unique ones from df1
# Rows are aligned by index; if lengths differ, you'll get NaNs where missing
combined = pd.concat([df2, df1_unique], axis=1)

# Overwrite the second CSV with the combined result
combined.to_csv(csv2_path, index=False)
print(f"Combined file saved to: {csv2_path}")


In [None]:
# ============================================================
# Install dependencies (run this cell once in Colab)
# ============================================================

# ============================================================
# Imports
# ============================================================
import pandas as pd
import numpy as np
import torch

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util
from bert_score import score as bertscore_score

# ============================================================
# Load your CSV
# ============================================================
csv_path = "Results/qwen7b_oneword_filtered.csv"  # <-- EDIT THIS
df = pd.read_csv(csv_path)

assert "qwen_answer" in df.columns, "Column 'qwen_answer' not found in CSV"
assert "gt_answer" in df.columns, "Column 'gt_answer' not found in CSV"

short_answers = df["qwen_answer"].astype(str).tolist()
gt_answers    = df["gt_answer"].astype(str).tolist()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ============================================================
# 1. BERTScore (F1 only) using RoBERTa
# ============================================================
print("Computing BERTScore F1...")

P, R, F1 = bertscore_score(
    short_answers,
    gt_answers,
    lang="en",
    model_type="roberta-large"
)

df["bertscore_f1"] = F1.cpu().numpy()

# ============================================================
# 2. Sentence-BERT similarity (single score, called sbert_f1)
# ============================================================
print("Computing Sentence-BERT similarity...")

sbert_model_name = "all-MiniLM-L6-v2"
sbert_model = SentenceTransformer(sbert_model_name, device=str(device))

emb_short = sbert_model.encode(
    short_answers,
    convert_to_tensor=True,
    batch_size=32,
    show_progress_bar=True,
)
emb_gt = sbert_model.encode(
    gt_answers,
    convert_to_tensor=True,
    batch_size=32,
    show_progress_bar=True,
)

cosine_matrix = util.cos_sim(emb_short, emb_gt)
cosine_diag = cosine_matrix.diag().cpu().numpy()

# Treat cosine similarity as the SBERT "F1-like" score
df["sbert_f1"] = cosine_diag

# ============================================================
# 3. NLI with RoBERTa (entailment & contradiction means only)
# ============================================================
print("Loading RoBERTa NLI model...")

nli_model_name = "roberta-large-mnli"
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device)

# label indices: 0 = contradiction, 1 = neutral, 2 = entailment
def nli_directional_scores(premises, hypotheses, batch_size=16, max_length=256):
    entail_list = []
    contra_list = []

    nli_model.eval()
    with torch.no_grad():
        for i in range(0, len(premises), batch_size):
            batch_p = premises[i:i+batch_size]
            batch_h = hypotheses[i:i+batch_size]

            enc = nli_tokenizer(
                batch_p,
                batch_h,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt"
            ).to(device)

            logits = nli_model(**enc).logits
            probs = torch.softmax(logits, dim=-1)

            contra = probs[:, 0]   # contradiction
            entail = probs[:, 2]   # entailment

            contra_list.extend(contra.cpu().numpy())
            entail_list.extend(entail.cpu().numpy())

    return np.array(entail_list), np.array(contra_list)

print("Computing NLI scores (short -> gt)...")
ent_s2g, contra_s2g = nli_directional_scores(short_answers, gt_answers)

print("Computing NLI scores (gt -> short)...")
ent_g2s, contra_g2s = nli_directional_scores(gt_answers, short_answers)

# Only keep the means
df["nli_entail_mean"] = (ent_s2g + ent_g2s) / 2.0
df["nli_contra_mean"] = (contra_s2g + contra_g2s) / 2.0

# ============================================================
# Save back to the original CSV (append columns)
# ============================================================
df.to_csv(csv_path, index=False)

print("\nDone! Appended these columns to your CSV:")
print("  - bertscore_f1")
print("  - sbert_f1")
print("  - nli_entail_mean")
print("  - nli_contra_mean")

df.head()


In [None]:
#!/usr/bin/env python3
"""
Analyze multiple CSV files with score columns.

- Computes overall mean of each score per file.
- Computes cluster-averages by:
    - modality
    - location
    - content_type
- Produces plots (matplotlib) to compare scores across files and clusters.

Usage:
    Just run: python analyze_scores.py
"""

import os
from typing import List

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# ---------------------------------------------------------------------
# YOUR CSV PATHS
# ---------------------------------------------------------------------
CSV_PATHS: List[str] = [
    "Results/qwen7b_oneword_inaccurate.csv",
    "Results/qwen7b_oneword_irrelevant.csv",
    "Results/qwen7b_oneword_missing.csv",
    "Results/qwen7b_oneword_noisy.csv",
    "Results/qwen7b_oneword_original.csv",
    "Results/qwen7b_oneword_severely_inaccurate.csv",
]

# ---------------------------------------------------------------------
# Desired order on the x-axis (severely inaccurate FIRST)
# ---------------------------------------------------------------------
FILE_ORDER: List[str] = [
    "qwen7b_oneword_severely_inaccurate.csv",
    "qwen7b_oneword_inaccurate.csv",
    "qwen7b_oneword_irrelevant.csv",
    "qwen7b_oneword_missing.csv",
    "qwen7b_oneword_noisy.csv",
    "qwen7b_oneword_original.csv",
]

# ---------------------------------------------------------------------
# SCORE COLUMNS: your 4 scores
# ---------------------------------------------------------------------
SCORE_COLUMNS: List[str] = [
    "bertscore_f1",
    "sbert_f1",
    "nli_entail_mean",
    "nli_contra_mean",
]

# Grouping columns
GROUP_COLUMNS = ["modality", "location", "content_type"]


def load_and_tag_csv(path: str) -> pd.DataFrame:
    """Load a CSV and add a 'source_file' column with just the basename."""
    df = pd.read_csv(path)
    df["source_file"] = os.path.basename(path)
    return df


def ensure_output_dir() -> str:
    out_dir = "score_analysis_outputs"
    os.makedirs(out_dir, exist_ok=True)
    return out_dir


def file_label(name: str) -> str:
    """
    Turn 'qwen7b_severely_inaccurate.csv' into 'severely_inaccurate'.
    Generic:
      - strip .csv if present
      - strip 'qwen7b_' prefix if present
    """
    base = name
    if base.endswith(".csv"):
        base = base[:-4]
    prefix = "qwen7b_oneword_"
    if base.startswith(prefix):
        base = base[len(prefix):]
    return base


# ==============================
# PLOTTING HELPERS (matplotlib)
# ==============================

def plot_overall_means_by_file(overall_by_file: pd.DataFrame,
                               score_cols: List[str],
                               out_dir: str,
                               file_order: List[str]):
    """
    For each metric, make a bar chart of mean score per file
    using the specified file_order and short labels on x-axis.
    """
    # Reorder rows to match file_order
    overall_by_file = overall_by_file.set_index("source_file").reindex(file_order).dropna(how="all")
    overall_by_file = overall_by_file.reset_index()

    labels = [file_label(f) for f in overall_by_file["source_file"]]

    for metric in score_cols:
        fig, ax = plt.subplots(figsize=(8, 5))
        ax.bar(labels, overall_by_file[metric])
        ax.set_title(f"Mean {metric} by file")
        ax.set_ylabel("Mean score")
        ax.set_xticklabels(labels, rotation=45, ha="right")
        fig.tight_layout()
        out_path = os.path.join(out_dir, f"overall_mean_{metric}_by_file.png")
        fig.savefig(out_path, bbox_inches="tight")
        plt.close(fig)
        print("Saved plot:", out_path)


def plot_score_distributions_by_file(all_df: pd.DataFrame,
                                     score_cols: List[str],
                                     out_dir: str,
                                     file_order: List[str]):
    """
    For each metric, make a boxplot of the distribution per file,
    with x-axis ordered by file_order and short labels.
    """
    # Only keep files that are actually present
    files_present = [f for f in file_order if f in all_df["source_file"].unique()]
    labels_present = [file_label(f) for f in files_present]

    for metric in score_cols:
        data = [
            all_df.loc[all_df["source_file"] == f, metric].dropna().values
            for f in files_present
        ]

        fig, ax = plt.subplots(figsize=(8, 5))
        ax.boxplot(data)
        ax.set_title(f"Distribution of {metric} by file")
        ax.set_ylabel(metric)
        ax.set_xticks(np.arange(1, len(labels_present) + 1))
        ax.set_xticklabels(labels_present, rotation=45, ha="right")
        fig.tight_layout()
        out_path = os.path.join(out_dir, f"boxplot_{metric}_by_file.png")
        fig.savefig(out_path, bbox_inches="tight")
        plt.close(fig)
        print("Saved plot:", out_path)


def plot_group_overall_bars(col: str,
                            by_group_overall: pd.DataFrame,
                            score_cols: List[str],
                            out_dir: str):
    """
    For each metric, bar chart of overall mean score per group value.
    """
    groups = by_group_overall[col].astype(str).tolist()
    x = np.arange(len(groups))

    for metric in score_cols:
        fig, ax = plt.subplots(figsize=(8, 5))
        y = by_group_overall[metric].values
        ax.bar(x, y)
        ax.set_xticks(x)
        ax.set_xticklabels(groups, rotation=45, ha="right")
        ax.set_ylabel("Mean score")
        ax.set_title(f"Overall mean {metric} by {col}")
        fig.tight_layout()
        out_path = os.path.join(out_dir, f"overall_mean_{metric}_by_{col}.png")
        fig.savefig(out_path, bbox_inches="tight")
        plt.close(fig)
        print("Saved plot:", out_path)


def plot_group_heatmaps(col: str,
                        by_file_and_group: pd.DataFrame,
                        score_cols: List[str],
                        out_dir: str,
                        file_order: List[str]):
    """
    For each metric, heatmap of mean score with:
        y-axis: group values (e.g. modality/location/content_type)
        x-axis: source_file ordered by file_order, using short labels.
    """
    for metric in score_cols:
        # pivot: rows = group, columns = file, values = metric mean
        pivot = by_file_and_group.pivot(index=col, columns="source_file", values=metric)

        # Reorder columns (files) according to file_order
        cols_present = [f for f in file_order if f in pivot.columns]
        pivot = pivot[cols_present]

        fig, ax = plt.subplots(figsize=(8, 6))
        im = ax.imshow(pivot.values, aspect="auto", origin="upper")

        # ticks and labels
        ax.set_xticks(np.arange(len(pivot.columns)))
        short_labels = [file_label(c) for c in pivot.columns]
        ax.set_xticklabels(short_labels, rotation=45, ha="right")
        ax.set_yticks(np.arange(len(pivot.index)))
        ax.set_yticklabels(pivot.index.astype(str))

        ax.set_xlabel("file")
        ax.set_ylabel(col)
        ax.set_title(f"{metric} by {col} and file")

        # colorbar
        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label("Mean score")

        fig.tight_layout()
        out_path = os.path.join(out_dir, f"heatmap_{metric}_by_{col}_and_file.png")
        fig.savefig(out_path, bbox_inches="tight")
        plt.close(fig)
        print("Saved plot:", out_path)


# ==============================
# MAIN ANALYSIS
# ==============================

def main():
    if not CSV_PATHS:
        print("ERROR: Please provide at least one CSV path in CSV_PATHS.")
        return

    # Load and concatenate all CSVs
    dfs = []
    for path in CSV_PATHS:
        if not os.path.isfile(path):
            print(f"WARNING: File not found, skipping: {path}")
            continue
        df = load_and_tag_csv(path)
        dfs.append(df)

    if not dfs:
        print("ERROR: No valid CSV files loaded. Check CSV_PATHS.")
        return

    all_df = pd.concat(dfs, ignore_index=True)

    # Ensure score columns exist
    missing_scores = [c for c in SCORE_COLUMNS if c not in all_df.columns]
    if missing_scores:
        print(f"ERROR: These SCORE_COLUMNS are missing in the data: {missing_scores}")
        return

    score_cols = SCORE_COLUMNS

    # Make source_file a categorical with desired order
    all_df["source_file"] = pd.Categorical(
        all_df["source_file"],
        categories=FILE_ORDER,
        ordered=True,
    )

    # -----------------------------------------------------------------
    # 1. Overall average of scores per CSV file
    # -----------------------------------------------------------------
    overall_by_file = (
        all_df.groupby("source_file", observed=True)[score_cols]
        .mean()
        .reset_index()
    )

    # -----------------------------------------------------------------
    # 2. Cluster-averages by modality, location, content_type
    # -----------------------------------------------------------------
    summaries = {}

    for col in GROUP_COLUMNS:
        if col not in all_df.columns:
            print(f"WARNING: Column '{col}' not found in data. Skipping grouping by this column.")
            continue

        # Per file + group (e.g., file & modality)
        by_file_and_group = (
            all_df.groupby(["source_file", col], observed=True)[score_cols]
            .mean()
            .reset_index()
        )

        # Across all files combined, grouped only by that column
        by_group_overall = (
            all_df.groupby(col, observed=True)[score_cols]
            .mean()
            .reset_index()
        )

        summaries[col] = {
            "by_file_and_group": by_file_and_group,
            "by_group_overall": by_group_overall,
        }

    # -----------------------------------------------------------------
    # Save numeric outputs
    # -----------------------------------------------------------------
    out_dir = ensure_output_dir()

    overall_path = os.path.join(out_dir, "overall_means_by_file.csv")
    overall_by_file.to_csv(overall_path, index=False)
    print("\nSaved overall means per file to:", overall_path)

    for col, data in summaries.items():
        by_file_path = os.path.join(out_dir, f"means_by_file_and_{col}.csv")
        by_group_path = os.path.join(out_dir, f"means_by_{col}_overall.csv")

        data["by_file_and_group"].to_csv(by_file_path, index=False)
        data["by_group_overall"].to_csv(by_group_path, index=False)

        print(f"Saved means by file and {col} to:", by_file_path)
        print(f"Saved overall means by {col} to:", by_group_path)

    # -----------------------------------------------------------------
    # Plotting (matplotlib only)
    # -----------------------------------------------------------------
    # 1) Overall comparison of files
    plot_overall_means_by_file(overall_by_file, score_cols, out_dir, FILE_ORDER)
    plot_score_distributions_by_file(all_df, score_cols, out_dir, FILE_ORDER)

    # 2) Group-based plots
    for col, data in summaries.items():
        print(f"\nCreating plots for grouping by '{col}'...")
        plot_group_overall_bars(
            col=col,
            by_group_overall=data["by_group_overall"],
            score_cols=score_cols,
            out_dir=out_dir,
        )
        plot_group_heatmaps(
            col=col,
            by_file_and_group=data["by_file_and_group"],
            score_cols=score_cols,
            out_dir=out_dir,
            file_order=FILE_ORDER,
        )

    # Quick sanity check in stdout
    print("\n=== Overall means per file (head) ===")
    print(overall_by_file.head())

    for col, data in summaries.items():
        print(f"\n=== Means by file and {col} (head) ===")
        print(data["by_file_and_group"].head())


if __name__ == "__main__":
    main()


In [None]:
#!/usr/bin/env python3
"""
Analyze multiple CSV files with score columns.

- Computes overall mean of each score per file.
- Computes cluster-averages by:
    - modality
    - location
    - content_type
- Produces plots (matplotlib) to compare scores across files and clusters.

Usage:
    Just run: python analyze_scores.py
"""

import os
from typing import List, Dict

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# ---------------------------------------------------------------------
# PLOT STYLE (bigger fonts + less gaudy defaults)
# ---------------------------------------------------------------------
FONT_SCALE = 2.0  # "double the font size"


def set_plot_style(font_scale: float = 2.0) -> None:
    base = float(plt.rcParams.get("font.size", 10.0))
    fs = base * font_scale

    plt.rcParams.update({
        "font.size": fs,
        "axes.titlesize": fs * 1.2,
        "axes.labelsize": fs,
        "xtick.labelsize": fs * 0.9,
        "ytick.labelsize": fs * 0.9,
        "legend.fontsize": fs * 0.9,
        "figure.titlesize": fs * 1.2,

        # Heatmap defaults (muted + readable)
        "image.cmap": "cividis",
        "image.interpolation": "nearest",
    })


# ---------------------------------------------------------------------
# YOUR CSV PATHS
# ---------------------------------------------------------------------
CSV_PATHS: List[str] = [
    "Results/qwen7b_oneword_inaccurate.csv",
    "Results/qwen7b_oneword_irrelevant.csv",
    "Results/qwen7b_oneword_missing.csv",
    "Results/qwen7b_oneword_noisy.csv",
    "Results/qwen7b_oneword_original.csv",
    "Results/qwen7b_oneword_severely_inaccurate.csv",
]

# ---------------------------------------------------------------------
# Desired order on the x-axis (severely inaccurate FIRST)
# ---------------------------------------------------------------------
FILE_ORDER: List[str] = [
    "qwen7b_oneword_severely_inaccurate.csv",
    "qwen7b_oneword_inaccurate.csv",
    "qwen7b_oneword_irrelevant.csv",
    "qwen7b_oneword_missing.csv",
    "qwen7b_oneword_noisy.csv",
    "qwen7b_oneword_original.csv",
]

# ---------------------------------------------------------------------
# SCORE COLUMNS: your 4 scores
# ---------------------------------------------------------------------
SCORE_COLUMNS: List[str] = [
    "bertscore_f1",
    "sbert_f1",
    "nli_entail_mean",
    "nli_contra_mean",
]

# Grouping columns
GROUP_COLUMNS = ["modality", "location", "content_type"]

# ---------------------------------------------------------------------
# Pretty metric names for plots
# ---------------------------------------------------------------------
METRIC_LABELS: Dict[str, str] = {
    "nli_entail_mean": "NLI Entailment",
    "nli_contra_mean": "NLI contradiction",
}


def metric_label(metric: str) -> str:
    """Map raw metric column names to nicer plot labels."""
    return METRIC_LABELS.get(metric, metric)


def load_and_tag_csv(path: str) -> pd.DataFrame:
    """Load a CSV and add a 'source_file' column with just the basename."""
    df = pd.read_csv(path)
    df["source_file"] = os.path.basename(path)
    return df


def ensure_output_dir() -> str:
    out_dir = "score_analysis_outputs"
    os.makedirs(out_dir, exist_ok=True)
    return out_dir


def file_label(name: str) -> str:
    """
    Turn 'qwen7b_oneword_severely_inaccurate.csv' into 'severely_inaccurate',
    with a special case to make 'severely_inaccurate' double-lined on axes.
    """
    base = name
    if base.endswith(".csv"):
        base = base[:-4]
    prefix = "qwen7b_oneword_"
    if base.startswith(prefix):
        base = base[len(prefix):]

    # double-line the x-axis label when it appears
    if base == "severely_inaccurate":
        return "severely\ninaccurate"

    return base


# ==============================
# PLOTTING HELPERS (matplotlib)
# ==============================

def plot_overall_means_by_file(overall_by_file: pd.DataFrame,
                               score_cols: List[str],
                               out_dir: str,
                               file_order: List[str]) -> None:
    """
    For each metric, make a bar chart of mean score per dataset
    using the specified file_order and short labels on x-axis.
    """
    overall_by_file = overall_by_file.set_index("source_file").reindex(file_order).dropna(how="all")
    overall_by_file = overall_by_file.reset_index()

    labels = [file_label(f) for f in overall_by_file["source_file"]]

    for metric in score_cols:
        fig, ax = plt.subplots(figsize=(8, 5))
        ax.bar(np.arange(len(labels)), overall_by_file[metric], color="0.35", alpha=0.9)
        ax.set_title(f"Mean {metric_label(metric)} across data")
        ax.set_ylabel("Mean score")
        ax.set_xticks(np.arange(len(labels)))
        ax.set_xticklabels(labels, rotation=45, ha="right")
        fig.tight_layout()
        out_path = os.path.join(out_dir, f"overall_mean_{metric}_by_file.png")
        fig.savefig(out_path, bbox_inches="tight")
        plt.close(fig)
        print("Saved plot:", out_path)


def plot_score_distributions_by_file(all_df: pd.DataFrame,
                                     score_cols: List[str],
                                     out_dir: str,
                                     file_order: List[str]) -> None:
    """
    For each metric, make a boxplot of the distribution per dataset,
    with x-axis ordered by file_order and short labels.
    """
    files_present = [f for f in file_order if f in all_df["source_file"].unique()]
    labels_present = [file_label(f) for f in files_present]

    for metric in score_cols:
        data = [
            all_df.loc[all_df["source_file"] == f, metric].dropna().values
            for f in files_present
        ]

        fig, ax = plt.subplots(figsize=(8, 5))
        ax.boxplot(data)
        ax.set_title(f"Distribution of {metric_label(metric)} across data")
        ax.set_ylabel(metric_label(metric))
        ax.set_xticks(np.arange(1, len(labels_present) + 1))
        ax.set_xticklabels(labels_present, rotation=45, ha="right")
        fig.tight_layout()
        out_path = os.path.join(out_dir, f"boxplot_{metric}_by_file.png")
        fig.savefig(out_path, bbox_inches="tight")
        plt.close(fig)
        print("Saved plot:", out_path)


def plot_group_overall_bars(col: str,
                            by_group_overall: pd.DataFrame,
                            score_cols: List[str],
                            out_dir: str) -> None:
    """
    For each metric, bar chart of overall mean score per group value.
    """
    groups = by_group_overall[col].astype(str).tolist()
    x = np.arange(len(groups))

    for metric in score_cols:
        fig, ax = plt.subplots(figsize=(8, 5))
        y = by_group_overall[metric].values
        ax.bar(x, y, color="0.35", alpha=0.9)
        ax.set_xticks(x)
        ax.set_xticklabels(groups, rotation=45, ha="right")
        ax.set_ylabel("Mean score")
        ax.set_title(f"Overall mean {metric_label(metric)} by {col}")
        fig.tight_layout()
        out_path = os.path.join(out_dir, f"overall_mean_{metric}_by_{col}.png")
        fig.savefig(out_path, bbox_inches="tight")
        plt.close(fig)
        print("Saved plot:", out_path)


def plot_group_heatmaps(col: str,
                        by_file_and_group: pd.DataFrame,
                        score_cols: List[str],
                        out_dir: str,
                        file_order: List[str]) -> None:
    """
    For each metric, heatmap of mean score with:
        y-axis: group values (e.g. modality/location/content_type)
        x-axis: dataset (source_file) ordered by file_order, using short labels.
    Adds numeric annotations and uses a muted colormap.
    """
    for metric in score_cols:
        pivot = by_file_and_group.pivot(index=col, columns="source_file", values=metric)

        cols_present = [f for f in file_order if f in pivot.columns]
        pivot = pivot[cols_present]

        vals = pivot.values.astype(float)

        fig, ax = plt.subplots(figsize=(10, 8))
        im = ax.imshow(vals, aspect="auto", origin="upper", cmap="cividis")

        ax.set_xticks(np.arange(len(pivot.columns)))
        short_labels = [file_label(c) for c in pivot.columns]
        ax.set_xticklabels(short_labels, rotation=45, ha="right")

        ax.set_yticks(np.arange(len(pivot.index)))
        ax.set_yticklabels(pivot.index.astype(str))

        ax.set_xlabel("captions")
        ax.set_ylabel(col)
        ax.set_title(f"{metric_label(metric)} by {col} across data")

        # annotate each cell
        norm = im.norm
        cmap = im.get_cmap()
        for i in range(vals.shape[0]):
            for j in range(vals.shape[1]):
                v = vals[i, j]
                if np.isnan(v):
                    continue
                rgba = cmap(norm(v))
                luminance = 0.299 * rgba[0] + 0.587 * rgba[1] + 0.114 * rgba[2]
                txt_color = "black" if luminance > 0.6 else "white"
                ax.text(j, i, f"{v:.3f}", ha="center", va="center", color=txt_color)

        cbar = fig.colorbar(im, ax=ax)
        cbar.set_label("Mean score")

        fig.tight_layout()
        out_path = os.path.join(out_dir, f"heatmap_{metric}_by_{col}_and_file.png")
        fig.savefig(out_path, bbox_inches="tight")
        plt.close(fig)
        print("Saved plot:", out_path)


# ==============================
# MAIN ANALYSIS
# ==============================

def main() -> None:
    if not CSV_PATHS:
        print("ERROR: Please provide at least one CSV path in CSV_PATHS.")
        return

    # Load and concatenate all CSVs
    dfs = []
    for path in CSV_PATHS:
        if not os.path.isfile(path):
            print(f"WARNING: File not found, skipping: {path}")
            continue
        df = load_and_tag_csv(path)
        dfs.append(df)

    if not dfs:
        print("ERROR: No valid CSV files loaded. Check CSV_PATHS.")
        return

    all_df = pd.concat(dfs, ignore_index=True)

    # Ensure score columns exist
    missing_scores = [c for c in SCORE_COLUMNS if c not in all_df.columns]
    if missing_scores:
        print(f"ERROR: These SCORE_COLUMNS are missing in the data: {missing_scores}")
        return

    score_cols = SCORE_COLUMNS

    # Make source_file a categorical with desired order
    all_df["source_file"] = pd.Categorical(
        all_df["source_file"],
        categories=FILE_ORDER,
        ordered=True,
    )

    # -----------------------------------------------------------------
    # 1. Overall average of scores per CSV file
    # -----------------------------------------------------------------
    overall_by_file = (
        all_df.groupby("source_file", observed=True)[score_cols]
        .mean()
        .reset_index()
    )

    # -----------------------------------------------------------------
    # 2. Cluster-averages by modality, location, content_type
    # -----------------------------------------------------------------
    summaries = {}

    for col in GROUP_COLUMNS:
        if col not in all_df.columns:
            print(f"WARNING: Column '{col}' not found in data. Skipping grouping by this column.")
            continue

        by_file_and_group = (
            all_df.groupby(["source_file", col], observed=True)[score_cols]
            .mean()
            .reset_index()
        )

        by_group_overall = (
            all_df.groupby(col, observed=True)[score_cols]
            .mean()
            .reset_index()
        )

        summaries[col] = {
            "by_file_and_group": by_file_and_group,
            "by_group_overall": by_group_overall,
        }

    # -----------------------------------------------------------------
    # Save numeric outputs
    # -----------------------------------------------------------------
    out_dir = ensure_output_dir()

    overall_path = os.path.join(out_dir, "overall_means_by_file.csv")
    overall_by_file.to_csv(overall_path, index=False)
    print("\nSaved overall means per file to:", overall_path)

    for col, data in summaries.items():
        by_file_path = os.path.join(out_dir, f"means_by_file_and_{col}.csv")
        by_group_path = os.path.join(out_dir, f"means_by_{col}_overall.csv")

        data["by_file_and_group"].to_csv(by_file_path, index=False)
        data["by_group_overall"].to_csv(by_group_path, index=False)

        print(f"Saved means by file and {col} to:", by_file_path)
        print(f"Saved overall means by {col} to:", by_group_path)

    # -----------------------------------------------------------------
    # Plotting (matplotlib only)
    # -----------------------------------------------------------------
    set_plot_style(FONT_SCALE)

    # 1) Overall comparison of datasets
    plot_overall_means_by_file(overall_by_file, score_cols, out_dir, FILE_ORDER)
    plot_score_distributions_by_file(all_df, score_cols, out_dir, FILE_ORDER)

    # 2) Group-based plots
    for col, data in summaries.items():
        print(f"\nCreating plots for grouping by '{col}'...")
        plot_group_overall_bars(
            col=col,
            by_group_overall=data["by_group_overall"],
            score_cols=score_cols,
            out_dir=out_dir,
        )
        plot_group_heatmaps(
            col=col,
            by_file_and_group=data["by_file_and_group"],
            score_cols=score_cols,
            out_dir=out_dir,
            file_order=FILE_ORDER,
        )

    # Quick sanity check in stdout
    print("\n=== Overall means per file (head) ===")
    print(overall_by_file.head())

    for col, data in summaries.items():
        print(f"\n=== Means by file and {col} (head) ===")
        print(data["by_file_and_group"].head())


if __name__ == "__main__":
    main()
