In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt

# ==== Configuration ====
root = "attn_output"   # Directory containing raw txt files
out_dir = "3d_attn"    # Directory to save figures
os.makedirs(out_dir, exist_ok=True)

# Mapping model names to display names
labels = {
    "pct": "No-PE",
    "pctxyz": "Learnable-PE",
    "pctgrid": "GridPE",
    "pctrope": "Rope-Axial"
}

colors = {
    "pct": "#1f77b4",   # blue
    "pctxyz": "#2ca02c", # green
    "pctgrid": "#ff7f0e", # orange
    "pctrope": "#d62728"  # red
}

# ==== 1. Scan model names and collect raw data ====
models = sorted(
    fn[:-len("_attn_distance.txt")]
    for fn in os.listdir(root)
    if fn.endswith("_attn_distance.txt")
)

dist_raw = {m: {} for m in models}
ent_raw  = {m: {} for m in models}
max_block_idx = -1

for m in models:
    # Read distance
    with open(os.path.join(root, f"{m}_attn_distance.txt")) as f:
        cur_n = None
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith("Input points:"):
                cur_n = int(line.split(":",1)[1])
                dist_raw[m].setdefault(cur_n, [])
            elif line.startswith("Block") and cur_n is not None:
                b_idx = int(line.split()[1].rstrip(':'))
                max_block_idx = max(max_block_idx, b_idx)
                dist_raw[m][cur_n].append(float(line.split(":",1)[1]))
    # Read entropy
    with open(os.path.join(root, f"{m}_attn_entropy.txt")) as f:
        cur_n = None
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith("Input points:"):
                cur_n = int(line.split(":",1)[1])
                ent_raw[m].setdefault(cur_n, [])
            elif line.startswith("Block") and cur_n is not None:
                b_idx = int(line.split()[1].rstrip(':'))
                max_block_idx = max(max_block_idx, b_idx)
                ent_raw[m][cur_n].append(float(line.split(":",1)[1]))

# Total number of blocks
B = max_block_idx + 1

# ==== 2. Compute means ====
dist_mean = {m: {} for m in models}
ent_mean  = {m: {} for m in models}

for m in models:
    for num in sorted(dist_raw[m].keys()):
        # Distance
        vals_d = dist_raw[m][num]
        Rd = len(vals_d) // B
        arr_d = np.array(vals_d[:Rd*B]).reshape(Rd, B)
        dist_mean[m][num] = arr_d.mean(axis=0)
        # Entropy
        vals_e = ent_raw[m][num]
        Re = len(vals_e) // B
        arr_e = np.array(vals_e[:Re*B]).reshape(Re, B)
        ent_mean[m][num] = arr_e.mean(axis=0)

# ==== 3. Plot and save at 300 dpi ====
for num in sorted(dist_mean[models[0]].keys()):
    x = np.arange(B)

    # --- Attention Distance ---
    fig_d, axd = plt.subplots(figsize=(6, 4), dpi=100)
    for m in models:
        axd.plot(x, dist_mean[m][num], marker='o',
                label=labels.get(m, m), color=colors.get(m, 'gray'))
    axd.set_xlabel("Block Index", fontsize=14)
    axd.set_ylabel("Distance", fontsize=14)
    axd.set_xticks(x)
    axd.grid(True, linestyle="--", alpha=0.5)
    axd.legend(loc='best', fontsize=12)
    fn_d = os.path.join(out_dir, f"attention_distance_N{num}.png")
    fig_d.savefig(fn_d, dpi=300, bbox_inches='tight')
    plt.close(fig_d)

    # --- Attention Entropy ---
    fig_e, axe = plt.subplots(figsize=(6, 4), dpi=100)
    for m in models:
        axe.plot(x, ent_mean[m][num], marker='s',
                label=labels.get(m, m), color=colors.get(m, 'gray'))
    axe.set_xlabel("Block Index", fontsize=14)
    axe.set_ylabel("Entropy", fontsize=14)
    axe.set_xticks(x)
    axe.grid(True, linestyle="--", alpha=0.5)
    axe.legend(loc='best', fontsize=12)
    fn_e = os.path.join(out_dir, f"attention_entropy_N{num}.png")
    fig_e.savefig(fn_e, dpi=300, bbox_inches='tight')
    plt.close(fig_e)

print(f"✅ Saved plots to '{out_dir}/' for input sizes: {sorted(dist_mean[models[0]].keys())}")

✅ Saved plots to '3d_attn/' for input sizes: [256, 384, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048]
