In [None]:
DIR = "/Users/philipp/Documents/Studium/Informatik/Semester_3/RL/IntrinsicAttention/data/Umbrella/Umbrella_intrinsic_2025-08-11_11-58-04/IntrinsicAttentionPPO_seed42_length50/IntrinsicAttentionPPO_Umbrella_acd0b_00000_0_2025-08-11_11-58-05/result.json"

In [None]:
import json

import torch

data = []
with open(DIR, 'r', encoding='utf-8') as f:
    for line in f:
        line = line.strip()
        if line:
            data.append(json.loads(line))

#mask has shape: [Total Batches, Minibatches per Batch, Intrinsic Reward Output, Source Input]
masks = torch.stack(tuple(torch.Tensor(row["learners"]["IntrinsicAttentionMask"]) for row in data))  # first object

In [None]:
used_mask = masks[:, -1, 0, :].transpose(1, 0)
save_name = "attention-first_reward"
title = "First Intr. Rew. Attention on Episode Timesteps"

In [None]:
import torch


def equalized_to_onehot_with_yfade(
        inp: torch.Tensor,
        cutoff: float = 0.75,  # after this x-fraction: exactly one-hot per column
        hot_row: str | int = "bottom",  # "bottom", "top", or an integer row index
        p: float = 2.0,  # column L^p norm target (e.g., 2.0=L2, 1.0=L1)
        sharpness: float = 1.6,  # >1 = steeper x-fade (smoothstep exponent)
        sigma_min: float = 0.25,  # minimal vertical spread (rows) just before cutoff
        sigma_max: float | None = None,  # maximal vertical spread near x=0; None -> 0.5*M
        alpha: float = 2.0,  # 2.0=Gaussian, 1.0=Laplacian-like tails
) -> torch.Tensor:
    """
    Build an (M x N) matrix:
      • At x=0: equalized columns (row-invariant), column ‖·‖_p = 1.
      • 0 < x < cutoff: blend toward a vertical 'bump' centered at hot_row,
        with sigma shrinking as x increases (y-fade).
      • x >= cutoff: exactly one-hot at hot_row (others = 0).
      • Every column is normalized to L^p norm = 1.

    Notes
    -----
    - y index 0 is the top row; 'bottom' means row M-1.
    - sigma_* are in **row units** (e.g., 1.0 ~ one row).
    """
    assert inp.ndim == 2, "inp must be 2D (M x N)"
    M, N = inp.shape
    device = inp.device
    f32 = torch.float32

    # hot row index
    if isinstance(hot_row, str):
        row0 = 0 if hot_row.lower().startswith("top") else (M - 1)
    else:
        row0 = int(hot_row)
        if not (0 <= row0 < M):
            raise ValueError(f"hot_row index {row0} out of range [0, {M - 1}]")

    if sigma_max is None:
        sigma_max = max(0.5 * M, 4.0)  # large spread ≈ 'flat' early on

    # equalized column so that L^p norm = 1 → each entry = M^{-1/p}
    v = float(M) ** (-1.0 / p)
    equal_col = torch.full((M, N), v, device=device, dtype=f32)

    # one-hot template (norm 1 for any p)
    onehot = torch.zeros((M, N), device=device, dtype=f32)
    onehot[row0, :] = 1.0

    # x progression 0..1
    x = torch.linspace(0, 1, N, device=device, dtype=f32)

    # smooth progress s(x) from 0→1 over [0, cutoff], then clamp
    t = (x / max(1e-8, cutoff)).clamp(0, 1)
    t = t.pow(sharpness)
    s = t * t * (3 - 2 * t)  # smoothstep

    # sigma(x): large → small as x increases (y-fade width shrinks)
    sigma_x = (1 - s) * sigma_max + s * sigma_min  # shape (N,)

    # vertical distances to hot row
    y_idx = torch.arange(M, device=device, dtype=f32).unsqueeze(1)  # (M,1)
    d = (y_idx - float(row0)).abs()  # (M,1)

    # column-wise vertical bumps (Gauss/Laplace-like with alpha)
    # broadcast sigma_x to (1,N)
    denom = sigma_x.clamp_min(1e-6).unsqueeze(0)  # (1,N)
    bump = torch.exp(-0.5 * (d / denom) ** alpha)  # (M,N)

    # Blend equalized ↔ bump for x < cutoff
    mat = (1 - s).unsqueeze(0) * equal_col + s.unsqueeze(0) * bump

    # Hard switch to exact one-hot for x >= cutoff
    hard_mask = (x >= cutoff).unsqueeze(0)  # (1,N)
    mat = torch.where(hard_mask, onehot, mat)

    # Column-wise L^p normalization to exactly 1
    if abs(p - 1.0) < 1e-12:
        norms = mat.abs().sum(dim=0, keepdim=True)
    else:
        norms = mat.abs().pow(p).sum(dim=0, keepdim=True).pow(1.0 / p)
    mat = mat / norms.clamp_min(1e-12)

    return mat.to(dtype=inp.dtype, device=device)

