In [13]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
import matplotlib.cm as cm
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# ─────────────────────────────────────────────────────────────────────────────
# 1. DATA LOADING
# ─────────────────────────────────────────────────────────────────────────────

DATA_DIR = "./valid_choreography"
OUTPUT_DIR = "./outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

files = {
    "flash-lite_0.01_base":     "gemini-2.5-flash-lite_temp_0.01_base_prompt.json",
    "flash-lite_0.01_mod":      "gemini-2.5-flash-lite_temp_0.01_modified_prompt.json",
    "flash-lite_0.1_mod":       "gemini-2.5-flash-lite_temp_0.1_modified_prompt.json",
    "flash-lite_0.7_mod":       "gemini-2.5-flash-lite_temp_0.7_modified_prompt.json",
    "flash_0.1_base":           "gemini-2.5-flash_temp_0.1_base_prompt.json",
    "flash_0.1_mod":            "gemini-2.5-flash_temp_0.1_modified_prompt.json",
    "flash_0.3_base":           "gemini-2.5-flash_temp_0.3_base_prompt.json",
    "flash_0.3_mod":            "gemini-2.5-flash_temp_0.3_modified_prompt.json",
    "flash_0.7_base":           "gemini-2.5-flash_temp_0.7_base_prompt.json",
    "flash_0.7_mod":            "gemini-2.5-flash_temp_0.7_modified_prompt.json",
    "pro_0.1_base":             "gemini-3-pro-preview_temp_0.1_base_prompt.json",
    "pro_0.1_mod":              "gemini-3-pro-preview_temp_0.1_modified_prompt.json",
    "pro_0.3_base":             "gemini-3-pro-preview_temp_0.3_base_prompt.json",
    "pro_0.3_mod":              "gemini-3-pro-preview_temp_0.3_modified_prompt.json",
    "pro_0.7_base":             "gemini-3-pro-preview_temp_0.7_base_prompt.json",
    "pro_0.7_mod":              "gemini-3-pro-preview_temp_0.7_modified_prompt.json",
    "pro_1.0_base":             "gemini-3-pro-preview_temp_1_base_prompt.json",
    "pro_1.0_mod":              "gemini-3-pro-preview_temp_1_modified_prompt.json",
    "flash_0.7_mod2":           "gemini-2.5-flash_temp_0.7_modified_prompt.json",
}

choreographies = {}
for label, fname in files.items():
    path = os.path.join(DATA_DIR, fname)
    if os.path.exists(path):
        with open(path) as f:
            choreographies[label] = json.load(f)

print(f"Loaded {len(choreographies)} choreographies.\n")

# ─────────────────────────────────────────────────────────────────────────────
# 2. FEATURE EXTRACTION
# ─────────────────────────────────────────────────────────────────────────────

def parse_label(label):
    parts = label.split("_")
    model = parts[0]          # flash-lite | flash | pro
    temp  = parts[1]          # 0.01 etc.
    prompt = parts[2] if len(parts) > 2 else "base"
    return model, temp, prompt

