# Attention Weight Analysis

**Ziel:** Untersuchen, warum Attention-Mechanismen keinen Mehrwert gegenüber der LSTM-Baseline liefern.

**Modelle:**
- M4: Small LSTM + Simple Attention (64 hidden, 3L)
- M6: Medium LSTM + Simple Attention (128 hidden, 5L)
- M7: Medium LSTM + Additive Attention (128 hidden, 5L)
- M8: Medium LSTM + Scaled Dot-Product Attention (128 hidden, 5L)

**Analysen:**
1. Entropie — Wie scharf/uniform ist die Attention?
2. Temporales Profil — Wohin fokussiert die Attention?
3. Konvergenz über Epochen — Lernt die Attention etwas?
4. Cross-Seed Konsistenz — Sind Patterns stabil?
5. Modellvergleich — Qualitative Unterschiede

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path
from scipy.stats import entropy as scipy_entropy

plt.rcParams.update({
    'figure.figsize': (14, 6),
    'font.size': 11,
    'axes.titlesize': 13,
    'axes.labelsize': 11,
    'legend.fontsize': 9,
})

BASE_DIR = Path('../attention_weights')
SEEDS = [7, 42, 94, 123, 231]
N_TIMESTEPS = 50
MAX_ENTROPY = np.log(N_TIMESTEPS)  # uniform distribution entropy

MODELS = {
    'M4': {'prefix': 'M4_Small_Simple_Attention', 'label': 'M4 Small Simple', 'color': '#2196F3'},
    'M6': {'prefix': 'M6_Medium_Simple_Attention', 'label': 'M6 Medium Simple', 'color': '#4CAF50'},
    'M7': {'prefix': 'M7_Medium_Additive_Attention', 'label': 'M7 Medium Additive', 'color': '#FF9800'},
    'M8': {'prefix': 'M8_Medium_Scaled_DP_Attention', 'label': 'M8 Medium Scaled DP', 'color': '#F44336'},
}

print(f'Max entropy (uniform over {N_TIMESTEPS} steps): {MAX_ENTROPY:.4f}')

In [None]:
def load_all_epochs(model_key: str, seed: int) -> dict[int, np.ndarray]:
    """Load all epoch attention weights for a model/seed combination."""
    dirname = f"{MODELS[model_key]['prefix']}_seed{seed}"
    dirpath = BASE_DIR / dirname
    if not dirpath.exists():
        return {}
    epochs = {}
    for f in sorted(dirpath.glob('attention_epoch_*.npy')):
        epoch_num = int(f.stem.split('_')[-1])
        epochs[epoch_num] = np.load(f)
    return epochs


def load_test_weights(model_key: str, seed: int) -> np.ndarray | None:
    """Load test-time attention weights if available."""
    dirname = f"{MODELS[model_key]['prefix']}_seed{seed}"
    fpath = BASE_DIR / dirname / 'attention_test.npy'
    if fpath.exists():
        return np.load(fpath)
    return None


def get_final_weights(model_key: str, seed: int) -> np.ndarray | None:
    """Get final epoch attention weights (test if available, else last epoch)."""
    test = load_test_weights(model_key, seed)
    if test is not None:
        return test
    epochs = load_all_epochs(model_key, seed)
    if epochs:
        return epochs[max(epochs.keys())]
    return None


def normalized_entropy(weights: np.ndarray) -> float:
    """Compute normalized entropy (0=peaked, 1=uniform)."""
    w = np.clip(weights, 1e-12, None)
    w = w / w.sum()
    return scipy_entropy(w) / MAX_ENTROPY


# Load everything
all_data = {}
for mk in MODELS:
    all_data[mk] = {}
    for seed in SEEDS:
        all_data[mk][seed] = {
            'epochs': load_all_epochs(mk, seed),
            'test': load_test_weights(mk, seed),
            'final': get_final_weights(mk, seed),
        }

# Summary
for mk in MODELS:
    for seed in SEEDS:
        n_ep = len(all_data[mk][seed]['epochs'])
        has_test = all_data[mk][seed]['test'] is not None
        print(f"{MODELS[mk]['label']:30s} seed={seed:3d}  epochs={n_ep:3d}  test={'yes' if has_test else 'no'}")

