In [4]:
import json
import numpy as np


# ---------------------------------------------------------
# BASIC LATEX HELPERS
# ---------------------------------------------------------

def cell(value, std=None):
    """Formats a LaTeX cell: value \\tpms{std}"""
    if std is not None:
        return f"{value} \\tpms{{{std}}}"
    return value


def clean_num(x):
    """Converts 0.55 -> .55, 1.00 stays 1.00"""
    s = f"{x:.2f}"
    return s[1:] if s.startswith("0") else s


# ---------------------------------------------------------
# METRIC EXTRACTION FROM JSON
# ---------------------------------------------------------

def calculate_metrics_stats(data: dict) -> dict:
    iou_means, iou_stds = [], []
    serp_means, serp_stds = [], []
    prag_means, prag_stds = [], []
    diversity_means, diversity_stds = [], []

    for attribute_name, attribute_data in data.items():
        if attribute_name == "neutral":
            continue

        if "IOU Divergence" in attribute_data:
            iou_means.append(attribute_data["IOU Divergence"].get("mean", 0.0))
            iou_stds.append(attribute_data["IOU Divergence"].get("std", 0.0))

        if "SERP MS Divergence" in attribute_data:
            serp_means.append(attribute_data["SERP MS Divergence"].get("mean", 0.0))
            serp_stds.append(attribute_data["SERP MS Divergence"].get("std", 0.0))

        if "Pragmatic Divergence" in attribute_data:
            prag_means.append(attribute_data["Pragmatic Divergence"].get("mean", 0.0))
            prag_stds.append(attribute_data["Pragmatic Divergence"].get("std", 0.0))

        if "diversity" in attribute_data:
            diversity_means.append(attribute_data["diversity"].get("mean", 0.0))
            diversity_stds.append(attribute_data["diversity"].get("std", 0.0))

    avg = lambda lst: sum(lst) / len(lst) if lst else 0.0

    return {
        "overall_mean_iou": round(avg(iou_means), 2),
        "overall_mean_std_iou": round(avg(iou_stds), 2),
        "overall_mean_serp": round(avg(serp_means), 2),
        "overall_mean_std_serp": round(avg(serp_stds), 2),
        "overall_mean_prag": round(avg(prag_means), 2),
        "overall_mean_std_prag": round(avg(prag_stds), 2),
        "overall_mean_diversity": round(avg(diversity_means), 2),
        "overall_mean_std_diversity": round(avg(diversity_stds), 2),
    }


# ---------------------------------------------------------
# BUILD 12-CELL ROW PER MODEL
# ---------------------------------------------------------