def extract_features(label, data):
    events = data.get("events", [])
    duration = data.get("duration", 90)
    dancers  = data.get("dancers", [])
    n_dancers = len(dancers)
    n_events  = len(events)

    # spatial centroid usage
    xs, ys = [], []
    for e in events:
        for key in ("from", "to"):
            pt = e.get(key, {})
            if pt:
                xs.append(pt.get("x", 12))
                ys.append(pt.get("y", 12))

    center_dist = []
    for x, y in zip(xs, ys):
        center_dist.append(np.sqrt((x - 12)**2 + (y - 12)**2))

    # movement speed (distance / duration of event)
    speeds = []
    for e in events:
        dt = e.get("end", 0) - e.get("start", 0)
        fx, fy = e.get("from", {}).get("x", 12), e.get("from", {}).get("y", 12)
        tx, ty = e.get("to",   {}).get("x", 12), e.get("to",   {}).get("y", 12)
        dist = np.sqrt((tx - fx)**2 + (ty - fy)**2)
        if dt > 0:
            speeds.append(dist / dt)

    # partner / collaborative events
    partner_events = sum(1 for e in events if e.get("partner"))
    partner_ratio  = partner_events / max(n_events, 1)

    # unique action vocabulary
    actions = [e.get("action", "").lower() for e in events]
    unique_actions = len(set(actions))

    # spatial spread (std of positions)
    spatial_spread = np.std(xs) + np.std(ys) if xs else 0

    # corner usage: fraction of points near corners (x<6 or x>18, y<6 or y>18)
    corner_pts = sum(1 for x, y in zip(xs, ys)
                     if (x < 6 or x > 18) and (y < 6 or y > 18))
    corner_ratio = corner_pts / max(len(xs), 1)

    # center usage: fraction of points in center 8x8 box
    center_pts = sum(1 for x, y in zip(xs, ys)
                     if 8 <= x <= 16 and 8 <= y <= 16)
    center_ratio = center_pts / max(len(xs), 1)

    return {
        "label": label,
        "model": parse_label(label)[0],
        "temp": float(parse_label(label)[1]),
        "prompt": parse_label(label)[2],
        "duration": duration,
        "n_dancers": n_dancers,
        "n_events": n_events,
        "events_per_sec": n_events / max(duration, 1),
        "mean_speed": np.mean(speeds) if speeds else 0,
        "max_speed": np.max(speeds) if speeds else 0,
        "mean_center_dist": np.mean(center_dist) if center_dist else 0,
        "spatial_spread": spatial_spread,
        "corner_ratio": corner_ratio,
        "center_ratio": center_ratio,
        "partner_ratio": partner_ratio,
        "unique_actions": unique_actions,
        "xs": xs,
        "ys": ys,
        "speeds": speeds,
        "actions": actions,
    }

features = {label: extract_features(label, data)
            for label, data in choreographies.items()}

# ─────────────────────────────────────────────────────────────────────────────
# 3. PRINT STATISTICS TABLE
# ─────────────────────────────────────────────────────────────────────────────

print("=" * 100)
print(f"{'Label':<25} {'Model':<10} {'Temp':>5} {'Prompt':>8} {'Dur':>5} {'Dancers':>7} "
      f"{'Events':>7} {'Ev/s':>6} {'AvgSpd':>7} {'CtrDst':>7} "
      f"{'Sprd':>6} {'Cornr%':>7} {'Cntr%':>6} {'Prtnr%':>7} {'UniAct':>7}")
print("-" * 100)
for label, f in sorted(features.items()):
    print(f"{label:<25} {f['model']:<10} {f['temp']:>5.2f} {f['prompt']:>8} "
          f"{f['duration']:>5.0f} {f['n_dancers']:>7} {f['n_events']:>7} "
          f"{f['events_per_sec']:>6.2f} {f['mean_speed']:>7.2f} "
          f"{f['mean_center_dist']:>7.2f} {f['spatial_spread']:>6.2f} "
          f"{f['corner_ratio']*100:>7.1f} {f['center_ratio']*100:>6.1f} "
          f"{f['partner_ratio']*100:>7.1f} {f['unique_actions']:>7}")
print("=" * 100)

# group stats
print("\n── GROUP MEANS by MODEL ──")
for model in ("flash-lite", "flash", "pro"):
    subset = [f for f in features.values() if f["model"] == model]
    if not subset:
        continue
    print(f"\nModel: {model}  (n={len(subset)})")
    for key in ("duration", "n_dancers", "n_events", "mean_speed",
                "mean_center_dist", "spatial_spread", "partner_ratio", "unique_actions"):
        vals = [s[key] for s in subset]
        print(f"  {key:<22}: mean={np.mean(vals):.2f}  std={np.std(vals):.2f}")

print("\n── GROUP MEANS by PROMPT TYPE ──")
for ptype in ("base", "mod"):
    subset = [f for f in features.values() if f["prompt"] == ptype]
    if not subset:
        continue
    print(f"\nPrompt: {ptype}  (n={len(subset)})")
    for key in ("n_events", "mean_speed", "spatial_spread", "partner_ratio", "unique_actions"):
        vals = [s[key] for s in subset]
        print(f"  {key:<22}: mean={np.mean(vals):.2f}  std={np.std(vals):.2f}")