## 1. Entropie-Analyse

Normierte Entropie: 0 = Attention fokussiert auf einen einzigen Zeitschritt, 1 = komplett uniform (= nutzlos, Attention ignoriert).

Wenn die Entropie nahe 1 liegt, gewichtet die Attention alle Zeitschritte gleich und liefert damit keinen Informationsgewinn gegenüber einem einfachen Mittelwert.

In [None]:
# Compute normalized entropy for final weights across all seeds
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Bar chart of mean entropy per model
ax = axes[0]
model_keys = list(MODELS.keys())
entropies_per_model = {}
for mk in model_keys:
    ents = []
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            ents.append(normalized_entropy(w))
    entropies_per_model[mk] = ents

means = [np.mean(entropies_per_model[mk]) for mk in model_keys]
stds = [np.std(entropies_per_model[mk]) for mk in model_keys]
colors = [MODELS[mk]['color'] for mk in model_keys]
labels = [MODELS[mk]['label'] for mk in model_keys]

bars = ax.bar(labels, means, yerr=stds, color=colors, alpha=0.8, capsize=5, edgecolor='black', linewidth=0.5)
ax.axhline(y=1.0, color='gray', linestyle='--', linewidth=1, label='Uniform (max entropy)')
ax.set_ylabel('Normierte Entropie')
ax.set_title('Entropie der finalen Attention Weights')
ax.set_ylim(0, 1.1)
ax.legend()

# Add value labels on bars
for bar, m, s in zip(bars, means, stds):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + s + 0.02,
            f'{m:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

# Right: Per-seed scatter
ax = axes[1]
for i, mk in enumerate(model_keys):
    for j, seed in enumerate(SEEDS):
        w = all_data[mk][seed]['final']
        if w is not None:
            e = normalized_entropy(w)
            ax.scatter(i + (j - 2) * 0.08, e, color=MODELS[mk]['color'],
                       s=60, edgecolors='black', linewidth=0.5, zorder=3)

ax.axhline(y=1.0, color='gray', linestyle='--', linewidth=1, label='Uniform')
ax.set_xticks(range(len(model_keys)))
ax.set_xticklabels(labels)
ax.set_ylabel('Normierte Entropie')
ax.set_title('Entropie pro Seed')
ax.set_ylim(0, 1.1)
ax.legend()

plt.tight_layout()
plt.show()

# Print table
print(f"{'Modell':30s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s}")
print('-' * 64)
for mk in model_keys:
    ents = entropies_per_model[mk]
    print(f"{MODELS[mk]['label']:30s} {np.mean(ents):8.4f} {np.std(ents):8.4f} {np.min(ents):8.4f} {np.max(ents):8.4f}")

## 2. Temporales Profil

Wo im 50-Schritt-Window liegt der Attention-Fokus? Zeitschritt 0 = ältester, Zeitschritt 49 = aktuellster.