In [None]:
# LUH-style attention heatmap (matplotlib-only, research aesthetic)
# - API compatible with your neon_attention_heatmap (same args & returns).
# - Defaults: white background, LUH blue sequential colormap, subtle grid.
# - No neon/gloss; nice sans-serif font with sensible fallbacks.
#
# Usage (same as before)
# -----
# fig, paths = luh_attention_heatmap(
#     matrix,
#     title="Head 3 — Layer 8", xlabel="Key position →", ylabel="Query position →",
#     tokens_x=[...], tokens_y=[...],
#     figsize=(10, 8), dpi=220,
#     gamma=0.65, clip_percentile=(0, 100),
#     render_mode="auto",
#     edge_gloss=True,              # now means a subtle light-grey grid
#     gloss_strength=0.55,          # grid visibility (0..1), not glossy lines
#     annotate=False,
#     save_path="/mnt/data/attention_luh.png"
# )
#
from typing import Optional, Sequence, Tuple, Union
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, PowerNorm, Normalize
from matplotlib.ticker import MaxNLocator
import os, math

try:
    import torch

    _HAS_TORCH = True
except Exception:
    _HAS_TORCH = False

ArrayLike = Union[np.ndarray, "torch.Tensor"]


def _to_numpy_2d(a: ArrayLike) -> np.ndarray:
    if _HAS_TORCH and isinstance(a, torch.Tensor):
        a = a.detach().cpu().float().numpy()
    a = np.asarray(a, dtype=float)
    if a.ndim != 2:
        raise ValueError(f"Expected 2D matrix, got shape {a.shape}.")
    return a


def _luh_cmap(kind: str = "seq") -> LinearSegmentedColormap:
    """
    LUH corporate palette:
      - Uniblau  = #00509B (primary)
      - Unigrün  = #C8D317 (accent)
    Source: LUH Corporate Identity – Farben.
    """
    if kind == "div":
        # Blue → white → green (for data centered around a midpoint)
        stops = [
            (0.00, "#00509B"),
            (0.50, "#FFFFFF"),
            (1.00, "#C8D317"),
        ]
        name = "luh_blue_white_green"
    else:
        # Sequential: white → Uniblau (research-friendly, high legibility)
        stops = [
            (0.00, "#FFFFFF"),
            (1.00, "#00509B"),
        ]
        name = "luh_white_to_uniblau"
    return LinearSegmentedColormap.from_list(name, stops)


def _resolve_norm(a: np.ndarray, gamma: Optional[float], vmin: Optional[float], vmax: Optional[float], clip_percentile):
    if clip_percentile is not None:
        lo, hi = clip_percentile
        vmin = np.percentile(a, lo) if vmin is None else vmin
        vmax = np.percentile(a, hi) if vmax is None else vmax
    else:
        vmin = np.nanmin(a) if vmin is None else vmin
        vmax = np.nanmax(a) if vmax is None else vmax
    if not np.isfinite(vmin): vmin = 0.0
    if not np.isfinite(vmax): vmax = 1.0
    if vmin == vmax: vmax = vmin + 1e-9
    if gamma is not None and gamma > 0 and gamma != 1.0:
        norm = PowerNorm(gamma=gamma, vmin=vmin, vmax=vmax, clip=True)
    else:
        norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
    return norm, vmin, vmax