# ─────────────────────────────────────────────────────────────────────────────
# 4. FIGURES
# ─────────────────────────────────────────────────────────────────────────────

MODEL_COLORS = {"flash-lite": "#E07B39", "flash": "#3A86FF", "pro": "#8338EC"}
PROMPT_MARKERS = {"base": "o", "mod": "s"}


# ── Fig: Event timeline for one choreography per model ──────────────────
model_examples = {
    "flash-lite": "flash-lite_0.01_base",
    "flash":      "flash_0.3_base",
    "pro":        "pro_0.7_base",
}
fig, axes = plt.subplots(3, 1, figsize=(14, 12))
fig.suptitle("Event Timelines per Model (representative choreography)", fontsize=13, fontweight='bold')

dancer_colors = ['#E63946', '#457B9D', '#2A9D8F', '#E9C46A']
for ax, (model, key) in zip(axes, model_examples.items()):
    data = choreographies.get(key, {})
    events = data.get("events", [])
    dancers = data.get("dancers", [])
    duration = data.get("duration", 90)

    for di, dancer in enumerate(dancers[:5]):
        d_events = [e for e in events if e.get("dancer") == dancer]
        for e in d_events:
            s, end = e.get("start", 0), e.get("end", 0)
            if end <= s:
                end = s + 0.5
            ax.barh(di, end - s, left=s, height=0.6,
                    color=dancer_colors[di % len(dancer_colors)], alpha=0.7,
                    edgecolor='white', lw=0.4)

    ax.set_yticks(range(len(dancers[:5])))
    ax.set_yticklabels([d if len(d) < 20 else f"Dancer {i+1}"
                        for i, d in enumerate(dancers[:5])])
    ax.set_xlim(0, duration)
    ax.set_xlabel("Time (s)")
    ax.set_title(f"Gemini {model}", fontsize=11)
    ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/fig1_timelines.png", dpi=150, bbox_inches='tight')
plt.close()
print("Saved fig1_timelines.png")


Loaded 19 choreographies.

Label                     Model       Temp   Prompt   Dur Dancers  Events   Ev/s  AvgSpd  CtrDst   Sprd  Cornr%  Cntr%  Prtnr%  UniAct
----------------------------------------------------------------------------------------------------
flash-lite_0.01_base      flash-lite  0.01     base    75       3      42   0.56    0.68    4.54   8.15     3.6   71.4     7.1      28
flash-lite_0.01_mod       flash-lite  0.01      mod    90       4      50   0.56    0.76   11.50  15.49    66.0   23.0     0.0      25
flash-lite_0.1_mod        flash-lite  0.10      mod    90       4      38   0.42    0.45    7.59  11.55    23.7   46.1     0.0      32
flash-lite_0.7_mod        flash-lite  0.70      mod    90       4      55   0.61    0.41    7.87  12.22    22.7   38.2     0.0      45
flash_0.1_base            flash       0.10     base    78       3      33   0.43    1.08    9.72  15.61    18.2   24.2     0.0      21
flash_0.1_mod             flash       0.10      mod    78     