Hypothese: Wenn die Attention hauptsächlich auf die letzten Schritte fokussiert, ist sie redundant zum LSTM-Hidden-State (der ohnehin den letzten Zeitschritt am stärksten repräsentiert).

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for idx, mk in enumerate(model_keys):
    ax = axes[idx // 2, idx % 2]
    
    # Collect final weights across seeds
    all_weights = []
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            # Normalize to sum to 1
            w_norm = w / w.sum()
            all_weights.append(w_norm)
            ax.plot(range(N_TIMESTEPS), w_norm, alpha=0.3, color=MODELS[mk]['color'], linewidth=1)
    
    if all_weights:
        mean_w = np.mean(all_weights, axis=0)
        std_w = np.std(all_weights, axis=0)
        ax.plot(range(N_TIMESTEPS), mean_w, color=MODELS[mk]['color'], linewidth=2.5, label='Mean')
        ax.fill_between(range(N_TIMESTEPS), mean_w - std_w, mean_w + std_w,
                        alpha=0.2, color=MODELS[mk]['color'])
    
    # Uniform reference
    ax.axhline(y=1/N_TIMESTEPS, color='gray', linestyle='--', linewidth=1, label=f'Uniform (1/{N_TIMESTEPS})')
    
    ax.set_title(f"{MODELS[mk]['label']}")
    ax.set_xlabel('Zeitschritt (0=ältester, 49=neuester)')
    ax.set_ylabel('Attention Weight')
    ax.legend(loc='upper left')
    ax.set_xlim(0, 49)

plt.suptitle('Temporales Attention-Profil (finale Weights, alle Seeds)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Quantitative: How much weight is in the last 5, 10, 20 timesteps?
print(f"{'Modell':30s} {'Last 5':>10s} {'Last 10':>10s} {'Last 20':>10s} {'Peak Pos':>10s}")
print('-' * 75)

for mk in model_keys:
    last5, last10, last20, peaks = [], [], [], []
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            w_norm = w / w.sum()
            last5.append(w_norm[-5:].sum())
            last10.append(w_norm[-10:].sum())
            last20.append(w_norm[-20:].sum())
            peaks.append(np.argmax(w_norm))
    
    print(f"{MODELS[mk]['label']:30s} "
          f"{np.mean(last5)*100:9.1f}% "
          f"{np.mean(last10)*100:9.1f}% "
          f"{np.mean(last20)*100:9.1f}% "
          f"{np.mean(peaks):9.1f}")

## 3. Konvergenz über Epochen

Wie entwickelt sich die Attention über das Training? Konvergiert sie zu einem stabilen, informativen Pattern — oder kollabiert sie zu einer uniformen Verteilung?

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for idx, mk in enumerate(model_keys):
    ax = axes[idx // 2, idx % 2]
    
    for seed in SEEDS:
        epochs_data = all_data[mk][seed]['epochs']
        if not epochs_data:
            continue
        ep_nums = sorted(epochs_data.keys())
        ent_vals = [normalized_entropy(epochs_data[e]) for e in ep_nums]
        ax.plot(ep_nums, ent_vals, marker='.', markersize=3, alpha=0.7,
                label=f'Seed {seed}', linewidth=1.2)
    
    ax.axhline(y=1.0, color='gray', linestyle='--', linewidth=1, label='Uniform')
    ax.set_title(f"{MODELS[mk]['label']}")
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Normierte Entropie')
    ax.set_ylim(0, 1.1)
    ax.legend(fontsize=8, ncol=2)

plt.suptitle('Entropie-Verlauf über Epochen (pro Seed)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Heatmaps: Attention profile evolution over epochs (seed 42 as representative)
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
REPR_SEED = 42

for idx, mk in enumerate(model_keys):
    ax = axes[idx // 2, idx % 2]
    epochs_data = all_data[mk][REPR_SEED]['epochs']
    
    if not epochs_data:
        ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(f"{MODELS[mk]['label']} (seed {REPR_SEED})")
        continue
    
    ep_nums = sorted(epochs_data.keys())
    matrix = np.array([epochs_data[e] / epochs_data[e].sum() for e in ep_nums])
    
    im = ax.imshow(matrix, aspect='auto', cmap='hot', interpolation='nearest',
                   extent=[0, N_TIMESTEPS-1, ep_nums[-1], ep_nums[0]])
    ax.set_title(f"{MODELS[mk]['label']} (seed {REPR_SEED})")
    ax.set_xlabel('Zeitschritt')
    ax.set_ylabel('Epoch')
    plt.colorbar(im, ax=ax, label='Weight', shrink=0.8)

plt.suptitle('Attention-Profil über Epochen (Heatmap)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Cross-Seed Konsistenz

Wenn die Attention ein robustes Signal lernt, sollten die Patterns über verschiedene Seeds hinweg konsistent sein. Hohe Varianz = die Attention lernt keine stabile Repräsentation.

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

cosine_sims = {}

for idx, mk in enumerate(model_keys):
    ax = axes[idx // 2, idx % 2]
    
    # Collect final weights per seed
    seed_weights = {}
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            seed_weights[seed] = w / w.sum()
    
    if len(seed_weights) < 2:
        continue
    
    # Pairwise cosine similarity matrix
    seeds_list = sorted(seed_weights.keys())
    n = len(seeds_list)
    sim_matrix = np.ones((n, n))
    
    for i in range(n):
        for j in range(i+1, n):
            wi = seed_weights[seeds_list[i]]
            wj = seed_weights[seeds_list[j]]
            cos_sim = np.dot(wi, wj) / (np.linalg.norm(wi) * np.linalg.norm(wj))
            sim_matrix[i, j] = cos_sim
            sim_matrix[j, i] = cos_sim
    
    # Store for summary
    upper_tri = sim_matrix[np.triu_indices(n, k=1)]
    cosine_sims[mk] = upper_tri
    
    im = ax.imshow(sim_matrix, cmap='RdYlGn', vmin=0.5, vmax=1.0, interpolation='nearest')
    ax.set_xticks(range(n))
    ax.set_xticklabels([f'S{s}' for s in seeds_list])
    ax.set_yticks(range(n))
    ax.set_yticklabels([f'S{s}' for s in seeds_list])
    ax.set_title(f"{MODELS[mk]['label']}\nmean cos_sim={np.mean(upper_tri):.4f}")
    plt.colorbar(im, ax=ax, label='Cosine Similarity', shrink=0.8)
    
    # Annotate cells
    for i in range(n):
        for j in range(n):
            ax.text(j, i, f'{sim_matrix[i,j]:.3f}', ha='center', va='center', fontsize=9,
                    color='white' if sim_matrix[i,j] < 0.75 else 'black')

plt.suptitle('Cross-Seed Konsistenz (Cosine Similarity der finalen Weights)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Summary
print(f"\n{'Modell':30s} {'Mean CosSim':>12s} {'Std':>8s} {'Min':>8s}")
print('-' * 60)
for mk in model_keys:
    if mk in cosine_sims:
        sims = cosine_sims[mk]
        print(f"{MODELS[mk]['label']:30s} {np.mean(sims):12.4f} {np.std(sims):8.4f} {np.min(sims):8.4f}")

In [None]:
# Coefficient of Variation per timestep across seeds
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for idx, mk in enumerate(model_keys):
    ax = axes[idx // 2, idx % 2]
    
    weights_matrix = []
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            weights_matrix.append(w / w.sum())
    
    if len(weights_matrix) < 2:
        continue
    
    weights_matrix = np.array(weights_matrix)
    mean_per_step = weights_matrix.mean(axis=0)
    std_per_step = weights_matrix.std(axis=0)
    cv = std_per_step / (mean_per_step + 1e-12)
    
    ax.bar(range(N_TIMESTEPS), cv, color=MODELS[mk]['color'], alpha=0.7, edgecolor='black', linewidth=0.3)
    ax.set_title(f"{MODELS[mk]['label']} — mean CV={np.mean(cv):.2f}")
    ax.set_xlabel('Zeitschritt')
    ax.set_ylabel('Variationskoeffizient (std/mean)')
    ax.set_xlim(-0.5, N_TIMESTEPS - 0.5)

plt.suptitle('Variabilität pro Zeitschritt über Seeds (Coefficient of Variation)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Modellvergleich

Direkter Vergleich der gemittelten Attention-Profile aller vier Mechanismen.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Left: Overlay of mean attention profiles
ax = axes[0]
for mk in model_keys:
    weights = []
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            weights.append(w / w.sum())
    if weights:
        mean_w = np.mean(weights, axis=0)
        std_w = np.std(weights, axis=0)
        ax.plot(range(N_TIMESTEPS), mean_w, color=MODELS[mk]['color'],
                linewidth=2, label=MODELS[mk]['label'])
        ax.fill_between(range(N_TIMESTEPS), mean_w - std_w, mean_w + std_w,
                        alpha=0.1, color=MODELS[mk]['color'])

ax.axhline(y=1/N_TIMESTEPS, color='gray', linestyle='--', linewidth=1, label='Uniform')
ax.set_xlabel('Zeitschritt (0=ältester, 49=neuester)')
ax.set_ylabel('Attention Weight')
ax.set_title('Mittleres Attention-Profil (über 5 Seeds)')
ax.legend()
ax.set_xlim(0, 49)

# Right: Summary metrics
ax = axes[1]
metrics_data = []
metric_labels = ['Norm. Entropie', 'Gew. in letzten 10', 'Cross-Seed CosSim']

for mk in model_keys:
    ents = entropies_per_model[mk]
    
    last10_vals = []
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            w_norm = w / w.sum()
            last10_vals.append(w_norm[-10:].sum())
    
    cos_mean = np.mean(cosine_sims[mk]) if mk in cosine_sims else 0
    
    metrics_data.append([np.mean(ents), np.mean(last10_vals), cos_mean])

metrics_data = np.array(metrics_data)
x = np.arange(len(metric_labels))
width = 0.18

for i, mk in enumerate(model_keys):
    ax.bar(x + i * width, metrics_data[i], width, color=MODELS[mk]['color'],
           alpha=0.8, label=MODELS[mk]['label'], edgecolor='black', linewidth=0.5)

ax.set_xticks(x + width * 1.5)
ax.set_xticklabels(metric_labels)
ax.set_ylabel('Wert')
ax.set_title('Zusammenfassung der Attention-Eigenschaften')
ax.legend()
ax.set_ylim(0, 1.1)

plt.tight_layout()
plt.show()

In [None]:
# KL divergence from uniform distribution
print('KL-Divergenz von Uniform-Verteilung (höher = fokussierter):')
print(f"{'Modell':30s} {'KL(attn||uniform)':>18s} {'Std':>8s}")
print('-' * 60)

uniform = np.ones(N_TIMESTEPS) / N_TIMESTEPS

for mk in model_keys:
    kl_vals = []
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            w_norm = w / w.sum()
            w_norm = np.clip(w_norm, 1e-12, None)
            kl = scipy_entropy(w_norm, uniform)
            kl_vals.append(kl)
    print(f"{MODELS[mk]['label']:30s} {np.mean(kl_vals):18.6f} {np.std(kl_vals):8.6f}")

## 6. Zusammenfassung

In [None]:
print('=' * 80)
print('ZUSAMMENFASSUNG: Warum Attention nicht hilft')
print('=' * 80)

for mk in model_keys:
    ent_mean = np.mean(entropies_per_model[mk])
    cos_mean = np.mean(cosine_sims[mk]) if mk in cosine_sims else float('nan')
    
    last10_vals = []
    peak_positions = []
    for seed in SEEDS:
        w = all_data[mk][seed]['final']
        if w is not None:
            w_norm = w / w.sum()
            last10_vals.append(w_norm[-10:].sum())
            peak_positions.append(np.argmax(w_norm))
    
    print(f"\n--- {MODELS[mk]['label']} ---")
    print(f"  Entropie:         {ent_mean:.4f} (1.0 = uniform)")
    print(f"  Cross-Seed Sim:   {cos_mean:.4f}")
    print(f"  Gewicht Last 10:  {np.mean(last10_vals)*100:.1f}% (erwartbar uniform: 20%)")
    print(f"  Peak-Positionen:  {peak_positions}")
    
    # Diagnosis
    if ent_mean > 0.95:
        print(f"  >> DIAGNOSE: Nahezu uniform — Attention lernt keine Differenzierung")
    elif ent_mean > 0.85:
        print(f"  >> DIAGNOSE: Schwach fokussiert — minimaler Informationsgewinn")
    else:
        print(f"  >> DIAGNOSE: Fokussiert — Attention lernt ein Pattern")
    
    if cos_mean < 0.9:
        print(f"  >> DIAGNOSE: Instabil über Seeds — kein robustes Signal")
    else:
        print(f"  >> DIAGNOSE: Stabil über Seeds")

print(f"\n{'=' * 80}")
print('FAZIT:')
print('Die Attention-Mechanismen lernen keine informative Gewichtung der Zeitschritte.')
print('Mögliche Erklärungen:')
print('  1. Das LSTM kodiert die relevante temporale Information bereits im Hidden State')
print('     -> Attention über Hidden States ist redundant')
print('  2. Die Aufgabe (Steering Torque) hängt primär vom aktuellen Zustand ab,')
print('     nicht von komplexen zeitlichen Mustern -> kein Attention-Vorteil')
print('  3. Bei 50 Zeitschritten @ 10Hz (5s Window) gibt es wenig langreichweitige')
print('     Abhängigkeiten, die Attention besser als LSTM erfassen könnte')
print(f"{'=' * 80}")