def luh_attention_heatmap(
        matrix: ArrayLike,
        title: Optional[str] = "Attention Heatmap",
        xlabel: str = "Key position →",
        ylabel: str = "Query position →",
        tokens_x: Optional[Sequence[str]] = None,  # len == N
        tokens_y: Optional[Sequence[str]] = None,  # len == M
        figsize: Tuple[float, float] = (10, 8),
        dpi: int = 220,
        cmap: Optional[LinearSegmentedColormap] = None,  # default set to LUH below
        gamma: float = 0.65,  # keep same default as original for identical value mapping
        vmin: Optional[float] = None,
        vmax: Optional[float] = None,
        clip_percentile: Optional[Tuple[float, float]] = (0.0, 100.0),
        render_mode: str = "auto",  # "auto" | "imshow" | "pcolormesh"
        edge_gloss: bool = True,  # now: subtle light-grey grid (no gloss)
        gloss_strength: float = 0.55,
        annotate: bool = False,
        save_path: Optional[str] = None,
        tick_max: int = 32,
        tick_labelsize: int = 8,
        rotate_xticks: int = 90,
        background: str = "#FFFFFF",  # white canvas (research style)
        diverging: bool = False,  # set True for blue↔green around mid
):
    """
    Draw a LUH-themed heatmap on a white canvas.
    Works for any 2D matrix shape (M x N).
    """
    a = _to_numpy_2d(matrix)
    M, N = a.shape

    # Use LUH palette unless a custom cmap is supplied
    if cmap is None:
        cmap = _luh_cmap("div" if diverging else "seq")

    norm, vmin, vmax = _resolve_norm(a, gamma, vmin, vmax, clip_percentile)

    # Font & styling (contained to this function)
    rc = {
        "font.family": ["DejaVu Sans", "Arial", "Helvetica"],
        "axes.titleweight": "semibold",
        "axes.titlesize": 14,
        "axes.labelsize": 11,
        "xtick.labelsize": tick_labelsize,
        "ytick.labelsize": tick_labelsize,
    }

    with plt.rc_context(rc):
        fig = plt.figure(figsize=figsize, dpi=dpi)
        fig.patch.set_facecolor(background)
        ax = plt.gca()
        ax.set_facecolor(background)

        total_cells = M * N
        if render_mode == "auto":
            render_mode = "pcolormesh" if total_cells <= 20000 else "imshow"

        grid_color = (0.85, 0.85, 0.85, min(0.8, 0.25 + 0.6 * gloss_strength))  # subtle neutral grey

        # Main image
        if render_mode == "imshow":
            im = ax.imshow(a, origin="upper", aspect="equal", interpolation="nearest",
                           cmap=cmap, norm=norm)
            if edge_gloss:
                ax.set_xticks(np.arange(-0.5, N, 1), minor=True)
                ax.set_yticks(np.arange(-0.5, M, 1), minor=True)
                ax.grid(which="minor", linestyle="-", linewidth=0.6, color=grid_color)
                ax.tick_params(axis="both", which="both", length=0)
        else:
            x = np.arange(N + 1)
            y = np.arange(M + 1)
            if edge_gloss:
                ax.pcolormesh(x, y, a, cmap=cmap, norm=norm, shading="flat",
                              edgecolors=grid_color, linewidth=0.6)
                im = ax.collections[-1]
            else:
                im = ax.pcolormesh(x, y, a, cmap=cmap, norm=norm, shading="flat",
                                   edgecolors="none", linewidth=0.0)

        # Labels & title (dark text for legibility)
        if title:
            ax.set_title(title, color="#111111", pad=12)
        ax.set_xlabel(xlabel, color="#111111", labelpad=6)
        ax.set_ylabel(ylabel, color="#111111", labelpad=6)

        # Tick labeling
        ax.tick_params(axis="both", labelcolor="#111111", length=0)
        # X ticks
        if tokens_x is not None and len(tokens_x) == N:
            step_x = max(1, int(math.ceil(N / tick_max)))
            idx = list(range(0, N, step_x))
            if idx[-1] != N - 1:
                idx.append(N - 1)
            ax.set_xticks(idx)
            ax.set_xticklabels([str(tokens_x[i]) for i in idx], rotation=rotate_xticks, ha="center", va="top")
        else:
            ax.xaxis.set_major_locator(MaxNLocator(nbins=min(tick_max, 12), integer=True))
        # Y ticks
        if tokens_y is not None and len(tokens_y) == M:
            step_y = max(1, int(math.ceil(M / tick_max)))
            idy = list(range(0, M, step_y))
            if idy[-1] != M - 1:
                idy.append(M - 1)
            ax.set_yticks(idy)
            ax.set_yticklabels([str(tokens_y[i]) for i in idy])
        else:
            ax.yaxis.set_major_locator(MaxNLocator(nbins=min(tick_max, 12), integer=True))

        # Remove spines for a clean research look
        for spine in ax.spines.values():
            spine.set_visible(False)

        # Colorbar (neutral text & ticks)
        cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.outline.set_visible(False)
        cbar.ax.tick_params(labelsize=8, colors="#111111")
        cbar.ax.yaxis.set_tick_params(color="#111111")
        cbar.ax.set_ylabel("Intensity", rotation=270, labelpad=14, color="#111111", fontsize=9)

        # Optional per-cell annotations (small matrices)
        if annotate and total_cells <= 900:
            for i in range(M):
                for j in range(N):
                    val = a[i, j]
                    if np.isfinite(val):
                        ax.text(j if render_mode == "imshow" else j + 0.5,
                                i if render_mode == "imshow" else i + 0.5,
                                f"{val:.2f}", ha="center", va="center", fontsize=6, color="#111111", alpha=0.85)

        plt.tight_layout()

        paths = None
        if save_path:
            fig.savefig(save_path, bbox_inches="tight", facecolor=fig.get_facecolor(), dpi=max(dpi, 300))
            root, _ = os.path.splitext(save_path)
            pdf_path = root + ".pdf"
            fig.savefig(pdf_path, bbox_inches="tight", facecolor=fig.get_facecolor())
            paths = (save_path, pdf_path)

        return fig, paths