In [14]:
"""
Generate iteration analysis figures from the summary.docx run data.
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
import warnings
warnings.filterwarnings('ignore')

OUTPUT_DIR = "./outputs"

# BASE PROMPT table
base_runs = {
    # (model, temp): [run0, run1, run2, run3, run4, run5]
    ("pro",        1.0): ["2col+4spd", "valid",   None, None, None, None],
    ("pro",        0.7): ["5col+1spd", "1col",    "1col", "valid", None, None],
    ("pro",        0.3): ["2spd",      "valid",   None, None, None, None],
    ("pro",        0.1): ["3ent",      "valid",   None, None, None, None],
    ("flash",      1.0): ["1col+13spd","1col+12spd","2col+11spd","5spd","4spd","2spd"],
    ("flash",      0.7): ["4spd",      "4spd",    "4spd", "valid", None, None],
    ("flash",      0.3): ["1col+5spd", "1col+5spd","2col+4spd","1col","valid",None],
    ("flash",      0.1): ["7spd",      "valid",   None, None, None, None],
    ("flash-lite", 1.0): ["68col+spd", "45col",   "22col","56col","50col","50col"],
    ("flash-lite", 0.7): ["48col+spd", "20col+spd","20col+spd","19col+spd","19col+spd","20col+spd"],
    ("flash-lite", 0.3): ["30col",     "30col",   "30col","30col","30col","30col"],
    ("flash-lite", 0.1): ["46col+1spd","46col+1spd","46col+1spd","46col","46col","46col"],
    ("flash-lite", 0.01):["3spd",      "1col",    "1col", "valid",  None, None],
}

# MODIFIED PROMPT table
mod_runs = {
    ("pro",        1.0): ["7col",       "valid",   None, None, None, None],
    ("pro",        0.7): ["5spd",       "4spd",    "valid", None, None, None],
    ("pro",        0.3): ["4col+2spd",  "valid",   None, None, None, None],
    ("pro",        0.1): ["6spd",       "valid",   None, None, None, None],
    ("flash",      1.0): ["11spd",      "11spd",   "11spd","11spd","11spd","11spd"],
    ("flash",      0.7): ["4spd",       "4spd",    "valid",None, None, None],
    ("flash",      0.3): ["valid",      None,      None, None, None, None],
    ("flash",      0.1): ["1col+4oob+8spd","4oob+7spd","valid",None,None,None],
    ("flash-lite", 1.0): ["31col+spd",  "39col+spd","41col+spd","38col+spd","38col+spd","34col+spd"],
    ("flash-lite", 0.7): ["1ent",       "valid",   None, None, None, None],
    ("flash-lite", 0.3): ["10spd",      "8spd",    "1spd","1spd","1spd","1spd"],
    ("flash-lite", 0.1): ["1ent",       "valid",   None, None, None, None],
    ("flash-lite", 0.01):["2spd",       "valid",   None, None, None, None],
}

def runs_until_valid(run_list):
    """Return number of runs needed to get 'valid', or None if never valid."""
    for i, r in enumerate(run_list):
        if r is not None and r.lower() == "valid":
            return i + 1   # 1-indexed count of runs
    return None            # never reached valid in <=6 runs

def total_errors(run_list):
    """Count total error events across all failing runs."""
    total = 0
    for r in run_list:
        if r is None or r.lower() == "valid":
            continue
        # rough count: split by '+' and count segments
        parts = r.replace("+", " ").split()
        for p in parts:
            # extract leading integer if present
            n = ""
            for ch in p:
                if ch.isdigit():
                    n += ch
                else:
                    break
            total += int(n) if n else 1
    return total

def dominant_error(run_list):
    """Return the most frequent error type across failing runs."""
    counts = {"col": 0, "spd": 0, "ent": 0, "oob": 0}
    for r in run_list:
        if r is None or r.lower() == "valid":
            continue
        for k in counts:
            if k in r:
                counts[k] += 1
    if max(counts.values()) == 0:
        return "none"
    return max(counts, key=counts.get)


models = ["pro", "flash", "flash-lite"]
temps_pro   = [1.0, 0.7, 0.3, 0.1]
temps_flash = [1.0, 0.7, 0.3, 0.1]
temps_lite  = [1.0, 0.7, 0.3, 0.1, 0.01]

all_keys_ordered = [
    ("pro",        1.0), ("pro",        0.7), ("pro",        0.3), ("pro",        0.1),
    ("flash",      1.0), ("flash",      0.7), ("flash",      0.3), ("flash",      0.1),
    ("flash-lite", 1.0), ("flash-lite", 0.7), ("flash-lite", 0.3), ("flash-lite", 0.1), ("flash-lite", 0.01),
]

labels = [f"{m}\n{t}" for m, t in all_keys_ordered]

base_iters = [runs_until_valid(base_runs.get(k, [None]*6)) for k in all_keys_ordered]
mod_iters  = [runs_until_valid(mod_runs.get(k, [None]*6))  for k in all_keys_ordered]

# ─────────────────────────────────────────────────────────────────────────────
# FIGURE: Error type breakdown per model
# ─────────────────────────────────────────────────────────────────────────────
error_types = {"col": "Collision", "spd": "Speed", "ent": "Entry", "oob": "Out-of-Bounds"}
colors_err  = {"col": "#E53935", "spd": "#FB8C00", "ent": "#8E24AA", "oob": "#039BE5"}

fig, axes = plt.subplots(1, 2, figsize=(13, 5))
fig.suptitle("Error Type Distribution Across All Failing Runs", fontsize=13, fontweight='bold')

for ax, (run_dict, title) in zip(axes, [(base_runs, "Base Prompt"), (mod_runs, "Modified Prompt")]):
    model_counts = {}
    for m in models:
        counts = {k: 0 for k in error_types}
        for (mdl, tmp), runs in run_dict.items():
            if mdl != m:
                continue
            for r in runs:
                if r is None or r.lower() == "valid":
                    continue
                for k in error_types:
                    if k in r:
                        counts[k] += 1
        model_counts[m] = counts

    x = np.arange(len(models))
    w = 0.18
    offsets = [-1.5, -0.5, 0.5, 1.5]
    for i, (ekey, elabel) in enumerate(error_types.items()):
        vals = [model_counts[m][ekey] for m in models]
        ax.bar(x + offsets[i] * w, vals, w, label=elabel,
               color=colors_err[ekey], alpha=0.85, edgecolor='white')

    ax.set_xticks(x)
    ax.set_xticklabels([f"Gemini\n{m}" for m in models])
    ax.set_ylabel("Count of failing runs with this error type")
    ax.set_title(title, fontsize=11)
    ax.legend(fontsize=9)
    ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/fig2_error_breakdown.png", dpi=150, bbox_inches='tight')
plt.close()
print("Saved fig2_error_breakdown.png")

# ─────────────────────────────────────────────────────────────────────────────
# FIGURE: Convergence curves — cumulative valid rate across runs
# ─────────────────────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
fig.suptitle("Cumulative Validation Success Rate by Run Number", fontsize=13, fontweight='bold')
model_colors = {"pro": "#8338EC", "flash": "#3A86FF", "flash-lite": "#E07B39"}

for ax, (run_dict, title) in zip(axes, [(base_runs, "Base Prompt"), (mod_runs, "Modified Prompt")]):
    for m in models:
        keys = [(mdl, tmp) for (mdl, tmp) in all_keys_ordered if mdl == m]
        cumulative = np.zeros(6)
        total = len(keys)
        for k in keys:
            runs = run_dict.get(k, [None]*6)
            for i, r in enumerate(runs):
                if r is not None and r.lower() == "valid":
                    cumulative[i:] += 1
                    break
        ax.plot(range(1, 7), cumulative / total * 100,
                marker='o', lw=2, color=model_colors[m], label=f"Gemini {m}")
        ax.fill_between(range(1, 7), cumulative / total * 100, alpha=0.1, color=model_colors[m])

    ax.set_xlabel("Run number")
    ax.set_ylabel("% of conditions with valid choreography")
    ax.set_title(title, fontsize=11)
    ax.set_xlim(0.5, 6.5)
    ax.set_ylim(-5, 105)
    ax.set_xticks(range(1, 7))
    ax.axhline(100, color='gray', ls='--', lw=1, alpha=0.5)
    ax.legend()
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/fig3_convergence.png", dpi=150, bbox_inches='tight')
plt.close()
print("Saved fig3_convergence.png")

# ─────────────────────────────────────────────────────────────────────────────
# PRINT SUMMARY STATS
# ─────────────────────────────────────────────────────────────────────────────
print("\n── ITERATION SUMMARY ──")
print(f"{'Condition':<30} {'Base runs':>10} {'Mod runs':>10}  {'Base valid?':>12}  {'Mod valid?':>12}")
print("-"*80)
for k in all_keys_ordered:
    b_runs = base_runs.get(k, [])
    m_runs = mod_runs.get(k, [])
    bi = runs_until_valid(b_runs)
    mi = runs_until_valid(m_runs)
    print(f"{str(k):<30} {str(bi) if bi else '>6':>10} {str(mi) if mi else '>6':>10}  "
          f"{'YES' if bi else 'NO':>12}  {'YES' if mi else 'NO':>12}")

print("\n── MODEL AVERAGES (runs to valid, excluding never-valid) ──")
for m in models:
    keys = [(mdl, tmp) for (mdl, tmp) in all_keys_ordered if mdl == m]
    b_vals = [runs_until_valid(base_runs.get(k, [])) for k in keys]
    m_vals = [runs_until_valid(mod_runs.get(k, [])) for k in keys]
    b_valid = [v for v in b_vals if v]
    m_valid = [v for v in m_vals if v]
    b_rate = len(b_valid) / len(keys) * 100
    m_rate = len(m_valid) / len(keys) * 100
    print(f"  {m:<12}  Base: {len(b_valid)}/{len(keys)} valid ({b_rate:.0f}%),"
          f" avg {np.mean(b_valid):.1f} runs  |  "
          f"Mod: {len(m_valid)}/{len(keys)} valid ({m_rate:.0f}%),"
          f" avg {np.mean(m_valid):.1f} runs")
print("\nDone.")

Saved fig2_error_breakdown.png
Saved fig3_convergence.png

── ITERATION SUMMARY ──
Condition                       Base runs   Mod runs   Base valid?    Mod valid?
--------------------------------------------------------------------------------
('pro', 1.0)                            2          2           YES           YES
('pro', 0.7)                            4          3           YES           YES
('pro', 0.3)                            2          2           YES           YES
('pro', 0.1)                            2          2           YES           YES
('flash', 1.0)                         >6         >6            NO            NO
('flash', 0.7)                          4          3           YES           YES
('flash', 0.3)                          5          1           YES           YES
('flash', 0.1)                          2          3           YES           YES
('flash-lite', 1.0)                    >6         >6            NO            NO
('flash-lite', 0.7)       

In [16]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os

UPLOAD = "./valid_choreography/"
OUTPUT_DIR = "./outputs/trajectories/"

os.makedirs(OUTPUT_DIR, exist_ok=True)

FILES = [
    ("flash",       "0.1",  "base",     "gemini-2.5-flash_temp_0.1_base_prompt.json"),
    ("flash",       "0.1",  "modified", "gemini-2.5-flash_temp_0.1_modified_prompt.json"),
    ("flash",       "0.3",  "base",     "gemini-2.5-flash_temp_0.3_base_prompt.json"),
    ("flash",       "0.3",  "modified", "gemini-2.5-flash_temp_0.3_modified_prompt.json"),
    ("flash",       "0.7",  "base",     "gemini-2.5-flash_temp_0.7_base_prompt.json"),
    ("flash",       "0.7",  "modified", "gemini-2.5-flash_temp_0.7_modified_prompt.json"),
    ("flash-lite",  "0.01", "base",     "gemini-2.5-flash-lite_temp_0.01_base_prompt.json"),
    ("flash-lite",  "0.01", "modified", "gemini-2.5-flash-lite_temp_0.01_modified_prompt.json"),
    ("flash-lite",  "0.1",  "modified", "gemini-2.5-flash-lite_temp_0.1_modified_prompt.json"),
    ("flash-lite",  "0.7",  "modified", "gemini-2.5-flash-lite_temp_0.7_modified_prompt.json"),
    ("pro-preview", "0.1",  "base",     "gemini-3-pro-preview_temp_0.1_base_prompt.json"),
    ("pro-preview", "0.1",  "modified", "gemini-3-pro-preview_temp_0.1_modified_prompt.json"),
    ("pro-preview", "0.3",  "base",     "gemini-3-pro-preview_temp_0.3_base_prompt.json"),
    ("pro-preview", "0.3",  "modified", "gemini-3-pro-preview_temp_0.3_modified_prompt.json"),
    ("pro-preview", "0.7",  "base",     "gemini-3-pro-preview_temp_0.7_base_prompt.json"),
    ("pro-preview", "0.7",  "modified", "gemini-3-pro-preview_temp_0.7_modified_prompt.json"),
    ("pro-preview", "1.0",  "base",     "gemini-3-pro-preview_temp_1_base_prompt.json"),
    ("pro-preview", "1.0",  "modified", "gemini-3-pro-preview_temp_1_modified_prompt.json"),
]

# ---------------------------------------------------------
# choreography loader
# ---------------------------------------------------------
def process_choreography(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)

    dancers = data["dancers"]
    grid_size = 25
    heatmap = np.zeros((grid_size, grid_size))
    trajectories = {d: {"x": [], "y": []} for d in dancers}

    for event in data["events"]:
        dancer = event.get("dancer")
        if dancer is None:
            continue

        start_pos = event.get("from", None)
        end_pos = event.get("to", None)

        # ENTER event
        if start_pos is None and end_pos is not None:
            x = end_pos["x"]
            y = end_pos["y"]
            trajectories[dancer]["x"].append(x)
            trajectories[dancer]["y"].append(y)

            gx = min(int(round(x)), grid_size - 1)
            gy = min(int(round(y)), grid_size - 1)
            heatmap[gy, gx] += 1
            continue

        if start_pos is None or end_pos is None:
            continue

        # Movement interpolation
        steps = 10
        for i in range(steps + 1):
            t = i / steps
            x = start_pos["x"] + (end_pos["x"] - start_pos["x"]) * t
            y = start_pos["y"] + (end_pos["y"] - start_pos["y"]) * t

            trajectories[dancer]["x"].append(x)
            trajectories[dancer]["y"].append(y)

            gx = min(int(round(x)), grid_size - 1)
            gy = min(int(round(y)), grid_size - 1)
            heatmap[gy, gx] += 1

    return dancers, trajectories, heatmap


# ---------------------------------------------------------
# Run all files + collect statistics
# ---------------------------------------------------------
stats_results = []

for model, temp, prompt_type, filename in FILES:

    file_path = os.path.join(UPLOAD, filename)

    if not os.path.exists(file_path):
        print(f"Warning: {file_path} not found.")
        continue

    title = f"{model} — temp {temp} — {prompt_type}"
    save_name = f"{model}_temp{temp}_{prompt_type}.png".replace(" ", "_")

    dancers, trajectories, heatmap = process_choreography(file_path)

    # -----------------------------------------------------
    # FIGURE: Trajectories + Heatmap (saved to file)
    # -----------------------------------------------------
    fig, axes = plt.subplots(1, 2, figsize=(15, 7))
    fig.suptitle(title, fontsize=16)

    ax1, ax2 = axes

    # Trajectories
    ax1.set_title("Trajectories")
    ax1.set_xlim(-2, 26)
    ax1.set_ylim(-2, 26)
    ax1.grid(True, linestyle='--', alpha=0.5)
    ax1.set_aspect('equal')

    colors = plt.cm.rainbow(np.linspace(0, 1, len(dancers)))
    used_label_positions = []

    def get_offset_position(x, y, used_positions, offset=0.5):
        pos = np.array([x, y])
        for ux, uy in used_positions:
            if np.linalg.norm(pos - np.array([ux, uy])) < 0.6:
                pos += np.array([offset, offset])
        used_positions.append((pos[0], pos[1]))
        return pos[0], pos[1]

    for i, dancer in enumerate(dancers):
        x = np.array(trajectories[dancer]["x"])
        y = np.array(trajectories[dancer]["y"])
        color = colors[i]

        ax1.plot(x, y, marker='.', markersize=3, color=color, alpha=0.7, label=dancer)

        # Entrance
        ex, ey = x[0], y[0]
        lx, ly = get_offset_position(ex, ey, used_label_positions)
        ax1.scatter(ex, ey, color=color, s=120, edgecolor='black')
        ax1.text(lx, ly, f"{dancer} (enter)", fontsize=10, color=color)

        # Exit
        ex2, ey2 = x[-1], y[-1]
        lx2, ly2 = get_offset_position(ex2, ey2, used_label_positions)
        ax1.scatter(ex2, ey2, color=color, s=120, edgecolor='black')
        ax1.text(lx2, ly2, f"{dancer} (exit)", fontsize=10, color=color)

        # Direction arrows
        arrow_step = max(1, len(x) // 12)
        for k in range(0, len(x) - arrow_step, arrow_step):
            ax1.arrow(
                x[k], y[k],
                x[k + arrow_step] - x[k],
                y[k + arrow_step] - y[k],
                color=color,
                length_includes_head=True,
                head_width=0.4,
                head_length=0.6,
                alpha=0.8
            )

    ax1.legend()

    # Heatmap
    ax2.set_title("Heatmap")
    im = ax2.imshow(
        heatmap,
        cmap='hot_r',
        interpolation='nearest',
        origin='lower',
        extent=[0, 24, 0, 24]
    )
    ax2.set_aspect('equal')
    fig.colorbar(im, ax=ax2, label="Visit Count")

    plt.tight_layout()

    # SAVE instead of show
    save_path = os.path.join(OUTPUT_DIR, save_name)
    plt.savefig(save_path, dpi=300)
    plt.close()

    print(f"Saved: {save_path}")

    # -----------------------------------------------------
    # STATISTICS
    # -----------------------------------------------------
    coverage_count = np.count_nonzero(heatmap)
    total_cells = 25 * 25
    coverage_pct = (coverage_count / total_cells) * 100

    all_x = []
    all_y = []
    for d in dancers:
        all_x.extend(trajectories[d]["x"])
        all_y.extend(trajectories[d]["y"])

    dist_from_center = np.sqrt((np.array(all_x) - 12)**2 + (np.array(all_y) - 12)**2)
    avg_dist = np.mean(dist_from_center)

    stats_results.append({
        "File": title,
        "Coverage (%)": f"{coverage_pct:.1f}%",
        "Avg Dist from Center": f"{avg_dist:.2f} units"
    })

# ---------------------------------------------------------
# Print Summary Table
# ---------------------------------------------------------
print("\n--- Comparative Spatial Statistics ---")
print(f"{'Choreography':<40} | {'Grid Coverage':<15} | {'Avg Dist from Center':<20}")
print("-" * 90)
for stat in stats_results:
    print(f"{stat['File']:<40} | {stat['Coverage (%)']:<15} | {stat['Avg Dist from Center']:<20}")


Saved: ./outputs/trajectories/flash_temp0.1_base.png
Saved: ./outputs/trajectories/flash_temp0.1_modified.png
Saved: ./outputs/trajectories/flash_temp0.3_base.png
Saved: ./outputs/trajectories/flash_temp0.3_modified.png
Saved: ./outputs/trajectories/flash_temp0.7_base.png
Saved: ./outputs/trajectories/flash_temp0.7_modified.png
Saved: ./outputs/trajectories/flash-lite_temp0.01_base.png
Saved: ./outputs/trajectories/flash-lite_temp0.01_modified.png
Saved: ./outputs/trajectories/flash-lite_temp0.1_modified.png
Saved: ./outputs/trajectories/flash-lite_temp0.7_modified.png
Saved: ./outputs/trajectories/pro-preview_temp0.1_base.png
Saved: ./outputs/trajectories/pro-preview_temp0.1_modified.png
Saved: ./outputs/trajectories/pro-preview_temp0.3_base.png
Saved: ./outputs/trajectories/pro-preview_temp0.3_modified.png
Saved: ./outputs/trajectories/pro-preview_temp0.7_base.png
Saved: ./outputs/trajectories/pro-preview_temp0.7_modified.png
Saved: ./outputs/trajectories/pro-preview_temp1.0_base.png