def build_rows(model_groups):
    """
    model_groups: list of (model_name, [college.json, music.json, movie.json])
    Returns: list of dicts:
      { "model": name, "cells": list of 12 dicts {mean,std,tag} }
    """
    rows = []
    for model_name, file_paths in model_groups:
        cells = []
        for file_path in file_paths:
            with open(file_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                stats = calculate_metrics_stats(data)

            # 4 metrics per dataset
            cells.extend([
                {"mean": stats["overall_mean_iou"],       "std": stats["overall_mean_std_iou"],       "tag": None},
                {"mean": stats["overall_mean_serp"],      "std": stats["overall_mean_std_serp"],      "tag": None},
                {"mean": stats["overall_mean_prag"],      "std": stats["overall_mean_std_prag"],      "tag": None},
                {"mean": stats["overall_mean_diversity"], "std": stats["overall_mean_std_diversity"], "tag": None},
            ])

        rows.append({"model": model_name, "cells": cells})

    return rows


# ---------------------------------------------------------
# APPLY BEST + SECOND PER COLUMN ACROSS MODELS
# ---------------------------------------------------------

def mark_best_second(rows):
    """
    rows: list from build_rows()
    Modifies rows in-place:
      - \best{} for best (min/min/min/max pattern repeated)
      - if single best, apply \second{} to second-best
      - if multiple bests, skip second
    """
    num_models = len(rows)
    num_cols = len(rows[0]["cells"])

    # build numeric matrix
    means = np.zeros((num_models, num_cols))
    for i, row in enumerate(rows):
        for j, cell_info in enumerate(row["cells"]):
            means[i, j] = cell_info["mean"]

    # evaluate each column independently
    for col in range(num_cols):
        col_vals = means[:, col]

        block_pos = col % 4   # 0,1,2 = min ; 3 = max
        minimize = block_pos in (0, 1, 2)

        best_val = np.min(col_vals) if minimize else np.max(col_vals)
        best_idxs = np.where(col_vals == best_val)[0]

        # mark best(s)
        for r in best_idxs:
            rows[r]["cells"][col]["tag"] = "best"

        # if multiple bests â†’ skip second
        if len(best_idxs) > 1:
            continue

        # find second
        others = col_vals[col_vals != best_val]
        if others.size == 0:
            continue

        second_val = np.min(others) if minimize else np.max(others)
        second_idxs = np.where(col_vals == second_val)[0]

        for r in second_idxs:
            if rows[r]["cells"][col]["tag"] is None:
                rows[r]["cells"][col]["tag"] = "second"

    return rows


# ---------------------------------------------------------
# PRINT EXACT LATEX FORMAT (3 LINES PER MODEL)
# ---------------------------------------------------------

def print_rows(rows):
    for row in rows:
        print(f"\\textbf{{{row['model']}}}")

        cells = row["cells"]

        # split 12 cells into 3 blocks (College, Music, Movie)
        blocks = [cells[i:i+4] for i in range(0, 12, 4)]

        for i, block in enumerate(blocks):
            parts = []
            for c in block:
                mean_str = clean_num(c["mean"])
                std_str  = clean_num(c["std"])

                if c["tag"] == "best":
                    mean_str = f"\\best{{{mean_str}}}"
                elif c["tag"] == "second":
                    mean_str = f"\\second{{{mean_str}}}"

                parts.append(cell(mean_str, std_str))

            line = " & " + " & ".join(parts)
            if i == 2:
                line += " \\\\"
            print(line)

        print("")   # blank line after each model


# ---------------------------------------------------------
# MODEL GROUPS (example: 4B row comparison)
# ---------------------------------------------------------

model_groups = [
    ("Gemma 3", [
        "data/metric_results/gemma3_4b_college_.json",
        "data/metric_results/gemma3_4b_music_.json",
        "data/metric_results/gemma3_4b_movie_.json",
    ]),

    ("Llama 3.2", [
        "data/metric_results/llama3.2_3b_college_.json",
        "data/metric_results/llama3.2_3b_music_.json",
        "data/metric_results/llama3.2_3b_movie_.json",
    ]),

    ("Qwen3", [
        "data/metric_results/qwen3_4b_college_.json",
        "data/metric_results/qwen3_4b_music_.json",
        "data/metric_results/qwen3_4b_movie_.json",
    ]),

    ("MiniCPM3", [
        "data/metric_results/minicpm3_4b_college_.json",
        "data/metric_results/minicpm3_4b_music_.json",
        "data/metric_results/minicpm3_4b_movie_.json",
    ]),

    ("SmolLM3", [
        "data/metric_results/smollm3_3b_college_.json",
        "data/metric_results/smollm3_3b_music_.json",
        "data/metric_results/smollm3_3b_movie_.json",
    ]),
]

model_groups = [
    ("Gemma 3", [
        "data/metric_results/gemma3_4b_college_high_school_student.json",
        "data/metric_results/gemma3_4b_music_music_fan.json",
        "data/metric_results/gemma3_4b_movie_movie_fan.json",
    ]),

    ("Llama 3.2", [
        "data/metric_results/llama3.2_3b_college_high_school_student.json",
        "data/metric_results/llama3.2_3b_music_music_fan.json",
        "data/metric_results/llama3.2_3b_movie_movie_fan.json",
    ]),

    ("Qwen3", [
        "data/metric_results/qwen3_4b_college_high_school_student.json",
        "data/metric_results/qwen3_4b_music_music_fan.json",
        "data/metric_results/qwen3_4b_movie_movie_fan.json",
    ]),

    ("MiniCPM3", [
        "data/metric_results/minicpm3_4b_college_high_school_student.json",
        "data/metric_results/minicpm3_4b_music_music_fan.json",
        "data/metric_results/minicpm3_4b_movie_movie_fan.json",
    ]),

    ("SmolLM3", [
        "data/metric_results/smollm3_3b_college_high_school_student.json",
        "data/metric_results/smollm3_3b_music_music_fan.json",
        "data/metric_results/smollm3_3b_movie_movie_fan.json",
    ]),
]

# ---------------------------------------------------------
# RUN PIPELINE
# ---------------------------------------------------------

rows = build_rows(model_groups)
rows = mark_best_second(rows)
print_rows(rows)


\textbf{Gemma 3}
 & .63 \tpms{.15} & .46 \tpms{.16} & .50 \tpms{.15} & .88 \tpms{.09}
 & .39 \tpms{.14} & .34 \tpms{.12} & .37 \tpms{.12} & \second{.98} \tpms{.03}
 & .69 \tpms{.14} & .57 \tpms{.19} & .62 \tpms{.16} & .88 \tpms{.22} \\

\textbf{Llama 3.2}
 & .49 \tpms{.15} & .34 \tpms{.14} & .40 \tpms{.13} & .87 \tpms{.11}
 & .57 \tpms{.13} & .44 \tpms{.15} & .47 \tpms{.14} & .79 \tpms{.15}
 & .61 \tpms{.09} & .43 \tpms{.14} & .60 \tpms{.10} & \second{.91} \tpms{.09} \\

\textbf{Qwen3}
 & .52 \tpms{.09} & .40 \tpms{.06} & .49 \tpms{.07} & .80 \tpms{.11}
 & .50 \tpms{.11} & .48 \tpms{.11} & .50 \tpms{.10} & .92 \tpms{.09}
 & .51 \tpms{.11} & .48 \tpms{.15} & .66 \tpms{.09} & .75 \tpms{.14} \\

\textbf{MiniCPM3}
 & \second{.05} \tpms{.06} & \second{.03} \tpms{.04} & \second{.03} \tpms{.04} & \second{.98} \tpms{.01}
 & \best{.05} \tpms{.05} & \best{.05} \tpms{.04} & \best{.05} \tpms{.04} & \best{.99} \tpms{.01}
 & \second{.25} \tpms{.09} & \second{.23} \tpms{.06} & \second{.23} \tpms{.06}

In [25]:
import json
import os
import numpy as np


def fmt_value(value):
    """Clean decimal formatting."""
    s = f"{value:.2f}"
    return s[1:] if s.startswith("0") else s


def cell(mean, std):
    """Formats: mean \\tpms{std}"""
    return f"{mean} \\tpms{{{std}}}"


def calculate_metrics_stats(data: dict) -> dict:
    """Return dict: { 'iou': (mean,std), 'serp':..., 'prag':..., 'div':... }"""
    iou_means, iou_stds = [], []
    serp_means, serp_stds = [], []
    prag_means, prag_stds = [], []
    div_means,  div_stds  = [], []

    for attr, record in data.items():
        if attr == "neutral":
            continue

        if "IOU Divergence" in record:
            iou_means.append(record["IOU Divergence"].get("mean", 0.0))
            iou_stds.append(record["IOU Divergence"].get("std", 0.0))

        if "SERP MS Divergence" in record:
            serp_means.append(record["SERP MS Divergence"].get("mean", 0.0))
            serp_stds.append(record["SERP MS Divergence"].get("std", 0.0))

        if "Pragmatic Divergence" in record:
            prag_means.append(record["Pragmatic Divergence"].get("mean", 0.0))
            prag_stds.append(record["Pragmatic Divergence"].get("std", 0.0))

        if "diversity" in record:
            div_means.append(record["diversity"].get("mean", 0.0))
            div_stds.append(record["diversity"].get("std", 0.0))

    avg = lambda L: sum(L)/len(L) if L else 0

    return {
        "iou":  (fmt_value(round(avg(iou_means),2)),  fmt_value(round(avg(iou_stds),2))),
        "serp": (fmt_value(round(avg(serp_means),2)), fmt_value(round(avg(serp_stds),2))),
        "prag": (fmt_value(round(avg(prag_means),2)), fmt_value(round(avg(prag_stds),2))),
        "div":  (fmt_value(round(avg(div_means),2)),  fmt_value(round(avg(div_stds),2))),
    }


# ------------------------------------------------------------
# AUTOMATIC BEST MARKING (with ties allowed)
# ------------------------------------------------------------

def apply_best_only(model_rows):
    """
    model_rows: list of rows, each row is list of 12 (mean,std) pairs.
    Modifies: wraps only \best{} for best values.
    Multiple best allowed if tied.
    """

    num_rows = len(model_rows)
    num_metrics = len(model_rows[0])

    for col in range(num_metrics):

        # get numeric means
        means = np.array([float(model_rows[r][col][0]) for r in range(num_rows)])

        # determine sort direction:
        # col % 4 < 3  => IOU/SERP/PRAG (minimize)
        # col % 4 == 3 => diversity (maximize)
        if col % 4 < 3:
            best_value = np.min(means)
        else:
            best_value = np.max(means)

        # apply best to ALL equal to best_value
        for r in range(num_rows):
            m, s = model_rows[r][col]
            if float(m) == float(best_value):
                model_rows[r][col] = (f"\\best{{{m}}}", s)

    return model_rows


# ------------------------------------------------------------
# BUILD MODEL BLOCK
# ------------------------------------------------------------

def build_model_block(model_name, sizes):
    """
    sizes = list of (size_label, [file1,file2,file3])
    """

    # 1. Collect all model rows first
    rows = []
    for size_label, files in sizes:
        row = []
        for fp in files:
            with open(fp, "r", encoding="utf-8") as f:
                stats = calculate_metrics_stats(json.load(f))
                row.extend([
                    stats["iou"],
                    stats["serp"],
                    stats["prag"],
                    stats["div"],
                ])
        rows.append(row)

    # 2. Apply BEST marking (no second, ties allowed)
    marked_rows = apply_best_only(rows)

    # 3. Print LaTeX block
    print(f"\\multirow{{{len(sizes)}}}{{*}}{{\\textbf{{{model_name}}}}}")

    for idx, ((size_label, _), row) in enumerate(zip(sizes, marked_rows)):
        metrics = " & ".join([cell(m, s) for (m, s) in row])
        print(f" & \\textbf{{{size_label}}} & {metrics} \\\\")

    print("\n\\midrule\n")


# ------------------------------------------------------------
# DEFINE MODEL STRUCTURE
# ------------------------------------------------------------

model_structure = {
    "Gemma 3": [
        ("1B",  [
            "data/metric_results/gemma3_1b_college_.json",
            "data/metric_results/gemma3_1b_music_.json",
            "data/metric_results/gemma3_1b_movie_.json"]),
        ("4B",  [
            "data/metric_results/gemma3_4b_college_.json",
            "data/metric_results/gemma3_4b_music_.json",
            "data/metric_results/gemma3_4b_movie_.json"]),
        ("12B", [
            "data/metric_results/gemma3_12b_college_.json",
            "data/metric_results/gemma3_12b_music_.json",
            "data/metric_results/gemma3_12b_movie_.json"]),
    ],

    "Llama 3.2": [
        ("1B", [
            "data/metric_results/llama3.2_1b_college_.json",
            "data/metric_results/llama3.2_1b_music_.json",
            "data/metric_results/llama3.2_1b_movie_.json"]),
        ("3B", [
            "data/metric_results/llama3.2_3b_college_.json",
            "data/metric_results/llama3.2_3b_music_.json",
            "data/metric_results/llama3.2_3b_movie_.json"]),
    ],

    "Qwen3": [
        ("1.7B", [
            "data/metric_results/qwen3_1.7b_college_.json",
            "data/metric_results/qwen3_1.7b_music_.json",
            "data/metric_results/qwen3_1.7b_movie_.json"]),
        ("3B", [
            "data/metric_results/qwen3_4b_college_.json",
            "data/metric_results/qwen3_4b_music_.json",
            "data/metric_results/qwen3_4b_movie_.json"]),
        ("14B", [
            "data/metric_results/qwen3_14b_college_.json",
            "data/metric_results/qwen3_14b_music_.json",
            "data/metric_results/qwen3_14b_movie_.json"]),
    ],
}


# ------------------------------------------------------------
# RUN ALL MODELS
# ------------------------------------------------------------

for model_name, sizes in model_structure.items():
    build_model_block(model_name, sizes)


\multirow{3}{*}{\textbf{Gemma 3}}
 & \textbf{1B} & .59 \tpms{.12} & .43 \tpms{.10} & \best{.46} \tpms{.10} & .97 \tpms{.03} & .73 \tpms{.20} & .69 \tpms{.23} & .69 \tpms{.23} & \best{1.00} \tpms{.00} & .78 \tpms{.08} & .69 \tpms{.10} & .72 \tpms{.09} & .99 \tpms{.01} \\
 & \textbf{4B} & \best{.55} \tpms{.12} & \best{.42} \tpms{.09} & .47 \tpms{.09} & .98 \tpms{.02} & \best{.27} \tpms{.09} & \best{.18} \tpms{.06} & \best{.20} \tpms{.06} & \best{1.00} \tpms{.00} & \best{.55} \tpms{.09} & \best{.40} \tpms{.08} & \best{.46} \tpms{.08} & .99 \tpms{.01} \\
 & \textbf{12B} & .65 \tpms{.07} & .50 \tpms{.09} & .55 \tpms{.09} & \best{.99} \tpms{.01} & .66 \tpms{.08} & .47 \tpms{.09} & .48 \tpms{.09} & \best{1.00} \tpms{.00} & .75 \tpms{.07} & .64 \tpms{.09} & .67 \tpms{.07} & \best{1.00} \tpms{.00} \\

\midrule

\multirow{2}{*}{\textbf{Llama 3.2}}
 & \textbf{1B} & \best{.12} \tpms{.18} & \best{.06} \tpms{.12} & \best{.16} \tpms{.23} & .76 \tpms{.33} & \best{.10} \tpms{.12} & \best{.05} \tpms{.08