In [None]:
used_time_step = 49  # 0 indexed, 50 is the last "Add one T" Step
fig, paths = luh_attention_heatmap(
    used_mask,
    title=f"{title} • Umbrella Env",
    xlabel="Training Rounds",
    ylabel="Episode Input Step",
    tokens_x=[f"k{j}" for j in range(used_mask.shape[0])],
    tokens_y=[f"q{i}" for i in range(used_mask.shape[1])],
    figsize=(5, 5),
    dpi=240,
    gamma=0.55,
    clip_percentile=(0.5, 99.5),
    render_mode="auto",
    edge_gloss=True,
    gloss_strength=0.7,
    annotate=False,
    save_path=f"../images/{save_name}.png",
    tick_max=36,
    rotate_xticks=90
)

In [None]:


template = equalized_to_onehot_with_yfade(
    used_mask,
    cutoff=0.85,  # fully one-hot in the last 25% of columns
    hot_row="bottom",  # bottom-right will be the final hot pixel
    p=2.0,  # L2 per-column norm
    sharpness=1.6,
    sigma_min=0.2,  # ~sub-row width as it approaches one-hot
    sigma_max=None,  # auto based on M
    alpha=2.0  # Gaussian; try 1.0 for heavier tails
)

fig, paths = luh_attention_heatmap(
    template,
    title=f"Optimal • {title} • Umbrella Env",
    xlabel="Training Rounds",
    ylabel="Episode Input Step",
    tokens_x=[f"k{j}" for j in range(template.shape[0])],
    tokens_y=[f"q{i}" for i in range(template.shape[1])],
    figsize=(5, 5),
    dpi=240,
    gamma=0.55,
    clip_percentile=(0.5, 99.5),
    render_mode="auto",
    edge_gloss=True,
    gloss_strength=0.7,
    annotate=False,
    save_path=f"../images/OPTIMAL_{save_name}.png",
    tick_max=36,
    rotate_xticks=90
)