In [None]:
# ============================
# Cell 1 — Acrobot ONLY (Custom Continuous Acrobot Env)
#   - Continuous torque action u in [-1, +1]
#   - Observation: [cos(th1), sin(th1), cos(th2), sin(th2), thdot1, thdot2]
#   - State (for modeling): (th1, th2, thdot1, thdot2)
#   - Reward: -1 per step until terminal (Acrobot spirit)
#   - Terminal: (-cos(th1) - cos(th1+th2)) > 1.0
#   - TimeLimit wrapper supported via make_env(...)
# ============================

import os
import math
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.wrappers import TimeLimit
from gymnasium.utils import seeding

# ----------------------------
# TF GPU setup (sanity)
# ----------------------------
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")
print("TF built with CUDA:", tf.test.is_built_with_cuda())
print("GPUs visible:", gpus)
print("Logical GPUs:", tf.config.list_logical_devices("GPU"))

if gpus:
    for gpu in gpus:
        try:
            tf.config.experimental.set_memory_growth(gpu, True)
        except Exception:
            pass

# ----------------------------
# Dtypes
# ----------------------------
DTYPE_NP = np.float64
U_MIN, U_MAX = -1.0, +1.0

# ----------------------------
# Helpers
# ----------------------------
def wrap_pi(x):
    return (x + np.pi) % (2.0 * np.pi) - np.pi

def obs_to_state(obs):
    """
    Acrobot observation:
      obs = [cos(th1), sin(th1), cos(th2), sin(th2), thdot1, thdot2]
    Convert to (th1, th2, thdot1, thdot2).
    """
    c1, s1, c2, s2, thdot1, thdot2 = map(float, obs)
    th1 = wrap_pi(math.atan2(s1, c1))
    th2 = wrap_pi(math.atan2(s2, c2))
    return np.array([th1, th2, thdot1, thdot2], dtype=DTYPE_NP)

def state_to_features(th1, th2, thdot1, thdot2, u,
                      w1_scale=8.0, w2_scale=10.0,
                      dtype=DTYPE_NP):
    """
    Default bounded feature map (D=7):
      [sin(th1), cos(th1), sin(th2), cos(th2), tanh(thdot1/w1), tanh(thdot2/w2), u]
    """
    return np.array(
        [np.sin(th1), np.cos(th1), np.sin(th2), np.cos(th2),
         np.tanh(thdot1 / w1_scale), np.tanh(thdot2 / w2_scale),
         float(u)],
        dtype=dtype
    )

# ----------------------------
# Custom Continuous Acrobot Env
# ----------------------------
class ContinuousAcrobotEnv(gym.Env):
    metadata = {"render_modes": ["rgb_array", "human"], "render_fps": 15}

    def __init__(self, render_mode=None, torque_mag=1.0, dt=0.2,
                 max_vel1=4*np.pi, max_vel2=9*np.pi):
        super().__init__()
        self.render_mode = render_mode

        # Standard-ish Acrobot params (close to classic-control)
        self.LINK_LENGTH_1 = 1.0
        self.LINK_LENGTH_2 = 1.0
        self.LINK_MASS_1 = 1.0
        self.LINK_MASS_2 = 1.0
        self.LINK_COM_POS_1 = 0.5
        self.LINK_COM_POS_2 = 0.5
        self.LINK_MOI = 1.0
        self.g = 9.8

        self.dt = float(dt)
        self.torque_mag = float(torque_mag)
        self.max_vel1 = float(max_vel1)
        self.max_vel2 = float(max_vel2)

        # Continuous action u in [-1,1], scaled by torque_mag
        self.action_space = spaces.Box(
            low=np.array([-1.0], dtype=np.float32),
            high=np.array([+1.0], dtype=np.float32),
            shape=(1,),
            dtype=np.float32
        )

        # Observation: [cos(th1), sin(th1), cos(th2), sin(th2), thdot1, thdot2]
        high = np.array([1.0, 1.0, 1.0, 1.0, self.max_vel1, self.max_vel2], dtype=np.float32)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)

        self.np_random = None
        self.state = None  # (th1, th2, thdot1, thdot2)
        self.seed()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def _terminal(self, th1, th2):
        # Same idea as classic Acrobot-v1: tip above threshold
        # y = -cos(th1) - cos(th1+th2); terminal if y > 1
        return (-math.cos(th1) - math.cos(th1 + th2)) > 1.0

    def _get_obs(self):
        th1, th2, thdot1, thdot2 = self.state
        return np.array([
            math.cos(th1), math.sin(th1),
            math.cos(th2), math.sin(th2),
            thdot1, thdot2
        ], dtype=np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        if seed is not None:
            self.seed(seed)

        # Small random init near downward-ish
        th1 = float(self.np_random.uniform(low=-0.1, high=0.1))
        th2 = float(self.np_random.uniform(low=-0.1, high=0.1))
        thdot1 = float(self.np_random.uniform(low=-0.1, high=0.1))
        thdot2 = float(self.np_random.uniform(low=-0.1, high=0.1))

        self.state = (wrap_pi(th1), wrap_pi(th2), thdot1, thdot2)
        return self._get_obs(), {}

    def step(self, action):
        action = np.asarray(action, dtype=np.float32).reshape(1,)
        assert self.action_space.contains(action), f"{action} invalid"

        u = float(np.clip(action[0], -1.0, 1.0)) * self.torque_mag

        th1, th2, thdot1, thdot2 = self.state

        # --- Dynamics (standard acrobot equations; torque applied at joint 2) ---
        m1 = self.LINK_MASS_1
        m2 = self.LINK_MASS_2
        l1 = self.LINK_LENGTH_1
        l2 = self.LINK_LENGTH_2
        lc1 = self.LINK_COM_POS_1
        lc2 = self.LINK_COM_POS_2
        I1 = self.LINK_MOI
        I2 = self.LINK_MOI
        g = self.g

        d1 = m1*lc1**2 + m2*(l1**2 + lc2**2 + 2*l1*lc2*math.cos(th2)) + I1 + I2
        d2 = m2*(lc2**2 + l1*lc2*math.cos(th2)) + I2

        phi2 = m2*lc2*g*math.cos(th1 + th2 - math.pi/2.0)
        phi1 = (-m2*l1*lc2*(thdot2**2)*math.sin(th2)
                - 2*m2*l1*lc2*thdot2*thdot1*math.sin(th2)
                + (m1*lc1 + m2*l1)*g*math.cos(th1 - math.pi/2.0)
                + phi2)

        denom = (m2*lc2**2 + I2 - (d2**2)/d1)
        if abs(denom) < 1e-10:
            denom = 1e-10

        thddot2 = (u + (d2/d1)*phi1 - m2*l1*lc2*(thdot1**2)*math.sin(th2) - phi2) / denom
        thddot1 = -(d2*thddot2 + phi1) / d1

        # Integrate
        thdot1 = thdot1 + self.dt * thddot1
        thdot2 = thdot2 + self.dt * thddot2

        # Clamp velocities
        thdot1 = float(np.clip(thdot1, -self.max_vel1, self.max_vel1))
        thdot2 = float(np.clip(thdot2, -self.max_vel2, self.max_vel2))

        th1 = wrap_pi(th1 + self.dt * thdot1)
        th2 = wrap_pi(th2 + self.dt * thdot2)

        self.state = (th1, th2, thdot1, thdot2)

        terminated = False
        truncated = False  # TimeLimit wrapper will handle truncation

        # Reward: -1 per step until success; 0 at terminal (common Acrobot convention)
        reward = 0.0 if terminated else -1.0

        obs = self._get_obs()
        info = {"u": float(u), "th1": th1, "th2": th2, "thdot1": thdot1, "thdot2": thdot2}
        return obs, float(reward), terminated, truncated, info

    def render(self):
        return None

    def close(self):
        return None

# ----------------------------
# Factory (keep the same name make_env for later cells)
# ----------------------------
MAX_EPISODE_STEPS = 2000

def make_env(render_mode=None, seed=0, max_episode_steps=MAX_EPISODE_STEPS,
             torque_mag=1.0, dt=0.2):
    env = ContinuousAcrobotEnv(render_mode=render_mode, torque_mag=torque_mag, dt=dt)
    env = TimeLimit(env, max_episode_steps=max_episode_steps)
    env.reset(seed=seed)
    return env

# ----------------------------
# Quick sanity test
# ----------------------------
env = make_env(render_mode=None, seed=0)
obs, _ = env.reset(seed=0)
s = obs_to_state(obs)

print("✅ ContinuousAcrobot ready")
print("obs shape:", obs.shape, "obs:", obs)
print("state (th1, th2, thdot1, thdot2):", s)
print("action_space:", env.action_space, "sample:", env.action_space.sample())

# step a few times
for i in range(3):
    a = env.action_space.sample()
    obs2, r, term, trunc, info = env.step(a)
    print(f"step {i}: r={r}, term={term}, info_u={info['u']:.3f}")

env.close()


In [None]:
# ============================
# Cell 2 — Random collection path + render + collect (X,Y)  ✅ Acrobot version
#
# "Everybody's Acrobot" conventions:
#   - obs = [cos(th1), sin(th1), cos(th2), sin(th2), thdot1, thdot2]
#   - success proxy height: y_tip = -cos(th1) - cos(th1+th2)
#   - terminal when y_tip > 1.0 (env already uses this)
#
# What you get:
#   - Random actions for n_steps
#   - Collect executed transitions:
#       X0: (N,7) = [sin th1, cos th1, sin th2, cos th2, tanh(w1/s1), tanh(w2/s2), u]
#       Ydth1, Ydth2, Ydw1, Ydw2  (each (N,1))
#   - Simple stick-figure rendering (pure PIL) + inline animation
#   - Plots:
#       y_tip(t), theta1/2(t), w1/2(t), u(t), phase plots
# ============================

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display
from PIL import Image, ImageDraw

# ----------------------------
# Safety: require Cell 1 symbols
# ----------------------------
required = [
    "make_env", "obs_to_state", "wrap_pi",
    "state_to_features", "U_MIN", "U_MAX"
]
missing = [k for k in required if k not in globals()]
if len(missing) > 0:
    raise NameError(f"Cell 2 missing required symbols from Cell 1: {missing}")

# ----------------------------
# Simple stick-figure renderer for Acrobot (pure PIL)
# ----------------------------
def render_acrobot_frame_from_state(
    th1, th2,
    W=720, H=450,
    link1_px=150,
    link2_px=150,
    joint_r=10,
    link_w=10
):
    """
    Draw 2-link acrobot:
      base at center-top-ish, link1 angle th1 from vertical,
      link2 angle th2 relative to link1 (standard Acrobot convention).
    """
    img = Image.new("RGB", (W, H), (245, 245, 245))
    dr = ImageDraw.Draw(img)

    # base position
    cx = int(W * 0.5)
    cy = int(H * 0.28)

    # coordinate system: y down in image
    # convention:
    #   link1 vector: (dx1, dy1) = (L1*sin(th1), L1*cos(th1))
    #   because th1=0 means pointing DOWN (classic acrobot); that's okay visually
    x1 = cx + int(link1_px * np.sin(th1))
    y1 = cy + int(link1_px * np.cos(th1))

    th12 = th1 + th2
    x2 = x1 + int(link2_px * np.sin(th12))
    y2 = y1 + int(link2_px * np.cos(th12))

    # draw ground/reference line
    dr.line([(0, cy), (W, cy)], fill=(220, 220, 220), width=3)

    # draw links
    dr.line([(cx, cy), (x1, y1)], fill=(60, 90, 160), width=link_w)
    dr.line([(x1, y1), (x2, y2)], fill=(180, 50, 50), width=link_w)

    # joints
    def circ(x, y, r, col):
        dr.ellipse([x-r, y-r, x+r, y+r], fill=col, outline=(25, 25, 25), width=2)

    circ(cx, cy, joint_r, (30, 30, 30))
    circ(x1, y1, joint_r, (30, 30, 30))
    circ(x2, y2, joint_r, (30, 30, 30))

    return np.asarray(img, dtype=np.uint8)

# ----------------------------
# Acrobot "everybody metric": tip height proxy
# ----------------------------
def acrobot_tip_height_proxy(th1, th2):
    # classic acrobot uses: y = -cos(th1) - cos(th1+th2)
    return float(-np.cos(th1) - np.cos(th1 + th2))

# ----------------------------
# Random collection (rendered)
# ----------------------------
def collect_random_transitions_rendered_acrobot(
    n_steps=2000,
    seed=0,
    max_episode_steps=2000,
    # env params
    torque_mag=1.0,
    dt=0.2,
    # rendering
    record_rgb=True,
    frame_stride=3,
    resize=(720, 450),
    fps=15,
    # reset behavior
    reset_on_done=True,
    verbose=True,
):
    rng = np.random.default_rng(seed)

    env = make_env(
        render_mode=None,
        seed=seed,
        max_episode_steps=max_episode_steps,
        torque_mag=torque_mag,
        dt=dt
    )

    obs, info = env.reset(seed=seed)
    s = obs_to_state(obs)  # (th1, th2, w1, w2)
    th1, th2, w1, w2 = map(float, s)

    X_list, Ydth1_list, Ydth2_list, Ydw1_list, Ydw2_list = [], [], [], [], []

    traj_th1, traj_th2, traj_w1, traj_w2, traj_u, traj_y = [], [], [], [], [], []
    frames = []
    resets = 0
    terminals = 0

    for t in range(n_steps):
        u = float(rng.uniform(U_MIN, U_MAX))

        obs2, reward, terminated, truncated, info = env.step(np.array([u], dtype=np.float32))
        s2 = obs_to_state(obs2)
        th1_2, th2_2, w1_2, w2_2 = map(float, s2)

        # collect GP transition (delta targets)
        X_list.append(state_to_features(th1, th2, w1, w2, u))
        Ydth1_list.append([wrap_pi(th1_2 - th1)])
        Ydth2_list.append([wrap_pi(th2_2 - th2)])
        Ydw1_list.append([w1_2 - w1])
        Ydw2_list.append([w2_2 - w2])

        # log traj
        traj_th1.append(th1); traj_th2.append(th2)
        traj_w1.append(w1);   traj_w2.append(w2)
        traj_u.append(u)
        traj_y.append(acrobot_tip_height_proxy(th1, th2))

        # render
        if record_rgb and ((t % frame_stride) == 0):
            W, H = int(resize[0]), int(resize[1])
            frames.append(render_acrobot_frame_from_state(th1_2, th2_2, W=W, H=H))

        # advance
        th1, th2, w1, w2 = th1_2, th2_2, w1_2, w2_2

        if terminated or truncated:
            terminals += int(terminated)
            if reset_on_done:
                resets += 1
                obs, info = env.reset(seed=int(seed + 1000 + t))
                s = obs_to_state(obs)
                th1, th2, w1, w2 = map(float, s)
                if verbose:
                    print(f"[t={t:04d}] reset (terminated={terminated}, truncated={truncated})")

    env.close()

    X0 = np.asarray(X_list, dtype=np.float64)
    Ydth1 = np.asarray(Ydth1_list, dtype=np.float64).reshape(-1, 1)
    Ydth2 = np.asarray(Ydth2_list, dtype=np.float64).reshape(-1, 1)
    Ydw1  = np.asarray(Ydw1_list, dtype=np.float64).reshape(-1, 1)
    Ydw2  = np.asarray(Ydw2_list, dtype=np.float64).reshape(-1, 1)

    traj = dict(
        th1=np.asarray(traj_th1, dtype=np.float64),
        th2=np.asarray(traj_th2, dtype=np.float64),
        w1=np.asarray(traj_w1, dtype=np.float64),
        w2=np.asarray(traj_w2, dtype=np.float64),
        u=np.asarray(traj_u, dtype=np.float64),
        y_tip=np.asarray(traj_y, dtype=np.float64),
        resets=int(resets),
        terminals=int(terminals),
        steps=int(n_steps),
        kept=int(X0.shape[0]),
    )

    # ----------------------------
    # 1) animation
    # ----------------------------
    if record_rgb and (len(frames) > 0):
        fig = plt.figure(figsize=(resize[0] / 100, resize[1] / 100), dpi=100)
        plt.axis("off")
        im = plt.imshow(frames[0])

        def animate_fn(i):
            im.set_data(frames[i])
            return [im]

        ani = animation.FuncAnimation(
            fig, animate_fn,
            frames=len(frames),
            interval=1000 / float(fps),
            blit=True
        )
        plt.close(fig)
        display(HTML(ani.to_jshtml()))
    else:
        print("⚠️ No frames collected (record_rgb=False or frame_stride too large).")

    # ----------------------------
    # 2) plots (Acrobot standard)
    # ----------------------------
    T = len(traj["th1"])
    tgrid = np.arange(T)

    plt.figure(figsize=(9, 3.2))
    plt.plot(tgrid, traj["y_tip"], linewidth=2)
    plt.axhline(1.0, linestyle="--", linewidth=2)
    plt.xlabel("t"); plt.ylabel("y_tip = -cos(th1) - cos(th1+th2)")
    plt.title("Acrobot swing-up progress (tip height proxy)")
    plt.grid(True, alpha=0.25); plt.tight_layout(); plt.show()

    plt.figure(figsize=(9, 3.2))
    plt.plot(tgrid, traj["th1"], linewidth=2, label="th1")
    plt.plot(tgrid, traj["th2"], linewidth=2, label="th2")
    plt.xlabel("t"); plt.ylabel("angle (rad)")
    plt.title("Angles (wrapped)")
    plt.grid(True, alpha=0.25); plt.legend(); plt.tight_layout(); plt.show()

    plt.figure(figsize=(9, 3.2))
    plt.plot(tgrid, traj["w1"], linewidth=2, label="w1")
    plt.plot(tgrid, traj["w2"], linewidth=2, label="w2")
    plt.xlabel("t"); plt.ylabel("angular velocity (rad/s)")
    plt.title("Angular velocities")
    plt.grid(True, alpha=0.25); plt.legend(); plt.tight_layout(); plt.show()

    plt.figure(figsize=(9, 3.0))
    plt.plot(tgrid, traj["u"], linewidth=2)
    plt.xlabel("t"); plt.ylabel("u (torque command)")
    plt.title("Random continuous actions u(t)")
    plt.grid(True, alpha=0.25); plt.tight_layout(); plt.show()

    plt.figure(figsize=(6.2, 5.2))
    plt.scatter(traj["th1"], traj["w1"], s=10, alpha=0.5)
    plt.xlabel("th1"); plt.ylabel("w1")
    plt.title("Phase: th1 vs w1")
    plt.grid(True, alpha=0.25); plt.tight_layout(); plt.show()

    plt.figure(figsize=(6.2, 5.2))
    plt.scatter(traj["th2"], traj["w2"], s=10, alpha=0.5)
    plt.xlabel("th2"); plt.ylabel("w2")
    plt.title("Phase: th2 vs w2")
    plt.grid(True, alpha=0.25); plt.tight_layout(); plt.show()

    print("Collected X0 shape:", X0.shape, " (features)")
    print("Targets shapes:", Ydth1.shape, Ydth2.shape, Ydw1.shape, Ydw2.shape)
    print(f"Kept={traj['kept']}  resets={traj['resets']}  terminals={traj['terminals']}")

    return X0, Ydth1, Ydth2, Ydw1, Ydw2, frames, traj


# ---- run it ----
SEED = 0
X0, Ydth1_0, Ydth2_0, Ydw1_0, Ydw2_0, frames0, traj0 = collect_random_transitions_rendered_acrobot(
    n_steps=6000,
    seed=SEED,
    max_episode_steps=2000,
    torque_mag=1.0,
    dt=0.2,
    record_rgb=True,
    frame_stride=3,
    resize=(720, 450),
    fps=15,
    reset_on_done=True,
    verbose=False,
)


In [None]:
# ===========================
# Cell 3 — OSGPR-VFE core (Streaming Sparse GP) + training + summaries + anchors  ✅ ACROBOT
#
# Same OSGPR/VFE machinery as your CartPole pipeline, but with the ACROBOT feature map:
#   - State S is (B,4) = [th1, th2, w1, w2]
#   - Features are 7D: [sin(th1), cos(th1), sin(th2), cos(th2), tanh(w1/s1), tanh(w2/s2), u]
#
# Provides:
#   - batch_state_to_features(): (B,4)+(B,) -> (B,7)   ✅ Acrobot
#   - OSGPR_VFE (single-output)
#   - train_osgpr()
#   - prior_summary(), extract_summary_from_model()
#   - greedy_dopt_anchors_from_K()
#   - rebuild_osgpr_from_old_summary(): returns (model_new, train_time, neg_obj)
# ===========================

import time
import copy
import numpy as np
import tensorflow as tf
import gpflow

from gpflow.inducing_variables import InducingPoints
from gpflow.models import GPModel, InternalDataTrainingLossMixin
from gpflow import covariances

# ---- numerics ----
gpflow.config.set_default_float(np.float64)
gpflow.config.set_default_jitter(1e-6)
tf.keras.backend.set_floatx("float64")

print("TF built with CUDA:", tf.test.is_built_with_cuda())
try:
    print("GPUs visible:", tf.config.list_physical_devices("GPU"))
except Exception as e:
    print("GPU query failed:", e)

DTYPE = gpflow.default_float()

# ---------------------------
# helpers
# ---------------------------
def sym_jitter(A, jitter=1e-6):
    """Make symmetric + add jitter (numpy)."""
    A = np.asarray(A, dtype=np.float64)
    A = 0.5 * (A + A.T)
    A = A + float(jitter) * np.eye(A.shape[0], dtype=np.float64)
    return A

def finite_mask(*arrs):
    """Row-wise finite mask across arrays."""
    m = None
    for a in arrs:
        a = np.asarray(a)
        mm = np.isfinite(a).all(axis=1) if a.ndim == 2 else np.isfinite(a)
        m = mm if m is None else (m & mm)
    return m

def clone_kernel(kernel):
    """
    Clone a GPflow kernel (to avoid variable-sharing across models).
    gpflow.utilities.deepcopy exists in many versions; fallback to copy.deepcopy.
    """
    try:
        from gpflow.utilities import deepcopy as gf_deepcopy
        return gf_deepcopy(kernel)
    except Exception:
        return copy.deepcopy(kernel)

# ------------------------------------------------------------
# Batch feature map (FAST) — used by MPPI later  ✅ ACROBOT
# ------------------------------------------------------------
def batch_state_to_features(S, U, w1_scale=8.0, w2_scale=10.0):
    """
    Vectorized mapping from Acrobot state to 7D GP features.

    S: (B,4)  [th1, th2, w1, w2]
    U: (B,)   action in [-1,1]  (or env action bounds)
    Returns:
      Xfeat: (B,7) [sin(th1), cos(th1), sin(th2), cos(th2),
                    tanh(w1/w1_scale), tanh(w2/w2_scale), u]
    """
    S = np.asarray(S, dtype=np.float64)
    U = np.asarray(U, dtype=np.float64).reshape(-1)
    assert S.ndim == 2 and S.shape[1] == 4, "S must be (B,4)=[th1, th2, w1, w2]"
    assert U.shape[0] == S.shape[0], "U must match batch size"

    th1 = S[:, 0]
    th2 = S[:, 1]
    w1  = S[:, 2]
    w2  = S[:, 3]

    Xf = np.empty((S.shape[0], 7), dtype=np.float64)
    Xf[:, 0] = np.sin(th1)
    Xf[:, 1] = np.cos(th1)
    Xf[:, 2] = np.sin(th2)
    Xf[:, 3] = np.cos(th2)
    Xf[:, 4] = np.tanh(w1 / w1_scale)
    Xf[:, 5] = np.tanh(w2 / w2_scale)
    Xf[:, 6] = U
    return Xf

# ============================================================
# OSGPR-VFE model — regression-only, single-output
# ============================================================
class OSGPR_VFE(GPModel, InternalDataTrainingLossMixin):
    """
    Online Sparse Variational GP Regression (VFE), SINGLE-OUTPUT.

    Provide:
      - current batch data (X, Y)
      - old summary q_old(u)=N(mu_old, Su_old) at Z_old
      - Kaa_old = K(Z_old,Z_old) from old step
      - new inducing Z (usually Z_GLOBAL; you MAY refresh Z over time, but size should be capped)

    Includes:
      - predict_f (correct but slower)
      - build_predict_cache + predict_f_cached (FAST diag predictions)
    """
    def __init__(self, data, kernel, mu_old, Su_old, Kaa_old, Z_old, Z, mean_function=None):
        X, Y = gpflow.models.util.data_input_to_tensor(data)
        self.X, self.Y = X, Y

        likelihood = gpflow.likelihoods.Gaussian()
        num_latent_gps = GPModel.calc_num_latent_gps_from_data(data, kernel, likelihood)
        super().__init__(kernel, likelihood, mean_function, num_latent_gps)

        Z = np.asarray(Z, dtype=np.float64)
        assert Z.ndim == 2, "Z must be (M, D)"
        self.inducing_variable = InducingPoints(Z)
        gpflow.set_trainable(self.inducing_variable, False)

        mu_old  = np.asarray(mu_old, dtype=np.float64).reshape(-1, 1)
        Su_old  = sym_jitter(Su_old, 1e-6)
        Kaa_old = sym_jitter(Kaa_old, 1e-6)
        Z_old   = np.asarray(Z_old, dtype=np.float64)

        self.mu_old  = tf.Variable(mu_old,  trainable=False, dtype=DTYPE)
        self.Su_old  = tf.Variable(Su_old,  trainable=False, dtype=DTYPE)
        self.Kaa_old = tf.Variable(Kaa_old, trainable=False, dtype=DTYPE)
        self.Z_old   = tf.Variable(Z_old,   trainable=False, dtype=DTYPE)

        if self.mean_function is None:
            self.mean_function = gpflow.mean_functions.Zero()

        # cache for fast predict
        self._cache_ready = False
        self._cache_Lb = None
        self._cache_LD = None
        self._cache_rhs = None

    def _common_terms(self):
        """
        Build common matrices used by both ELBO and prediction.

        Z   : new inducing (Mb)
        Za  : old inducing (Ma) == self.Z_old
        X   : current batch inputs

        Kbf = K(Z, X)    [Mb, N]
        Kbb = K(Z, Z)    [Mb, Mb]
        Kba = K(Z, Za)   [Mb, Ma]
        """
        jitter = gpflow.utilities.to_default_float(1e-6)
        sigma2 = self.likelihood.variance

        Saa = self.Su_old  # [Ma,Ma]
        ma  = self.mu_old  # [Ma,1]

        Kbf = covariances.Kuf(self.inducing_variable, self.kernel, self.X)           # [Mb, N]
        Kbb = covariances.Kuu(self.inducing_variable, self.kernel, jitter=jitter)   # [Mb, Mb]
        Kba = covariances.Kuf(self.inducing_variable, self.kernel, self.Z_old)      # [Mb, Ma]

        Kaa_cur = gpflow.utilities.add_noise_cov(self.kernel(self.Z_old), jitter)   # [Ma,Ma]
        Kaa     = gpflow.utilities.add_noise_cov(self.Kaa_old, jitter)              # [Ma,Ma]

        err = self.Y - self.mean_function(self.X)  # [N,1]

        # c = Kbf*(Y/sigma2) + Kba*(Saa^{-1} ma)
        Sainv_ma = tf.linalg.solve(Saa, ma)                                # [Ma,1]
        c = tf.matmul(Kbf, self.Y / sigma2) + tf.matmul(Kba, Sainv_ma)     # [Mb,1]

        # Cholesky(Kbb)
        Lb = tf.linalg.cholesky(Kbb)                                       # [Mb,Mb]
        Lbinv_c   = tf.linalg.triangular_solve(Lb, c,   lower=True)        # [Mb,1]
        Lbinv_Kba = tf.linalg.triangular_solve(Lb, Kba, lower=True)        # [Mb,Ma]
        Lbinv_Kbf = tf.linalg.triangular_solve(Lb, Kbf, lower=True) / tf.sqrt(sigma2)  # [Mb,N]

        d1 = tf.matmul(Lbinv_Kbf, Lbinv_Kbf, transpose_b=True)             # [Mb,Mb]

        # T = (Lb^{-1}Kba)^T  => [Ma,Mb]
        T = tf.linalg.matrix_transpose(Lbinv_Kba)

        # d2
        LSa = tf.linalg.cholesky(Saa)
        LSainv_T = tf.linalg.triangular_solve(LSa, T, lower=True)
        d2 = tf.matmul(LSainv_T, LSainv_T, transpose_a=True)               # [Mb,Mb]

        # d3
        La = tf.linalg.cholesky(Kaa)
        Lainv_T = tf.linalg.triangular_solve(La, T, lower=True)
        d3 = tf.matmul(Lainv_T, Lainv_T, transpose_a=True)                 # [Mb,Mb]

        Mb = self.inducing_variable.num_inducing
        D = tf.eye(Mb, dtype=DTYPE) + d1 + d2 - d3
        D = gpflow.utilities.add_noise_cov(D, jitter)
        LD = tf.linalg.cholesky(D)

        rhs = tf.linalg.triangular_solve(LD, Lbinv_c, lower=True)          # [Mb,1]

        Qff_diag = tf.reduce_sum(tf.square(Lbinv_Kbf), axis=0)             # [N]

        tf.debugging.assert_all_finite(Lb,  "Lb has NaN/Inf")
        tf.debugging.assert_all_finite(LD,  "LD has NaN/Inf")
        tf.debugging.assert_all_finite(rhs, "rhs has NaN/Inf")

        return (Kbf, Kba, Kaa, Kaa_cur, La, Kbb, Lb, D, LD, Lbinv_Kba, rhs, err, Qff_diag)

    def maximum_log_likelihood_objective(self):
        sigma2 = self.likelihood.variance
        N = tf.cast(tf.shape(self.X)[0], DTYPE)

        Saa = self.Su_old
        ma  = self.mu_old
        Kfdiag = self.kernel(self.X, full_cov=False)

        (Kbf, Kba, Kaa, Kaa_cur, La, Kbb, Lb, D, LD,
         Lbinv_Kba, rhs, err, Qff_diag) = self._common_terms()

        LSa = tf.linalg.cholesky(Saa)
        Lainv_ma = tf.linalg.triangular_solve(LSa, ma, lower=True)

        bound = -0.5 * N * np.log(2.0 * np.pi)
        bound += -0.5 * tf.reduce_sum(tf.square(err)) / sigma2
        bound += -0.5 * tf.reduce_sum(tf.square(Lainv_ma))
        bound +=  0.5 * tf.reduce_sum(tf.square(rhs))

        bound += -0.5 * N * tf.math.log(sigma2)
        bound += -tf.reduce_sum(tf.math.log(tf.linalg.diag_part(LD)))

        bound += -0.5 * tf.reduce_sum(Kfdiag) / sigma2
        bound +=  0.5 * tf.reduce_sum(Qff_diag)

        bound += tf.reduce_sum(tf.math.log(tf.linalg.diag_part(La)))
        bound += -tf.reduce_sum(tf.math.log(tf.linalg.diag_part(LSa)))

        # correction term involving Kaa_cur - Qaa
        Kaadiff = Kaa_cur - tf.matmul(Lbinv_Kba, Lbinv_Kba, transpose_a=True)
        Sainv_Kaadiff = tf.linalg.solve(Saa, Kaadiff)
        Kainv_Kaadiff = tf.linalg.solve(Kaa, Kaadiff)

        bound += -0.5 * tf.reduce_sum(
            tf.linalg.diag_part(Sainv_Kaadiff) - tf.linalg.diag_part(Kainv_Kaadiff)
        )
        return bound

    def predict_f(self, Xnew, full_cov=False):
        jitter = gpflow.utilities.to_default_float(1e-6)

        Kbs = covariances.Kuf(self.inducing_variable, self.kernel, Xnew)  # [Mb, Nnew]
        (_, _, _, _, _, _, Lb, _, LD, _, rhs, _, _) = self._common_terms()

        Lbinv_Kbs = tf.linalg.triangular_solve(Lb, Kbs, lower=True)
        LDinv_Lbinv_Kbs = tf.linalg.triangular_solve(LD, Lbinv_Kbs, lower=True)
        mean = tf.matmul(LDinv_Lbinv_Kbs, rhs, transpose_a=True)  # [Nnew,1]

        if full_cov:
            Kss = self.kernel(Xnew) + jitter * tf.eye(tf.shape(Xnew)[0], dtype=DTYPE)
            var = (
                Kss
                - tf.matmul(Lbinv_Kbs, Lbinv_Kbs, transpose_a=True)
                + tf.matmul(LDinv_Lbinv_Kbs, LDinv_Lbinv_Kbs, transpose_a=True)
            )
            return mean + self.mean_function(Xnew), var
        else:
            var = (
                self.kernel(Xnew, full_cov=False)
                - tf.reduce_sum(tf.square(Lbinv_Kbs), axis=0)
                + tf.reduce_sum(tf.square(LDinv_Lbinv_Kbs), axis=0)
            )
            var = tf.maximum(var, tf.cast(1e-12, var.dtype))
            return mean + self.mean_function(Xnew), var

    def build_predict_cache(self):
        """Build cached matrices for fast predict_f_cached(). Call after training / after each update."""
        (_, _, _, _, _, _, Lb, _, LD, _, rhs, _, _) = self._common_terms()
        self._cache_Lb = Lb
        self._cache_LD = LD
        self._cache_rhs = rhs
        self._cache_ready = True

    def predict_f_cached(self, Xnew, full_cov=False):
        """Fast diag prediction using cached Lb, LD, rhs."""
        if not self._cache_ready:
            return self.predict_f(Xnew, full_cov=full_cov)

        jitter = gpflow.utilities.to_default_float(1e-6)
        Lb  = self._cache_Lb
        LD  = self._cache_LD
        rhs = self._cache_rhs

        Kbs = covariances.Kuf(self.inducing_variable, self.kernel, Xnew)  # [Mb,Nnew]
        Lbinv_Kbs = tf.linalg.triangular_solve(Lb, Kbs, lower=True)
        LDinv_Lbinv_Kbs = tf.linalg.triangular_solve(LD, Lbinv_Kbs, lower=True)
        mean = tf.matmul(LDinv_Lbinv_Kbs, rhs, transpose_a=True)

        if full_cov:
            Kss = self.kernel(Xnew) + jitter * tf.eye(tf.shape(Xnew)[0], dtype=DTYPE)
            var = (
                Kss
                - tf.matmul(Lbinv_Kbs, Lbinv_Kbs, transpose_a=True)
                + tf.matmul(LDinv_Lbinv_Kbs, LDinv_Lbinv_Kbs, transpose_a=True)
            )
            return mean + self.mean_function(Xnew), var
        else:
            var = (
                self.kernel(Xnew, full_cov=False)
                - tf.reduce_sum(tf.square(Lbinv_Kbs), axis=0)
                + tf.reduce_sum(tf.square(LDinv_Lbinv_Kbs), axis=0)
            )
            var = tf.maximum(var, tf.cast(1e-12, var.dtype))
            return mean + self.mean_function(Xnew), var

# ----------------------------
# training helper
# ----------------------------
def train_osgpr(model, iters=250, lr=0.02, clip_norm=10.0):
    """Adam optimize the negative ELBO."""
    opt = tf.keras.optimizers.Adam(lr)

    @tf.function
    def step():
        with tf.GradientTape() as tape:
            loss = -model.maximum_log_likelihood_objective()
        grads = tape.gradient(loss, model.trainable_variables)
        if clip_norm is not None:
            grads = [tf.clip_by_norm(g, clip_norm) if g is not None else None for g in grads]
        opt.apply_gradients([(g, v) for g, v in zip(grads, model.trainable_variables) if g is not None])
        return loss

    t0 = time.time()
    last = None
    for _ in range(int(iters)):
        last = step()
    return float(time.time() - t0), float(last.numpy())

# ----------------------------
# summaries (to chain online)
# ----------------------------
def prior_summary(kernel, Z):
    """
    Prior summary at inducing Z for the first model:
      mu0 = 0
      Su0 = Kzz
      Kaa0 = Kzz
    """
    Z = np.asarray(Z, dtype=np.float64)
    Kzz = kernel.K(Z).numpy()
    Kzz = sym_jitter(Kzz, 1e-6)
    mu0 = np.zeros((Z.shape[0], 1), dtype=np.float64)
    return mu0, Kzz, Kzz, Z

def extract_summary_from_model(model):
    """
    Extract q(u)=N(mu,Su) at model's current inducing Z plus Kaa=K(Z,Z).
    """
    Z = model.inducing_variable.Z.numpy().astype(np.float64)

    mu_tf, Sig_tf = model.predict_f(Z, full_cov=True)
    mu = mu_tf.numpy().reshape(-1, 1)

    Su = Sig_tf.numpy()
    if Su.ndim == 3:
        Su = Su[0]
    Su = sym_jitter(Su, 1e-6)

    Kaa = model.kernel.K(Z).numpy()
    Kaa = sym_jitter(Kaa, 1e-6)
    return mu, Su, Kaa, Z

# ============================================================
# Anchors: greedy D-opt (log-det) on Kzz
# ============================================================
def greedy_dopt_anchors_from_K(Kzz, m_anchors=24, lam=1e-6):
    """
    Greedy log-det anchor selection on PSD Kzz using incremental Cholesky updates.
    Returns indices of size m_anchors.
    """
    K = np.asarray(Kzz, dtype=np.float64)
    M = K.shape[0]
    assert K.shape == (M, M)
    K = sym_jitter(K, lam)

    chosen = []
    diag = np.clip(np.diag(K).copy(), 1e-12, None)
    remaining = np.ones(M, dtype=bool)
    L = None

    for k in range(min(int(m_anchors), M)):
        if k == 0:
            i = int(np.argmax(diag))
            chosen.append(i)
            remaining[i] = False
            L = np.array([[np.sqrt(diag[i])]], dtype=np.float64)
            continue

        S = np.array(chosen, dtype=np.int64)
        Ks_all = K[np.ix_(S, np.arange(M))]     # (k,M)

        v = np.linalg.solve(L, Ks_all)          # (k,M)
        vn2 = np.sum(v * v, axis=0)             # (M,)
        s2 = diag - vn2
        s2 = np.where(remaining, s2, -np.inf)

        i = int(np.argmax(s2))
        if not np.isfinite(s2[i]) or s2[i] <= 1e-12:
            cand = np.where(remaining)[0]
            if len(cand) == 0:
                break
            i = int(cand[np.argmax(diag[cand])])
            s2_i = max(diag[i], 1e-12)
        else:
            s2_i = float(s2[i])

        chosen.append(i)
        remaining[i] = False

        kvec = K[np.ix_(S, [i])].reshape(-1, 1)  # (k,1)
        w = np.linalg.solve(L, kvec)             # (k,1)
        alpha = np.sqrt(max(s2_i, 1e-12))

        L_new = np.zeros((k + 1, k + 1), dtype=np.float64)
        L_new[:k, :k] = L
        L_new[k, :k] = w.reshape(-1)
        L_new[k, k] = alpha
        L = L_new

    return np.array(chosen, dtype=np.int64)

# ============================================================
# Online update builder (GLOBAL update step)
# ============================================================
def rebuild_osgpr_from_old_summary(
    model_old,
    X_new,
    Y_new,
    Z_new=None,
    iters=120,
    lr=0.02,
    noise=1e-4,
    freeze_kernel=False,
    clip_norm=10.0,
):
    """
    Build a NEW OSGPR_VFE model using:
      - old posterior summary extracted from model_old at its inducing Z_old
      - new executed batch (X_new, Y_new)
      - inducing set Z_new (defaults to model_old.Z; you may pass a refreshed Z here)

    Returns:
      model_new, train_time_sec, last_neg_obj
    """
    # old summary
    mu_old, Su_old, Kaa_old, Z_old = extract_summary_from_model(model_old)

    # inducing set for the new model
    if Z_new is None:
        Z_use = Z_old
    else:
        Z_use = np.asarray(Z_new, dtype=np.float64)

    # clone kernel to avoid variable-sharing surprises
    k_new = clone_kernel(model_old.kernel)

    m = OSGPR_VFE(
        data=(np.asarray(X_new, dtype=np.float64), np.asarray(Y_new, dtype=np.float64)),
        kernel=k_new,
        mu_old=mu_old, Su_old=Su_old, Kaa_old=Kaa_old, Z_old=Z_old,
        Z=Z_use,
    )
    m.likelihood.variance.assign(float(noise))

    if freeze_kernel:
        try:
            m.kernel.variance.trainable = False
            m.kernel.lengthscales.trainable = False
        except Exception:
            pass

    t_sec, neg = train_osgpr(m, iters=iters, lr=lr, clip_norm=clip_norm)
    m.build_predict_cache()
    return m, float(t_sec), float(neg)

print("✅ OSGPR core + helpers ready (Cell 3 — ACROBOT feature map + pipeline)")


In [None]:
# ===========================
# Cell 4 — Train INITIAL GLOBAL OSGPR models (ACROBOT, 4 outputs)
#   + GLOBAL evaluation plots (mean + uncertainty)
#   + Anchor reselection utilities (for the new pipeline)
#
# Inputs expected from your CURRENT Cell 2 (Acrobot):
#   X0:      (N, D)  features (you said (N,7))
#   Ydth1_0: (N,1)   Δtheta1
#   Ydth2_0: (N,1)   Δtheta2
#   Ydw1_0:  (N,1)   Δomega1
#   Ydw2_0:  (N,1)   Δomega2
#
# Produces:
#   m_dth1, m_dth2, m_dw1, m_dw2   (GLOBAL OSGPR models)
#   Z_GLOBAL  (fixed)
#   refresh_global_anchors() -> updates ANCHOR_IDX
#   Plot helpers: slice + surface (mean colored by std) for any model
# ===========================

import numpy as np
import gpflow
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import plotly.graph_objects as go

SEED = 0
rng = np.random.default_rng(SEED)

# ----------------------------
# Safety: require OSGPR core from Cell 3 + Acrobot data from Cell 2
# ----------------------------
required_syms = [
    "finite_mask", "sym_jitter",
    "prior_summary", "train_osgpr",
    "OSGPR_VFE", "greedy_dopt_anchors_from_K",
]
missing = [k for k in required_syms if k not in globals()]
if len(missing) > 0:
    raise NameError(f"Cell 4 missing required symbols from Cell 3: {missing}")

required_data = ["X0", "Ydth1_0", "Ydth2_0", "Ydw1_0", "Ydw2_0"]
missing_data = [k for k in required_data if k not in globals()]
if len(missing_data) > 0:
    raise NameError(
        "Cell 4 expects Acrobot random-collection outputs from your Cell 2:\n"
        f"Missing: {missing_data}\n"
        "Please ensure Cell 2 returns X0, Ydth1_0, Ydth2_0, Ydw1_0, Ydw2_0."
    )

# Alias to clean internal names
Ydth1 = Ydth1_0
Ydth2 = Ydth2_0
Ydw1  = Ydw1_0
Ydw2  = Ydw2_0

# ----------------------------
# 0) Clean / finite filter
# ----------------------------
mask = finite_mask(X0, Ydth1, Ydth2, Ydw1, Ydw2)
X0f    = X0[mask]
Yth1f  = Ydth1[mask]
Yth2f  = Ydth2[mask]
Yw1f   = Ydw1[mask]
Yw2f   = Ydw2[mask]
print("Data kept:", X0f.shape[0], "/", X0.shape[0])

# ----------------------------
# 1) Train/test split
# ----------------------------
N = X0f.shape[0]
perm = rng.permutation(N)
N_test = max(50, int(0.2 * N))
test_idx  = perm[:N_test]
train_idx = perm[N_test:]

Xtr = X0f[train_idx]
Xte = X0f[test_idx]

Yth1_tr, Yth1_te = Yth1f[train_idx], Yth1f[test_idx]
Yth2_tr, Yth2_te = Yth2f[train_idx], Yth2f[test_idx]
Yw1_tr,  Yw1_te  = Yw1f[train_idx],  Yw1f[test_idx]
Yw2_tr,  Yw2_te  = Yw2f[train_idx],  Yw2f[test_idx]

print("Train:", Xtr.shape, " Test:", Xte.shape)

# ----------------------------
# 2) GLOBAL inducing set (fixed capacity)
# ----------------------------
M_GLOBAL = int(globals().get("M_GLOBAL", 512))
N_avail  = Xtr.shape[0]

if N_avail >= M_GLOBAL:
    idxZ = rng.choice(N_avail, size=M_GLOBAL, replace=False)
    Z_GLOBAL = Xtr[idxZ].copy().astype(np.float64)
else:
    print(f"⚠️ Warning: Train data ({N_avail}) < M_GLOBAL ({M_GLOBAL}). Sampling with replacement.")
    idxZ = rng.choice(N_avail, size=M_GLOBAL, replace=True)
    Z_GLOBAL = Xtr[idxZ].copy().astype(np.float64)
    Z_GLOBAL += rng.standard_normal(Z_GLOBAL.shape) * 1e-6  # avoid exact duplicates

print("Z_GLOBAL:", Z_GLOBAL.shape)

# ----------------------------
# 3) Kernels (ARD SE) with correct input dim
# ----------------------------
D_IN = int(Xtr.shape[1])

def make_kernel():
    return gpflow.kernels.SquaredExponential(
        lengthscales=np.ones((D_IN,), dtype=np.float64),
        variance=1.0
    )

k_th1 = make_kernel()
k_th2 = make_kernel()
k_w1  = make_kernel()
k_w2  = make_kernel()

# ----------------------------
# 4) Build + train helper (initial model uses prior summary at Z_GLOBAL)
# ----------------------------
def build_and_train_global_model(kernel, X, Y, Z_init, name,
                                 iters=300, lr=0.02, noise=1e-4):
    mu_old, Su_old, Kaa_old, Z_old = prior_summary(kernel, Z_init)

    m = OSGPR_VFE(
        data=(np.asarray(X, dtype=np.float64), np.asarray(Y, dtype=np.float64)),
        kernel=kernel,
        mu_old=mu_old, Su_old=Su_old, Kaa_old=Kaa_old, Z_old=Z_old,
        Z=Z_init
    )
    m.likelihood.variance.assign(float(noise))

    print(f"\nTraining {name} ...")
    t, neg = train_osgpr(m, iters=iters, lr=lr, clip_norm=10.0)
    print(f"{name} done | train={t:.2f}s | neg_obj={neg:.4f} | noise={float(m.likelihood.variance.numpy()):.2e}")

    # for speed later in MPPI / eval
    if hasattr(m, "build_predict_cache"):
        m.build_predict_cache()
    return m

# ----------------------------
# 5) Train 4 global models
# ----------------------------
m_dth1 = build_and_train_global_model(k_th1, Xtr, Yth1_tr, Z_GLOBAL, "dtheta1",  iters=300, lr=0.02, noise=1e-4)
m_dth2 = build_and_train_global_model(k_th2, Xtr, Yth2_tr, Z_GLOBAL, "dtheta2",  iters=300, lr=0.02, noise=1e-4)
m_dw1  = build_and_train_global_model(k_w1,  Xtr, Yw1_tr,  Z_GLOBAL, "domega1",  iters=300, lr=0.02, noise=1e-4)
m_dw2  = build_and_train_global_model(k_w2,  Xtr, Yw2_tr,  Z_GLOBAL, "domega2",  iters=300, lr=0.02, noise=1e-4)

print("\n✅ Global Acrobot OSGPR models trained + caches ready")

# ----------------------------
# 6) Numeric eval (RMSE on held-out)
# ----------------------------
def rmse_on_test(model, Xte, Yte):
    Xte = np.asarray(Xte, dtype=np.float64)
    y = np.asarray(Yte, dtype=np.float64).reshape(-1)
    if hasattr(model, "predict_f_cached"):
        mu, _ = model.predict_f_cached(Xte, full_cov=False)
    else:
        mu, _ = model.predict_f(Xte, full_cov=False)
    yhat = mu.numpy().reshape(-1)
    return float(np.sqrt(np.mean((yhat - y) ** 2)))

print("\nHeld-out RMSE:")
print("  dtheta1 :", rmse_on_test(m_dth1, Xte, Yth1_te))
print("  dtheta2 :", rmse_on_test(m_dth2, Xte, Yth2_te))
print("  domega1 :", rmse_on_test(m_dw1,  Xte, Yw1_te))
print("  domega2 :", rmse_on_test(m_dw2,  Xte, Yw2_te))

# ============================================================
# 7) Anchor reselection utilities (CRITICAL for new pipeline)
# ============================================================
def reselect_anchors_from_model(model, m_anchors=24, lam=1e-6):
    """
    Anchors chosen on K(Z_GLOBAL,Z_GLOBAL) using the model's CURRENT kernel hyperparams.
    Re-run after each global update if kernel changes or you want anchors to adapt.
    """
    Z = model.inducing_variable.Z.numpy().astype(np.float64)
    Kzz = model.kernel.K(Z).numpy()
    Kzz = sym_jitter(Kzz, lam)
    m_anchors = int(min(m_anchors, Z.shape[0]))
    idx = greedy_dopt_anchors_from_K(Kzz, m_anchors=m_anchors, lam=lam)
    return idx

ANCHOR_M = min(24, int(Z_GLOBAL.shape[0]))
ANCHOR_IDX = reselect_anchors_from_model(m_dw2, m_anchors=ANCHOR_M, lam=1e-6)  # pick one head
print("ANCHOR_IDX:", ANCHOR_IDX.shape, f"(anchors={len(ANCHOR_IDX)})")

def refresh_global_anchors(m_anchors=24):
    """
    Call this after global update (especially if kernel hyperparams move).
    Uses m_dw2 by default (you can switch to m_dth1 etc.).
    """
    global ANCHOR_IDX
    ANCHOR_IDX = reselect_anchors_from_model(m_dw2, m_anchors=m_anchors, lam=1e-6)
    return ANCHOR_IDX

# ============================================================
# 8) Plot helpers (GLOBAL): mean + uncertainty
# ============================================================
def gp_predict_mu_std_fast(model, X):
    X = np.asarray(X, dtype=np.float64)
    if hasattr(model, "predict_f_cached"):
        mu_tf, var_tf = model.predict_f_cached(X, full_cov=False)
    else:
        mu_tf, var_tf = model.predict_f(X, full_cov=False)
    mu = mu_tf.numpy().reshape(-1)
    var = var_tf.numpy().reshape(-1)
    std = np.sqrt(np.maximum(var, 1e-12))
    return mu, std

def get_inducing_Z_np(model):
    return model.inducing_variable.Z.numpy().astype(np.float64)

# ---- 2D Slice: pick one feature dim to sweep (generic, since Acrobot features are not 1:1 physical x) ----
def plot_slice_feature_dim_two_actions(
    model,
    X_train, y_train,
    feat_dim=0,
    feat_min=-1.0, feat_max=1.0,
    n_grid=260,
    u_dim=-1,                 # action is last dim by default
    u_list=(+1.0, -1.0),
    title="Slice over feature dim",
    y_label="Δy",
    show_data=True,
    show_inducing=True,
):
    Z = get_inducing_Z_np(model)

    base = np.nanmedian(np.asarray(X_train, dtype=np.float64), axis=0)
    curves = []
    auto_ymin, auto_ymax = +np.inf, -np.inf

    for u_fixed in u_list:
        Xq = np.tile(base[None, :], (n_grid, 1))
        Xq[:, feat_dim] = np.linspace(feat_min, feat_max, n_grid)
        Xq[:, u_dim] = float(u_fixed)

        mu, std = gp_predict_mu_std_fast(model, Xq)
        lo, hi = mu - 2 * std, mu + 2 * std
        curves.append((u_fixed, Xq[:, feat_dim].copy(), mu, std, lo, hi))
        auto_ymin, auto_ymax = min(auto_ymin, float(lo.min())), max(auto_ymax, float(hi.max()))

    plt.figure(figsize=(9, 5))
    for u_fixed, xgrid, mu, std, lo, hi in curves:
        plt.plot(xgrid, mu, lw=2.5, label=f"mean (u={u_fixed:+.1f})")
        plt.fill_between(xgrid, lo, hi, alpha=0.18, label=f"±2σ (u={u_fixed:+.1f})")

    if show_data:
        X_train = np.asarray(X_train, dtype=np.float64)
        y_train = np.asarray(y_train, dtype=np.float64).reshape(-1)
        for u_fixed in u_list:
            mask = np.abs(X_train[:, u_dim] - float(u_fixed)) < 0.15
            if np.sum(mask) > 0:
                plt.scatter(X_train[mask, feat_dim], y_train[mask], s=18, alpha=0.35,
                            label=f"data (u≈{u_fixed:+.1f}, n={np.sum(mask)})")

    if show_inducing and (Z is not None):
        for u_fixed in u_list:
            maskZ = np.abs(Z[:, u_dim] - float(u_fixed)) < 0.15
            if np.sum(maskZ) > 0:
                Zsel = Z[maskZ]
                muZ, _ = gp_predict_mu_std_fast(model, Zsel)
                plt.scatter(Zsel[:, feat_dim], muZ, marker="x", s=70, linewidths=2.0,
                            label=f"Z (u≈{u_fixed:+.1f}, M={np.sum(maskZ)})")

    plt.xlabel(f"feature[{feat_dim}]")
    plt.ylabel(y_label)
    plt.title(title + f"  (median other dims; action dim={u_dim})")
    plt.grid(True, alpha=0.25)
    plt.legend(loc="best")
    plt.tight_layout()
    plt.show()

# ---- 3D Surface: pick two feature dims to visualize, color by std ----
def plot_surface_two_feature_dims_mean_colored_by_std(
    model,
    X_ref,
    dim_x=0,
    dim_y=1,
    x_min=-1.0, x_max=1.0,
    y_min=-1.0, y_max=1.0,
    n_grid=70,
    u_fixed=+1.0,
    u_dim=-1,
    title="3D surface over feature dims",
    z_label="Δy",
    show_inducing=True,
):
    X_ref = np.asarray(X_ref, dtype=np.float64)
    base = np.nanmedian(X_ref, axis=0)

    x_grid = np.linspace(x_min, x_max, n_grid)
    y_grid = np.linspace(y_min, y_max, n_grid)
    Xg, Yg = np.meshgrid(x_grid, y_grid)

    Xq = np.tile(base[None, :], (n_grid * n_grid, 1))
    Xq[:, dim_x] = Xg.ravel()
    Xq[:, dim_y] = Yg.ravel()
    Xq[:, u_dim] = float(u_fixed)

    mu, std = gp_predict_mu_std_fast(model, Xq)
    Mean = mu.reshape(Xg.shape)
    Std  = std.reshape(Xg.shape)

    surface = go.Surface(
        x=Xg, y=Yg, z=Mean,
        surfacecolor=Std,
        colorscale="Viridis",
        colorbar=dict(title="Std"),
        opacity=0.95,
        showscale=True,
        name="surface"
    )
    traces = [surface]

    if show_inducing:
        Z = get_inducing_Z_np(model)
        if Z is not None:
            maskZ = np.abs(Z[:, u_dim] - float(u_fixed)) < 0.15
            if np.sum(maskZ) > 0:
                Zsel = Z[maskZ]
                muZ, _ = gp_predict_mu_std_fast(model, Zsel)
                traces.append(
                    go.Scatter3d(
                        x=Zsel[:, dim_x], y=Zsel[:, dim_y], z=muZ,
                        mode="markers",
                        marker=dict(size=3, color="red", opacity=0.9),
                        name=f"inducing Z (u≈{u_fixed:+.1f}) | M={np.sum(maskZ)}"
                    )
                )

    fig = go.Figure(data=traces)
    fig.update_layout(
        title=f"{title} | fixed u={u_fixed:+.1f}",
        scene=dict(
            xaxis=dict(title=f"feature[{dim_x}]", range=[x_min, x_max]),
            yaxis=dict(title=f"feature[{dim_y}]", range=[y_min, y_max]),
            zaxis=dict(title=z_label),
        ),
        margin=dict(l=0, r=0, b=0, t=50),
        height=650
    )
    fig.show()

print("\n✅ Cell 4 complete (Acrobot global models + anchor refresh + plot helpers).")

# OPTIONAL quick demo plots (uncomment if you want):
# plot_slice_feature_dim_two_actions(m_dw2, Xtr, Yw2_tr, feat_dim=0, title="domega2 slice", y_label="Δw2")
# plot_surface_two_feature_dims_mean_colored_by_std(m_dw2, Xtr, dim_x=0, dim_y=1, title="domega2 surface", z_label="Δw2")


In [None]:
# ============================
# Cell 5 — Evaluate + Visualize GLOBAL GP (Reusable) ✅ ACROBOT
#
# Provides:
#   ✅ RMSE for dtheta1, dtheta2, domega1, domega2 on Xte
#   ✅ Slice plot (mean ±2σ) for chosen output model (generic feature dim)
#   ✅ 3D Surface mean colored by std (over 2 chosen feature dims)
#   ✅ NEW: 3D Surface of std alone (uncertainty surface)
#
# Assumes you ran:
#   Cell 4 (trained global models + plot helpers already defined)
#
# Expected variables from your Cell 4:
#   m_dth1, m_dth2, m_dw1, m_dw2
#   Xtr, Xte
#   Yth1_tr, Yth1_te, Yth2_tr, Yth2_te, Yw1_tr, Yw1_te, Yw2_tr, Yw2_te
#   gp_predict_mu_std_fast()
#   plot_slice_feature_dim_two_actions()
#   plot_surface_two_feature_dims_mean_colored_by_std()
# ============================

import numpy as np
import plotly.graph_objects as go

# ----------------------------
# 0) Safety: require Cell 4 symbols
# ----------------------------
required = [
    "gp_predict_mu_std_fast",
    "plot_slice_feature_dim_two_actions",
    "plot_surface_two_feature_dims_mean_colored_by_std",
    "m_dth1", "m_dth2", "m_dw1", "m_dw2",
    "Xtr", "Xte",
    "Yth1_tr", "Yth1_te",
    "Yth2_tr", "Yth2_te",
    "Yw1_tr", "Yw1_te",
    "Yw2_tr", "Yw2_te",
]
missing = [k for k in required if k not in globals()]
if len(missing) > 0:
    raise NameError(f"Cell 5 missing required symbols (run Cell 4 first): {missing}")

# ----------------------------
# 1) RMSE helper
# ----------------------------
def rmse_np(yhat, y):
    yhat = np.asarray(yhat).reshape(-1)
    y = np.asarray(y).reshape(-1)
    return float(np.sqrt(np.mean((yhat - y) ** 2)))

def print_global_rmse_acrobot():
    mu_th1, _ = gp_predict_mu_std_fast(m_dth1, Xte)
    mu_th2, _ = gp_predict_mu_std_fast(m_dth2, Xte)
    mu_w1,  _ = gp_predict_mu_std_fast(m_dw1,  Xte)
    mu_w2,  _ = gp_predict_mu_std_fast(m_dw2,  Xte)

    print("=== Test RMSE (global models, ACROBOT) ===")
    print(f"dtheta1  RMSE: {rmse_np(mu_th1, Yth1_te):.6f}")
    print(f"dtheta2  RMSE: {rmse_np(mu_th2, Yth2_te):.6f}")
    print(f"domega1  RMSE: {rmse_np(mu_w1,  Yw1_te):.6f}")
    print(f"domega2  RMSE: {rmse_np(mu_w2,  Yw2_te):.6f}")

# ----------------------------
# 2) Std-only surface over 2 chosen feature dims (ACROBOT)
# ----------------------------
def plot_surface_two_feature_dims_std_only(
    model,
    X_ref,
    dim_x=0,
    dim_y=1,
    x_min=-1.0, x_max=1.0,
    y_min=-1.0, y_max=1.0,
    n_grid=70,
    u_fixed=+1.0,
    u_dim=-1,
    title="3D Std surface (uncertainty only)",
):
    X_ref = np.asarray(X_ref, dtype=np.float64)
    base = np.nanmedian(X_ref, axis=0)

    x_grid = np.linspace(x_min, x_max, n_grid)
    y_grid = np.linspace(y_min, y_max, n_grid)
    Xg, Yg = np.meshgrid(x_grid, y_grid)

    Xq = np.tile(base[None, :], (n_grid * n_grid, 1))
    Xq[:, dim_x] = Xg.ravel()
    Xq[:, dim_y] = Yg.ravel()
    Xq[:, u_dim] = float(u_fixed)

    _, std = gp_predict_mu_std_fast(model, Xq)
    Std = std.reshape(Xg.shape)

    fig = go.Figure(
        data=[
            go.Surface(
                x=Xg, y=Yg, z=Std,
                colorscale="Viridis",
                colorbar=dict(title="Std"),
                opacity=0.98,
                showscale=True,
                name="std surface"
            )
        ]
    )
    fig.update_layout(
        title=f"{title} | fixed u={u_fixed:+.1f}",
        scene=dict(
            xaxis=dict(title=f"feature[{dim_x}]", range=[x_min, x_max]),
            yaxis=dict(title=f"feature[{dim_y}]", range=[y_min, y_max]),
            zaxis=dict(title="std"),
        ),
        margin=dict(l=0, r=0, b=0, t=50),
        height=650
    )
    fig.show()

# ----------------------------
# 3) One-call global evaluation bundle (REUSABLE after updates)
# ----------------------------
def eval_and_plot_global_acrobot(
    tag="GLOBAL (init)",
    model_for_plots=None,
    y_train_for_plots=None,
    # slice config
    slice_feat_dim=0,
    slice_feat_min=-1.0,
    slice_feat_max=1.0,
    # surface config
    surf_dim_x=0,
    surf_dim_y=1,
    surf_x_min=-1.0, surf_x_max=1.0,
    surf_y_min=-1.0, surf_y_max=1.0,
    # action config
    u_fixed=+1.0,
    u_dim=-1,
):
    """
    Notes on feature dims for Acrobot (your 7D features):
      0: sin(th1)
      1: cos(th1)
      2: sin(th2)
      3: cos(th2)
      4: tanh(w1/s1)
      5: tanh(w2/s2)
      6: u
    """
    if model_for_plots is None:
        model_for_plots = m_dw2
    if y_train_for_plots is None:
        y_train_for_plots = Yw2_tr

    print("\n==============================")
    print(f"GLOBAL EVAL (ACROBOT): {tag}")
    print("==============================")
    print_global_rmse_acrobot()

    # Slice: mean ±2σ, compare u=+1/-1
    plot_slice_feature_dim_two_actions(
        model=model_for_plots,
        X_train=Xtr,
        y_train=y_train_for_plots,
        feat_dim=int(slice_feat_dim),
        feat_min=float(slice_feat_min),
        feat_max=float(slice_feat_max),
        n_grid=260,
        u_dim=int(u_dim),
        u_list=(+1.0, -1.0),
        title=f"{tag} slice: mean ±2σ (u=+1/-1) + inducing",
        y_label="Δy",
        show_data=True,
        show_inducing=True,
    )

    # Surface: mean colored by std
    plot_surface_two_feature_dims_mean_colored_by_std(
        model=model_for_plots,
        X_ref=Xtr,
        dim_x=int(surf_dim_x),
        dim_y=int(surf_dim_y),
        x_min=float(surf_x_min), x_max=float(surf_x_max),
        y_min=float(surf_y_min), y_max=float(surf_y_max),
        n_grid=70,
        u_fixed=float(u_fixed),
        u_dim=int(u_dim),
        title=f"{tag} surface: mean colored by std (+ inducing)",
        z_label="Δy",
        show_inducing=True,
    )

    # NEW: std-only surface
    plot_surface_two_feature_dims_std_only(
        model=model_for_plots,
        X_ref=Xtr,
        dim_x=int(surf_dim_x),
        dim_y=int(surf_dim_y),
        x_min=float(surf_x_min), x_max=float(surf_x_max),
        y_min=float(surf_y_min), y_max=float(surf_y_max),
        n_grid=70,
        u_fixed=float(u_fixed),
        u_dim=int(u_dim),
        title=f"{tag} surface: std only (uncertainty)",
    )

# ----------------------------
# RUN ONCE FOR INITIAL GLOBAL (ACROBOT)
# ----------------------------
# Default visualization:
#   - slice over feature[0]=sin(th1)
#   - surface over feature[0]=sin(th1) vs feature[4]=tanh(w1/s1)
eval_and_plot_global_acrobot(
    tag="GLOBAL (initial)",
    model_for_plots=m_dw2,
    y_train_for_plots=Yw2_tr,
    slice_feat_dim=0,
    surf_dim_x=0,
    surf_dim_y=4,
    u_fixed=+1.0,
    u_dim=-1
)


In [None]:
# ===========================
# Cell 6 — MPPI + Online OSGPR (fixed-size Z_GLOBAL) + Local subset (PALSGP-style) ✅ ACROBOT
#   ✅ Keeps the "NO RETRACING" fix:
#      - One compiled LocalSubsetPredictor per head
#      - Only assign() new tensors when subset changes
# ===========================

import time
import copy
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display
import plotly.graph_objects as go
import gpflow
import tensorflow as tf
from scipy.stats import chi2
from gpflow.utilities import parameter_dict, multiple_assign

# ----------------------------
# Enforce float64 (per your pipeline going forward)
# ----------------------------
try:
    gpflow.config.set_default_float(np.float64)
except Exception:
    pass
gpflow.config.set_default_jitter(1e-6)
tf.keras.backend.set_floatx("float64")

DTYPE_TF = gpflow.default_float()
assert DTYPE_TF == tf.float64, f"Expected float64 default, got {DTYPE_TF}. Set gpflow default float to float64."

# ----------------------------
# GPU sanity / enforcement
# ----------------------------
REQUIRE_GPU = True
LOG_DEVICE_PLACEMENT = False

print("TF built with CUDA:", tf.test.is_built_with_cuda())
print("GPUs visible:", tf.config.list_physical_devices("GPU"))
print("Logical GPUs:", tf.config.list_logical_devices("GPU"))

if LOG_DEVICE_PLACEMENT:
    tf.debugging.set_log_device_placement(True)

if REQUIRE_GPU:
    gpus = tf.config.list_physical_devices("GPU")
    if len(gpus) == 0:
        raise RuntimeError("REQUIRE_GPU=True but no GPU is visible to TensorFlow. Fix CUDA/TF install or set REQUIRE_GPU=False.")

# ----------------------------
# Config (Acrobot)
# ----------------------------
HORIZON    = 25
K_SAMPLES  = 512
SIGMA      = 0.6
LAMBDA     = 1.0

UPDATE_EVERY = 60
ITERS_UPDATE = 200
LR_UPDATE    = 0.02
NOISE_UPDATE = 1e-4

M_GLOBAL = 512

M_SUB           = 72
ANCHOR_M   = 32
ANCHOR_LAM = 1e-6

# Local subset rebuild logic
LOCAL_REBUILD_EVERY   = 10
LOCAL_OVERLAP_THRESH  = 0.70

# Tube selection
TUBE_ALPHA      = 0.99
TUBE_COV_EPS    = 1e-6
TUBE_TIME_BINS  = 8
TUBE_FALLBACK_CANDIDATES = 400

# Recording
RECORD_RGB_DEFAULT = True
RESIZE       = (720, 450)
FPS          = 10
FRAME_STRIDE = 2

# Acrobot success + cost
Y_TIP_GOAL = 1.0
HOLD_STEPS = 60  # smaller than CartPole; tweak as you like

# Exploration schedule (uncertainty bonus)
EXPLORE_STEPS = 1000
UNC_W_MAX     = 15.0
UNC_W_MIN     = 0.0

# Cost weights
TIP_W      = 3.0     # encourage large y_tip
VEL_W      = 0.05    # penalize angular velocities
U_W        = 0.00    # penalize control
TERM_BONUS = 5.0     # terminal bonus when y_tip high

if "U_MIN" not in globals(): U_MIN = -1.0
if "U_MAX" not in globals(): U_MAX = +1.0

PLOT_EACH_UPDATE = False

# Std scale unify (optional; used only for consistent plotting if you store surfaces)
STD_CMIN_FIXED = 0.0
STD_FIXED_Q    = 0.5
STD_MODE       = "fixed"
STD_CMAX_FIXED = None

STORE_GLOBAL_SURFACES = False
GLOBAL_SURF_HISTORY = []   # reset per RUN

# ----------------------------
# NEW: Cross-method registry for unified evaluation plots
# ----------------------------
if "EVAL_REGISTRY" not in globals():
    EVAL_REGISTRY = {}

def register_method_results(method_name, payload):
    EVAL_REGISTRY[str(method_name)] = payload

# ----------------------------
# Safety: require dependencies from earlier cells
# ----------------------------
required = [
    "make_env", "obs_to_state", "wrap_pi", "state_to_features",
    "batch_state_to_features", "OSGPR_VFE", "train_osgpr",
    "extract_summary_from_model", "greedy_dopt_anchors_from_K",
    "clone_kernel",
    "m_dth1", "m_dth2", "m_dw1", "m_dw2",
    "Z_GLOBAL",
]
missing = [k for k in required if k not in globals()]
if len(missing) > 0:
    raise NameError(f"Cell 6 missing required symbols from earlier cells: {missing}")

if "render_acrobot_frame_from_state" not in globals():
    print("⚠️ render_acrobot_frame_from_state not found. RGB recording will be disabled.")
    RECORD_RGB_DEFAULT = False

# ----------------------------
# Acrobot tip height proxy
# ----------------------------
def acrobot_tip_height_proxy_np(th1, th2):
    return float(-np.cos(th1) - np.cos(th1 + th2))

@tf.function
def acrobot_tip_height_proxy_tf(th1, th2):
    return -(tf.cos(th1) + tf.cos(th1 + th2))

# ----------------------------
# Helper: sym_jitter (numpy)
# ----------------------------
def sym_jitter(A, jitter=1e-6):
    A = np.asarray(A, dtype=np.float64)
    A = 0.5 * (A + A.T)
    return A + float(jitter) * np.eye(A.shape[0], dtype=np.float64)

# ----------------------------
# Global predict helpers (mean/std)
# ----------------------------
def gp_predict_mu_std_fast(model, X):
    X = np.asarray(X, dtype=np.float64)
    if hasattr(model, "predict_f_cached"):
        mu_tf, var_tf = model.predict_f_cached(X, full_cov=False)
    else:
        mu_tf, var_tf = model.predict_f(X, full_cov=False)
    mu = mu_tf.numpy().reshape(-1)
    var = var_tf.numpy().reshape(-1)
    std = np.sqrt(np.maximum(var, 1e-12))
    return mu, std

def update_unified_std_scale(std_list, q=0.99, mode="fixed"):
    global STD_CMAX_FIXED
    all_std = np.concatenate([np.asarray(s, dtype=np.float64).reshape(-1) for s in std_list if s is not None], axis=0)
    if all_std.size == 0:
        if STD_CMAX_FIXED is None:
            STD_CMAX_FIXED = 1.0
        return (STD_CMIN_FIXED, STD_CMAX_FIXED)

    cand = float(np.quantile(all_std, q))
    cand = max(cand, 1e-8)

    if STD_CMAX_FIXED is None:
        STD_CMAX_FIXED = cand
    elif mode == "grow_only":
        STD_CMAX_FIXED = max(STD_CMAX_FIXED, cand)

    return (STD_CMIN_FIXED, STD_CMAX_FIXED)

# ============================================================
# CLEAN multi-head Z_GLOBAL refit (ACROBOT): aggregate kernels across 4 heads
# ============================================================
def refit_Z_global_multihead(Z_global, Xnew, M_global, lam=1e-6, mode="mean", normalize_traces=False):
    Z_global = np.asarray(Z_global, dtype=np.float64)
    Xnew = np.asarray(Xnew, dtype=np.float64)

    Z_cand = np.vstack([Z_global, Xnew]).astype(np.float64, copy=False)
    Zc_tf = tf.convert_to_tensor(Z_cand, dtype=tf.float64)

    K1 = m_dth1.kernel.K(Zc_tf).numpy().astype(np.float64, copy=False)
    K2 = m_dth2.kernel.K(Zc_tf).numpy().astype(np.float64, copy=False)
    K3 = m_dw1.kernel.K(Zc_tf).numpy().astype(np.float64, copy=False)
    K4 = m_dw2.kernel.K(Zc_tf).numpy().astype(np.float64, copy=False)

    if normalize_traces:
        def _norm(K):
            tr = float(np.trace(K))
            return K / max(tr, 1e-12)
        K1, K2, K3, K4 = map(_norm, [K1, K2, K3, K4])

    K_agg = (K1 + K2 + K3 + K4)
    if mode != "sum":
        K_agg = K_agg / 4.0

    idxZ = greedy_dopt_anchors_from_K(K_agg, m_anchors=int(M_global), lam=float(lam))
    Z_new = np.asarray(Z_cand[np.asarray(idxZ, dtype=np.int64)], dtype=np.float64)

    if Z_new.shape[0] != int(M_global):
        Z_new = Z_new[:int(M_global)].copy()
    return Z_new

# ============================================================
# Online OSGPR update (unchanged)
# ============================================================
def osgpr_stream_update(model_old, X_new, Y_new, Z_new,
                        iters=100, lr=0.02, noise=1e-4,
                        freeze_kernel=False, clip_norm=10.0):
    X_new = np.asarray(X_new, dtype=np.float64)
    Y_new = np.asarray(Y_new, dtype=np.float64).reshape(-1, 1)

    mu_old, Su_old, Kaa_old, Z_old = extract_summary_from_model(model_old)
    Z_new = np.asarray(Z_new, dtype=np.float64)

    k_new = clone_kernel(model_old.kernel)

    m = OSGPR_VFE(
        data=(X_new, Y_new),
        kernel=k_new,
        mu_old=mu_old, Su_old=Su_old, Kaa_old=Kaa_old, Z_old=Z_old,
        Z=Z_new
    )
    m.likelihood.variance.assign(float(noise))

    if freeze_kernel:
        try:
            m.kernel.variance.trainable = False
            m.kernel.lengthscales.trainable = False
        except Exception:
            pass

    t_train, last_loss = train_osgpr(m, iters=iters, lr=lr, clip_norm=clip_norm)
    m.build_predict_cache()

    info = dict(
        train_seconds=float(t_train),
        last_neg_objective=float(last_loss),
        M_new=int(m.inducing_variable.num_inducing),
    )
    return m, info

# ============================================================
# Anchors via aggregated kernel on Z_GLOBAL (ACROBOT)
# ============================================================
def compute_anchor_idx_dopt_from_Zglobal_multihead(Z_global, m_anchors=18, lam=1e-6, normalize_traces=True):
    Z = np.asarray(Z_global, dtype=np.float64)
    Ztf = tf.convert_to_tensor(Z, dtype=tf.float64)

    K1 = m_dth1.kernel.K(Ztf).numpy().astype(np.float64, copy=False)
    K2 = m_dth2.kernel.K(Ztf).numpy().astype(np.float64, copy=False)
    K3 = m_dw1.kernel.K(Ztf).numpy().astype(np.float64, copy=False)
    K4 = m_dw2.kernel.K(Ztf).numpy().astype(np.float64, copy=False)

    if normalize_traces:
        def _norm(K):
            tr = float(np.trace(K))
            return K / max(tr, 1e-12)
        K1, K2, K3, K4 = map(_norm, [K1, K2, K3, K4])

    K_agg = (K1 + K2 + K3 + K4) / 4.0
    return greedy_dopt_anchors_from_K(K_agg, m_anchors=int(m_anchors), lam=float(lam))

# ============================================================
# Chi-square Mahalanobis tube selection (unchanged)
# ============================================================
def normalize_nonnegative_weights(w, eps=1e-12):
    w = np.asarray(w, dtype=np.float64).reshape(-1)
    w = np.maximum(w, 0.0)
    s = float(np.sum(w))
    if s < eps:
        return np.ones_like(w) / max(len(w), 1)
    return w / s

def compute_weighted_moments(rollout_inputs, rollout_weights):
    X = np.asarray(rollout_inputs, dtype=np.float64)  # (K,H,D)
    K, H, D = X.shape
    w = normalize_nonnegative_weights(rollout_weights).reshape(K, 1, 1)
    tube_mean = np.sum(w * X, axis=0)  # (H,D)
    Xc = X - tube_mean[None, :, :]
    tube_cov = np.einsum("khd,khe->hde", (w * Xc), Xc)  # (H,D,D)
    return tube_mean, tube_cov

def min_mahalanobis_and_argmin(points_Z, tube_mean, tube_cov, cov_eps=1e-6):
    Z = np.asarray(points_Z, dtype=np.float64)
    mu = np.asarray(tube_mean, dtype=np.float64)
    Sig = np.asarray(tube_cov, dtype=np.float64)

    M, D = Z.shape
    H = mu.shape[0]
    I = np.eye(D, dtype=np.float64)

    dmin = np.full((M,), np.inf, dtype=np.float64)
    tmin = np.zeros((M,), dtype=np.int64)

    for t in range(H):
        St = Sig[t] + float(cov_eps) * I
        L = np.linalg.cholesky(St)

        diff = (Z - mu[t:t+1, :]).T  # (D,M)
        y = np.linalg.solve(L, diff)
        y2 = np.linalg.solve(L.T, y)
        quad = np.sum(diff * y2, axis=0)  # (M,)

        mask = quad < dmin
        dmin[mask] = quad[mask]
        tmin[mask] = t

    return dmin, tmin

def split_even_quota(total, num_bins):
    num_bins = int(max(1, num_bins))
    base = total // num_bins
    rem = total - base * num_bins
    q = np.full((num_bins,), base, dtype=np.int64)
    if rem > 0:
        q[:rem] += 1
    return q

def chi2_radius_sq(alpha, D):
    alpha = float(alpha)
    D = int(D)
    alpha = min(max(alpha, 1e-6), 1.0 - 1e-12)
    return float(chi2.ppf(alpha, df=D))

def greedy_dopt_select_from_kernel(K, k, jitter=1e-6, log_eps=1e-300):
    K = np.asarray(K, dtype=np.float64)
    n = K.shape[0]
    k = int(min(k, n))
    if k <= 0:
        return np.zeros((0,), dtype=np.int64)

    diag = np.diag(K).copy() + float(jitter)
    chosen = []
    Lchol = None

    for t in range(k):
        if t == 0:
            safe = np.maximum(diag, log_eps)
            i = int(np.argmax(np.log(safe)))
            chosen.append(i)
            Lchol = np.sqrt(max(diag[i], log_eps)).reshape(1, 1)
            continue

        S = np.array(chosen, dtype=np.int64)
        K_S = K[:, S]
        v = np.linalg.solve(Lchol, K_S.T)
        sq = np.sum(v * v, axis=0)
        schur = diag - sq

        schur[S] = -np.inf
        schur_pos = np.maximum(schur, log_eps)
        scores = np.log(schur_pos)
        scores[S] = -np.inf

        i = int(np.argmax(scores))
        chosen.append(i)

        k_iS = K[i, S].reshape(1, -1)
        w = np.linalg.solve(Lchol, k_iS.T)
        alpha2 = diag[i] - float(np.sum(w * w))
        alpha2 = max(alpha2, log_eps)
        alpha = np.sqrt(alpha2)

        Lnew = np.zeros((t + 1, t + 1), dtype=np.float64)
        Lnew[:t, :t] = Lchol
        Lnew[t, :t] = w.reshape(-1)
        Lnew[t, t] = alpha
        Lchol = Lnew

    return np.array(chosen, dtype=np.int64)

def select_tube_subset(
    Z_global,
    kernel_for_dopt,
    rollout_inputs,
    rollout_weights,
    total_subset_size=64,
    anchor_idx=None,
    time_bins=16,
    cov_eps=1e-6,
    dopt_jitter=1e-6,
    fallback_candidates=400,
    alpha=0.99,
):
    Zg = np.asarray(Z_global, dtype=np.float64)
    M = Zg.shape[0]

    if anchor_idx is None:
        anchor_idx = np.zeros((0,), dtype=np.int64)
    anchor_idx = np.unique(np.asarray(anchor_idx, dtype=np.int64))
    anchor_idx = anchor_idx[(anchor_idx >= 0) & (anchor_idx < M)]

    tube_mean, tube_cov = compute_weighted_moments(rollout_inputs, rollout_weights)
    H, D = tube_mean.shape

    chi2_thr = chi2_radius_sq(alpha, D)

    dmin, tmin = min_mahalanobis_and_argmin(Zg, tube_mean, tube_cov, cov_eps=cov_eps)
    tube_candidates_idx = np.where(dmin <= chi2_thr)[0].astype(np.int64)

    remaining_budget = int(total_subset_size - anchor_idx.size)
    remaining_budget = max(0, remaining_budget)

    if tube_candidates_idx.size < remaining_budget:
        L = int(min(max(fallback_candidates, remaining_budget), M))
        tube_candidates_idx = np.argsort(dmin)[:L].astype(np.int64)

    tube_pool = tube_candidates_idx[~np.isin(tube_candidates_idx, anchor_idx)]
    tmin_pool = tmin[tube_pool]

    selected = set(int(i) for i in anchor_idx)

    if remaining_budget > 0 and tube_pool.size > 0:
        B = int(max(1, min(time_bins, H)))
        edges = np.linspace(0, H, B + 1, dtype=np.int64)
        quotas = split_even_quota(remaining_budget, B)

        for b in range(B):
            if len(selected) >= total_subset_size:
                break

            a, c = int(edges[b]), int(edges[b + 1])
            if c <= a:
                continue

            in_bin = (tmin_pool >= a) & (tmin_pool < c)
            bin_pool = tube_pool[in_bin]
            if bin_pool.size == 0:
                continue

            k_b = int(min(quotas[b], bin_pool.size, total_subset_size - len(selected)))
            if k_b <= 0:
                continue

            Zb = Zg[bin_pool]
            Kb = kernel_for_dopt.K(tf.convert_to_tensor(Zb, dtype=tf.float64)).numpy().astype(np.float64)
            pick = greedy_dopt_select_from_kernel(Kb, k=k_b, jitter=dopt_jitter)

            for gidx in bin_pool[pick]:
                selected.add(int(gidx))

    if len(selected) < total_subset_size:
        need = int(total_subset_size - len(selected))
        leftover = tube_pool[~np.isin(tube_pool, np.array(list(selected), dtype=np.int64))]
        if leftover.size > 0 and need > 0:
            Zl = Zg[leftover]
            Kl = kernel_for_dopt.K(tf.convert_to_tensor(Zl, dtype=tf.float64)).numpy().astype(np.float64)
            pick = greedy_dopt_select_from_kernel(Kl, k=min(need, leftover.size), jitter=dopt_jitter)
            for gidx in leftover[pick]:
                selected.add(int(gidx))

    if len(selected) < total_subset_size:
        need = int(total_subset_size - len(selected))
        remaining = np.setdiff1d(np.arange(M, dtype=np.int64), np.array(list(selected), dtype=np.int64), assume_unique=False)
        if remaining.size > 0:
            order = remaining[np.argsort(dmin[remaining])]
            for gidx in order[:need]:
                selected.add(int(gidx))

    subset_idx = np.array(sorted(selected), dtype=np.int64)[:total_subset_size]
    return subset_idx, tube_mean, tube_cov, tube_candidates_idx, chi2_thr

# ============================================================
# TF local predictor pack (NO RETRACING) ✅ ACROBOT
# ============================================================
@tf.function
def wrap_pi_tf(theta):
    two_pi = tf.constant(2.0 * np.pi, dtype=DTYPE_TF)
    pi = tf.constant(np.pi, dtype=DTYPE_TF)
    return tf.math.floormod(theta + pi, two_pi) - pi

@tf.function
def batch_state_to_features_tf(S, U):
    """
    S: (B,4) = [th1, th2, w1, w2]
    U: (B,)
    -> (B,7) = [sin(th1), cos(th1), sin(th2), cos(th2), tanh(w1/8), tanh(w2/10), u]
    """
    th1 = S[:, 0]
    th2 = S[:, 1]
    w1  = S[:, 2]
    w2  = S[:, 3]
    f0 = tf.sin(th1)
    f1 = tf.cos(th1)
    f2 = tf.sin(th2)
    f3 = tf.cos(th2)
    f4 = tf.tanh(w1 / tf.constant(8.0, dtype=DTYPE_TF))
    f5 = tf.tanh(w2 / tf.constant(10.0, dtype=DTYPE_TF))
    f6 = U
    return tf.stack([f0, f1, f2, f3, f4, f5, f6], axis=1)

@tf.function
def se_ard_kernel_Kzx_tf(Z, X, lengthscales, variance):
    ls = tf.reshape(lengthscales, (1, -1))
    var = tf.cast(variance, DTYPE_TF)
    Zs = Z / ls
    Xs = X / ls
    z2 = tf.reduce_sum(Zs * Zs, axis=1, keepdims=True)
    x2 = tf.reduce_sum(Xs * Xs, axis=1, keepdims=True)
    zx = tf.matmul(Zs, Xs, transpose_b=True)
    r2 = tf.maximum(z2 + tf.transpose(x2) - 2.0 * zx, 0.0)
    return var * tf.exp(-0.5 * r2)

class LocalSubsetPredictor(tf.Module):
    def __init__(self, Dfeat, Msub, name=None):
        super().__init__(name=name)
        self.Dfeat = int(Dfeat)
        self.Msub  = int(Msub)

        self.Z     = tf.Variable(tf.zeros([self.Msub, self.Dfeat], dtype=tf.float64), trainable=False)
        self.L     = tf.Variable(tf.eye(self.Msub, dtype=tf.float64), trainable=False)
        self.alpha = tf.Variable(tf.zeros([self.Msub, 1], dtype=tf.float64), trainable=False)
        self.S     = tf.Variable(tf.eye(self.Msub, dtype=tf.float64), trainable=False)

        self.ls    = tf.Variable(tf.ones([self.Dfeat], dtype=tf.float64), trainable=False)
        self.var   = tf.Variable(1.0, dtype=tf.float64, trainable=False)

    def assign_pack(self, Z, L, alpha, S, ls, var):
        self.Z.assign(tf.cast(Z, tf.float64))
        self.L.assign(tf.cast(L, tf.float64))
        self.alpha.assign(tf.cast(alpha, tf.float64))
        self.S.assign(tf.cast(S, tf.float64))
        self.ls.assign(tf.cast(ls, tf.float64))
        self.var.assign(tf.cast(var, tf.float64))

    @tf.function(reduce_retracing=True)
    def predict_mu_var(self, Xfeat):
        Xfeat = tf.cast(Xfeat, tf.float64)
        Kzx = se_ard_kernel_Kzx_tf(self.Z, Xfeat, self.ls, self.var)  # (M,N)

        w1 = tf.linalg.triangular_solve(self.L, Kzx, lower=True)
        W  = tf.linalg.triangular_solve(tf.transpose(self.L), w1, lower=False)

        mu = tf.reshape(tf.matmul(Kzx, self.alpha, transpose_a=True), (-1,))

        kxx = tf.fill((tf.shape(Xfeat)[0],), tf.cast(self.var, tf.float64))
        Qdiag = tf.reduce_sum(Kzx * W, axis=0)
        SW = tf.matmul(self.S, W)
        Sdiag = tf.reduce_sum(W * SW, axis=0)

        v = tf.maximum(kxx - Qdiag + Sdiag, tf.cast(1e-12, tf.float64))
        return mu, v

def build_local_pack_from_global_tf(model_global, Z_global, idx_sub):
    idx_sub = np.asarray(idx_sub, dtype=np.int64)
    Zg = np.asarray(Z_global, dtype=np.float64)
    Z  = Zg[idx_sub].copy()

    Ztf = tf.convert_to_tensor(Z, dtype=tf.float64)
    muZ, SigZ = model_global.predict_f(Ztf, full_cov=True)
    m = muZ
    S = SigZ
    if len(S.shape) == 3:
        S = S[0]

    ls  = tf.cast(model_global.kernel.lengthscales, tf.float64)
    var = tf.cast(model_global.kernel.variance, tf.float64)

    Kzz = se_ard_kernel_Kzx_tf(Ztf, Ztf, ls, var)
    jitter = tf.cast(1e-6, tf.float64)
    Kzz = 0.5 * (Kzz + tf.transpose(Kzz)) + jitter * tf.eye(tf.shape(Kzz)[0], dtype=tf.float64)
    L = tf.linalg.cholesky(Kzz)

    y = tf.linalg.triangular_solve(L, m, lower=True)
    alpha = tf.linalg.triangular_solve(tf.transpose(L), y, lower=False)

    return (Ztf, L, alpha, S, tf.reshape(ls, (-1,)), var)

# ============================================================
# MPPI cost (Acrobot)
# ============================================================
@tf.function
def exploration_weight_tf(t):
    explore_steps_f = tf.cast(EXPLORE_STEPS, DTYPE_TF)
    t_f = tf.cast(t, DTYPE_TF)
    a = tf.clip_by_value(1.0 - t_f / tf.maximum(explore_steps_f, 1.0), 0.0, 1.0)
    return tf.cast(UNC_W_MIN, DTYPE_TF) + (tf.cast(UNC_W_MAX, DTYPE_TF) - tf.cast(UNC_W_MIN, DTYPE_TF)) * a

@tf.function
def stage_cost_acrobot_tf(S, U, unc_bonus=None, unc_w=0.0):
    th1 = S[:, 0]
    th2 = S[:, 1]
    w1  = S[:, 2]
    w2  = S[:, 3]
    y_tip = acrobot_tip_height_proxy_tf(th1, th2)

    c = (
        - tf.cast(TIP_W, DTYPE_TF) * y_tip
        + tf.cast(VEL_W, DTYPE_TF) * (tf.square(w1) + tf.square(w2))
        + tf.cast(U_W, DTYPE_TF) * tf.square(U)
    )
    if (unc_bonus is not None) and (unc_w > 0.0):
        c = c - tf.cast(unc_w, DTYPE_TF) * tf.cast(unc_bonus, DTYPE_TF)
    return c

@tf.function
def terminal_cost_acrobot_tf(S):
    th1 = S[:, 0]
    th2 = S[:, 1]
    y_tip = acrobot_tip_height_proxy_tf(th1, th2)
    good = y_tip >= tf.cast(Y_TIP_GOAL, DTYPE_TF)
    cT = tf.where(good, -tf.cast(TERM_BONUS, DTYPE_TF) * tf.ones_like(y_tip), tf.zeros_like(y_tip))
    return cT

# ============================================================
# Global-local dynamics: uses persistent predictors (NO RETRACING)
# ============================================================
@tf.function(reduce_retracing=True)
def gp_dynamics_step_batch_local_tf(S, U, local_dth1, local_dth2, local_dw1, local_dw2):
    Xfeat = batch_state_to_features_tf(S, U)

    dth1, _ = local_dth1.predict_mu_var(Xfeat)
    dth2, _ = local_dth2.predict_mu_var(Xfeat)
    dw1,  _ = local_dw1.predict_mu_var(Xfeat)
    dw2,  _ = local_dw2.predict_mu_var(Xfeat)

    S2 = tf.stack([
        wrap_pi_tf(S[:, 0] + dth1),
        wrap_pi_tf(S[:, 1] + dth2),
        S[:, 2] + dw1,
        S[:, 3] + dw2
    ], axis=1)
    return S2, Xfeat

@tf.function(reduce_retracing=True)
def rollout_tube_features_local_tf(state0, u_seq, local_dth1, local_dth2, local_dw1, local_dw2):
    H = tf.shape(u_seq)[0]
    s = tf.identity(state0)
    tube = tf.TensorArray(dtype=DTYPE_TF, size=H)

    t = tf.constant(0)
    def cond(t, s, tube): return t < H
    def body(t, s, tube):
        u = u_seq[t]
        xfeat = batch_state_to_features_tf(tf.expand_dims(s, axis=0), tf.expand_dims(u, axis=0))[0]
        tube = tube.write(t, xfeat)
        s2, _ = gp_dynamics_step_batch_local_tf(
            tf.expand_dims(s, axis=0), tf.expand_dims(u, axis=0),
            local_dth1, local_dth2, local_dw1, local_dw2
        )
        s = s2[0]
        return t+1, s, tube

    _, _, tube = tf.while_loop(cond, body, [t, s, tube], parallel_iterations=1)
    return tube.stack()

@tf.function(reduce_retracing=True)
def mppi_plan_gpu_local_tf(state0, u_mean0, t_global,
                           local_dth1, local_dth2, local_dw1, local_dw2,
                           horizon=HORIZON, K=K_SAMPLES, sigma=SIGMA, lam=LAMBDA,
                           base_seed=0):
    H = tf.cast(horizon, tf.int32)
    Kt = tf.cast(K, tf.int32)

    seed = tf.stack([tf.cast(base_seed, tf.int32), tf.cast(t_global, tf.int32)], axis=0)
    eps = tf.random.stateless_normal((Kt, H), seed=seed, mean=0.0, stddev=tf.cast(sigma, DTYPE_TF), dtype=DTYPE_TF)
    U = tf.clip_by_value(u_mean0[None, :] + eps, tf.cast(U_MIN, DTYPE_TF), tf.cast(U_MAX, DTYPE_TF))

    S = tf.tile(state0[None, :], [Kt, 1])
    total_cost = tf.zeros((Kt,), dtype=DTYPE_TF)

    Xta = tf.TensorArray(dtype=DTYPE_TF, size=H)
    unc_w = exploration_weight_tf(t_global)

    t = tf.constant(0, dtype=tf.int32)
    def cond(t, S, total_cost, Xta): return t < H
    def body(t, S, total_cost, Xta):
        Ut = U[:, t]
        S2, Xfeat = gp_dynamics_step_batch_local_tf(S, Ut, local_dth1, local_dth2, local_dw1, local_dw2)
        Xta = Xta.write(t, Xfeat)

        # uncertainty bonus: use dw2 variance (can swap to dw1 if you like)
        _, unc_v = local_dw2.predict_mu_var(Xfeat)
        total_cost = total_cost + stage_cost_acrobot_tf(S, Ut, unc_bonus=unc_v, unc_w=unc_w)
        return t+1, S2, total_cost, Xta

    _, S, total_cost, Xta = tf.while_loop(cond, body, [t, S, total_cost, Xta], parallel_iterations=1)
    total_cost = total_cost + terminal_cost_acrobot_tf(S)

    cmin = tf.reduce_min(total_cost)
    w = tf.exp(-(total_cost - cmin) / tf.cast(lam, DTYPE_TF))
    wsum = tf.reduce_sum(w) + tf.cast(1e-12, DTYPE_TF)

    u_mean = u_mean0 + tf.reduce_sum(w[:, None] * eps, axis=0) / wsum
    u_mean = tf.clip_by_value(u_mean, tf.cast(U_MIN, DTYPE_TF), tf.cast(U_MAX, DTYPE_TF))

    tubeX = rollout_tube_features_local_tf(state0, u_mean, local_dth1, local_dth2, local_dw1, local_dw2)
    XHKD = Xta.stack()
    Xroll = tf.transpose(XHKD, perm=[1, 0, 2])  # (K,H,Dfeat)
    return u_mean[0], u_mean, tubeX, unc_w, Xroll, w

def mppi_plan_gpu_local(state, u_init, t_global,
                        local_dth1, local_dth2, local_dw1, local_dw2,
                        base_seed=0):
    state0 = tf.convert_to_tensor(np.asarray(state, dtype=np.float64).reshape(4,), dtype=DTYPE_TF)
    u0     = tf.convert_to_tensor(np.asarray(u_init, dtype=np.float64).reshape(-1,), dtype=DTYPE_TF)

    dev = "/GPU:0" if len(tf.config.list_logical_devices("GPU")) > 0 else "/CPU:0"
    if REQUIRE_GPU: dev = "/GPU:0"

    with tf.device(dev):
        u_first, u_mean, tubeX, unc_w, Xroll, w = mppi_plan_gpu_local_tf(
            state0, u0, tf.convert_to_tensor(int(t_global), dtype=tf.int32),
            local_dth1, local_dth2, local_dw1, local_dw2,
            horizon=HORIZON, K=K_SAMPLES, sigma=SIGMA, lam=LAMBDA,
            base_seed=int(base_seed)
        )

    return float(u_first.numpy()), u_mean.numpy(), tubeX.numpy(), float(unc_w.numpy()), Xroll.numpy(), w.numpy()

# ----------------------------
# Visualization helpers
# ----------------------------
def display_run_animation_from_frames(frames, fps=FPS, resize=RESIZE):
    if frames is None or len(frames) == 0:
        print("⚠️ No frames captured for this run.")
        return None

    fig = plt.figure(figsize=(resize[0]/100, resize[1]/100), dpi=100)
    plt.axis("off")
    im = plt.imshow(frames[0])

    def animate_fn(i):
        im.set_data(frames[i])
        return [im]

    ani = animation.FuncAnimation(
        fig, animate_fn,
        frames=len(frames),
        interval=1000 / float(fps),
        blit=True
    )
    plt.close(fig)
    html = HTML(ani.to_jshtml())
    display(html)
    return html

def success_hold_update_acrobot(state, hold_count):
    th1, th2, w1, w2 = state
    y_tip = acrobot_tip_height_proxy_np(th1, th2)
    good = (y_tip >= float(Y_TIP_GOAL))
    hold_count = (hold_count + 1) if good else 0
    success = (hold_count >= int(HOLD_STEPS))
    return hold_count, success, y_tip, good

# ============================================================
# Single EPISODE runner ✅ ACROBOT
# ============================================================
def run_one_episode_mppi_retrain_rgb_with_eval(
    max_steps=600, seed=0, verbose=True,
    use_gpu_mppi=True,
    warmup_mppi=True,
    record_rgb=True,
    t_offset_in_run=0,
    return_frames=False,
):
    global m_dth1, m_dth2, m_dw1, m_dw2
    global Z_GLOBAL
    global STD_CMAX_FIXED

    env = make_env(render_mode=None, seed=seed, max_episode_steps=max_steps)
    obs, info = env.reset(seed=seed)
    s = np.array(obs_to_state(obs), dtype=np.float64)  # (th1, th2, w1, w2)

    # Ensure Z_GLOBAL fixed size
    Z_GLOBAL = np.asarray(Z_GLOBAL, dtype=np.float64)
    if Z_GLOBAL.shape[0] != int(M_GLOBAL):
        if Z_GLOBAL.shape[0] > int(M_GLOBAL):
            Z_GLOBAL = Z_GLOBAL[:int(M_GLOBAL)].copy()
        else:
            raise ValueError(f"Z_GLOBAL has {Z_GLOBAL.shape[0]} points but M_GLOBAL={M_GLOBAL}.")

    # anchors
    ANCHOR_IDX = compute_anchor_idx_dopt_from_Zglobal_multihead(
        Z_GLOBAL, m_anchors=ANCHOR_M, lam=ANCHOR_LAM, normalize_traces=True
    )

    # init local subset + mean action sequence
    u_mean = np.zeros((HORIZON,), dtype=np.float64)
    idx_sub = np.arange(min(M_SUB, M_GLOBAL), dtype=np.int64)

    # Persistent TF predictors per head (no retracing on rebuild)
    Dfeat = int(Z_GLOBAL.shape[1])  # feature dim = 7
    local_dth1 = LocalSubsetPredictor(Dfeat, M_SUB, name="local_dth1")
    local_dth2 = LocalSubsetPredictor(Dfeat, M_SUB, name="local_dth2")
    local_dw1  = LocalSubsetPredictor(Dfeat, M_SUB, name="local_dw1")
    local_dw2  = LocalSubsetPredictor(Dfeat, M_SUB, name="local_dw2")

    def refresh_locals(idx_sub_np):
        Z,L,a,S,ls,var = build_local_pack_from_global_tf(m_dth1, Z_GLOBAL, idx_sub_np); local_dth1.assign_pack(Z,L,a,S,ls,var)
        Z,L,a,S,ls,var = build_local_pack_from_global_tf(m_dth2, Z_GLOBAL, idx_sub_np); local_dth2.assign_pack(Z,L,a,S,ls,var)
        Z,L,a,S,ls,var = build_local_pack_from_global_tf(m_dw1,  Z_GLOBAL, idx_sub_np); local_dw1.assign_pack(Z,L,a,S,ls,var)
        Z,L,a,S,ls,var = build_local_pack_from_global_tf(m_dw2,  Z_GLOBAL, idx_sub_np); local_dw2.assign_pack(Z,L,a,S,ls,var)

    refresh_locals(idx_sub)

    last_idx_sub = np.array(idx_sub, dtype=np.int64)
    last_local_rebuild_t = 0

    pred_time_step = np.zeros((max_steps,), dtype=np.float64)
    train_time_step = np.zeros((max_steps,), dtype=np.float64)
    wall_excl_vis_cum = np.zeros((max_steps,), dtype=np.float64)

    t_wall_start = time.perf_counter()
    vis_time_s = 0.0

    # warmup: compile once
    if use_gpu_mppi and warmup_mppi:
        _ = mppi_plan_gpu_local(
            state=s, u_init=u_mean, t_global=int(t_offset_in_run),
            local_dth1=local_dth1, local_dth2=local_dth2, local_dw1=local_dw1, local_dw2=local_dw2,
            base_seed=seed
        )

    frames = []
    total_reward = 0.0
    hold_count = 0
    updates = 0

    Xbuf, ydth1_buf, ydth2_buf, ydw1_buf, ydw2_buf = [], [], [], [], []
    last_rollout_inputs = None
    last_rollout_weights = None

    for t in range(max_steps):
        t_global = int(t_offset_in_run + t)

        # ---------- MPPI plan ----------
        t0 = time.perf_counter()
        if use_gpu_mppi:
            u0, u_mean, tubeX, unc_w, Xroll, wroll = mppi_plan_gpu_local(
                state=s, u_init=u_mean, t_global=t_global,
                local_dth1=local_dth1, local_dth2=local_dth2, local_dw1=local_dw1, local_dw2=local_dw2,
                base_seed=seed
            )
        else:
            raise RuntimeError("use_gpu_mppi=False path not wired in this cell (set use_gpu_mppi=True).")
        t1 = time.perf_counter()
        pred_time_step[t] = (t1 - t0)

        last_rollout_inputs = Xroll
        last_rollout_weights = wroll

        # ---------- tube subset selection ----------
        idx_sub_cand, _, _, _, _ = select_tube_subset(
            Z_global=Z_GLOBAL,
            kernel_for_dopt=m_dw2.kernel,   # choose one head for D-opt inside bins (consistent)
            rollout_inputs=last_rollout_inputs,
            rollout_weights=last_rollout_weights,
            total_subset_size=M_SUB,
            anchor_idx=ANCHOR_IDX,
            time_bins=TUBE_TIME_BINS,
            cov_eps=TUBE_COV_EPS,
            dopt_jitter=1e-6,
            fallback_candidates=TUBE_FALLBACK_CANDIDATES,
            alpha=TUBE_ALPHA,
        )

        inter = np.intersect1d(last_idx_sub, idx_sub_cand)
        overlap = float(len(inter)) / float(len(idx_sub_cand)) if len(idx_sub_cand) > 0 else 1.0

        need_rebuild = ((t - last_local_rebuild_t) >= int(LOCAL_REBUILD_EVERY)) or (overlap < float(LOCAL_OVERLAP_THRESH))
        if need_rebuild:
            idx_sub = np.array(idx_sub_cand, dtype=np.int64)
            refresh_locals(idx_sub)  # no retracing; tensor assigns only
            last_idx_sub = np.array(idx_sub, dtype=np.int64)
            last_local_rebuild_t = int(t)
        else:
            idx_sub = last_idx_sub

        # ---------- env step ----------
        obs2, r, terminated, truncated, info = env.step(np.array([u0], dtype=np.float32))
        s2 = np.array(obs_to_state(obs2), dtype=np.float64)
        total_reward += float(r)

        # collect executed transition (delta targets)
        Xbuf.append(state_to_features(s[0], s[1], s[2], s[3], float(u0)))
        ydth1_buf.append([wrap_pi(s2[0] - s[0])])
        ydth2_buf.append([wrap_pi(s2[1] - s[1])])
        ydw1_buf.append([s2[2] - s[2]])
        ydw2_buf.append([s2[3] - s[3]])

        # ---------- render (excluded from wall) ----------
        if record_rgb and RECORD_RGB_DEFAULT and (t % FRAME_STRIDE == 0):
            tv0 = time.perf_counter()
            W, Hh = int(RESIZE[0]), int(RESIZE[1])
            frame = render_acrobot_frame_from_state(s2[0], s2[1], W=W, H=Hh)
            frames.append(frame)
            vis_time_s += (time.perf_counter() - tv0)

        # ---------- success tracking ----------
        hold_count, success, y_tip, good = success_hold_update_acrobot(s2, hold_count)

        if verbose and (t % 50 == 0):
            print(f"[t_global={t_global:04d}] u0={u0:+.2f} unc_w={unc_w:.2f}  y_tip={y_tip:+.3f}  hold={hold_count}/{HOLD_STEPS}")

        # ---------- UPDATE (global only) ----------
        if ((t + 1) % UPDATE_EVERY == 0) and (len(Xbuf) >= 10):
            updates += 1

            Xnew  = np.asarray(Xbuf, dtype=np.float64)
            y1    = np.asarray(ydth1_buf, dtype=np.float64)
            y2    = np.asarray(ydth2_buf, dtype=np.float64)
            y3    = np.asarray(ydw1_buf,  dtype=np.float64)
            y4    = np.asarray(ydw2_buf,  dtype=np.float64)

            # refit Z_GLOBAL (still fixed size)
            Z_GLOBAL = refit_Z_global_multihead(
                Z_global=Z_GLOBAL,
                Xnew=Xnew,
                M_global=M_GLOBAL,
                lam=ANCHOR_LAM,
                mode="mean",
                normalize_traces=False,
            )

            # train 4 heads (OSGPR)
            ttrain0 = time.perf_counter()
            m_dth1, _ = osgpr_stream_update(m_dth1, Xnew, y1, Z_GLOBAL, iters=ITERS_UPDATE, lr=LR_UPDATE, noise=NOISE_UPDATE, freeze_kernel=False)
            m_dth2, _ = osgpr_stream_update(m_dth2, Xnew, y2, Z_GLOBAL, iters=ITERS_UPDATE, lr=LR_UPDATE, noise=NOISE_UPDATE, freeze_kernel=False)
            m_dw1,  _ = osgpr_stream_update(m_dw1,  Xnew, y3, Z_GLOBAL, iters=ITERS_UPDATE, lr=LR_UPDATE, noise=NOISE_UPDATE, freeze_kernel=False)
            m_dw2,  _ = osgpr_stream_update(m_dw2,  Xnew, y4, Z_GLOBAL, iters=ITERS_UPDATE, lr=LR_UPDATE, noise=NOISE_UPDATE, freeze_kernel=False)
            ttrain1 = time.perf_counter()
            train_time_step[t] += float(ttrain1 - ttrain0)

            # clear buffers
            Xbuf, ydth1_buf, ydth2_buf, ydw1_buf, ydw2_buf = [], [], [], [], []

            # reselect anchors
            ANCHOR_IDX = compute_anchor_idx_dopt_from_Zglobal_multihead(
                Z_GLOBAL, m_anchors=ANCHOR_M, lam=ANCHOR_LAM, normalize_traces=True
            )

            # force rebuild subset after update
            idx_sub, _, _, _, _ = select_tube_subset(
                Z_global=Z_GLOBAL,
                kernel_for_dopt=m_dw2.kernel,
                rollout_inputs=last_rollout_inputs,
                rollout_weights=last_rollout_weights,
                total_subset_size=M_SUB,
                anchor_idx=ANCHOR_IDX,
                time_bins=TUBE_TIME_BINS,
                cov_eps=TUBE_COV_EPS,
                dopt_jitter=1e-6,
                fallback_candidates=TUBE_FALLBACK_CANDIDATES,
                alpha=TUBE_ALPHA,
            )
            refresh_locals(idx_sub)
            last_idx_sub = np.array(idx_sub, dtype=np.int64)
            last_local_rebuild_t = int(t)

        # ---------- wall time cumulative excl vis ----------
        wall_excl_vis_cum[t] = max((time.perf_counter() - t_wall_start) - vis_time_s, 0.0)

        s = s2
        if success or terminated or truncated:
            y_tip = -np.cos(s2[0]) - np.cos(s2[0] + s2[1])
            print(f"[EP END] t_global={t_global:04d} t_ep={t:04d} "
                  f"success={success} terminated={terminated} truncated={truncated} "
                  f"y_tip={y_tip:.3f} hold={hold_count}/{HOLD_STEPS}")
            break


    env.close()
    steps = int(t + 1)

    stats = dict(
        total_reward=float(total_reward),
        steps=steps,
        updates=int(updates),
        frames=int(len(frames)),
        hold_steps=int(hold_count),
        Z_global_size=int(len(Z_GLOBAL)),
        std_cmax_fixed=float(STD_CMAX_FIXED) if STD_CMAX_FIXED is not None else None,
        vis_time_s=float(vis_time_s),
    )

    frames_out = frames if (return_frames and record_rgb and RECORD_RGB_DEFAULT) else None
    return stats, frames_out, last_idx_sub, pred_time_step[:steps], train_time_step[:steps], wall_excl_vis_cum[:steps]

# ============================================================
# MULTI-RUN driver ✅ ACROBOT
# ============================================================
N_RUNS = 1
N_EPISODES_PER_RUN = 4

MAX_STEPS_PER_EP = 2000
USE_GPU_MPPI = True
VERBOSE = True

Z_GLOBAL_INIT = np.asarray(Z_GLOBAL, dtype=np.float64).copy()

params_th1_init = parameter_dict(m_dth1)
params_th2_init = parameter_dict(m_dth2)
params_w1_init  = parameter_dict(m_dw1)
params_w2_init  = parameter_dict(m_dw2)

def reset_models_and_globals_for_fresh_run():
    global m_dth1, m_dth2, m_dw1, m_dw2
    global Z_GLOBAL
    global GLOBAL_SURF_HISTORY
    global STD_CMAX_FIXED

    Z_GLOBAL = Z_GLOBAL_INIT.copy()
    if Z_GLOBAL.shape[0] != int(M_GLOBAL):
        if Z_GLOBAL.shape[0] > int(M_GLOBAL):
            Z_GLOBAL = Z_GLOBAL[:int(M_GLOBAL)].copy()
        else:
            raise ValueError(f"Z_GLOBAL_INIT has {Z_GLOBAL.shape[0]} < M_GLOBAL={M_GLOBAL}.")

    multiple_assign(m_dth1, params_th1_init)
    multiple_assign(m_dth2, params_th2_init)
    multiple_assign(m_dw1,  params_w1_init)
    multiple_assign(m_dw2,  params_w2_init)

    for mm in [m_dth1, m_dth2, m_dw1, m_dw2]:
        try: mm.build_predict_cache()
        except Exception: pass

    GLOBAL_SURF_HISTORY = []
    STD_CMAX_FIXED = None

run_pred_time = []
run_train_time = []
run_wall_cum = []
run_rewards = []
run_steps_total = []
run_updates_total = []

for run in range(N_RUNS):
    print(f"\n==================== RUN {run+1}/{N_RUNS} (fresh reset) ====================")
    reset_models_and_globals_for_fresh_run()

    pred_concat = []
    train_concat = []
    wall_concat = []
    t_offset = 0

    frames_run = []
    run_reward_sum = 0.0
    run_updates_sum = 0
    run_steps_sum = 0

    for ep in range(N_EPISODES_PER_RUN):
        record_rgb = (ep == 0)
        warmup_mppi = (ep == 0)

        print(f"\n--- RUN {run+1} EP {ep+1}/{N_EPISODES_PER_RUN} (t_offset={t_offset}) ---")

        stats_ep, frames_ep, _, pred_t, train_t, wall_cum_t = run_one_episode_mppi_retrain_rgb_with_eval(
            max_steps=MAX_STEPS_PER_EP,
            seed=1000*run + ep,
            verbose=VERBOSE,
            use_gpu_mppi=USE_GPU_MPPI,
            warmup_mppi=warmup_mppi,
            record_rgb=record_rgb,
            t_offset_in_run=t_offset,
            return_frames=True,
        )

        if record_rgb and (frames_ep is not None) and (len(frames_ep) > 0):
            frames_run.extend(frames_ep)

        pred_concat.append(pred_t)
        train_concat.append(train_t)
        wall_concat.append(wall_cum_t)

        run_reward_sum += stats_ep["total_reward"]
        run_updates_sum += stats_ep["updates"]
        run_steps_sum += stats_ep["steps"]
        t_offset += stats_ep["steps"]

    pred_run = np.concatenate(pred_concat, axis=0)
    train_run = np.concatenate(train_concat, axis=0)
    wall_run = np.concatenate(wall_concat, axis=0)

    # recompute wall cumulative cleanly
    wall_incr = np.zeros_like(wall_run)
    i0 = 0
    for ep_arr in wall_concat:
        ep_arr = np.asarray(ep_arr, dtype=np.float64)
        if ep_arr.size > 0:
            d = np.diff(np.concatenate([[0.0], ep_arr]))
            wall_incr[i0:i0+len(d)] = d
        i0 += len(ep_arr)
    wall_cum_run = np.cumsum(np.maximum(wall_incr, 0.0))

    run_pred_time.append(pred_run)
    run_train_time.append(train_run)
    run_wall_cum.append(wall_cum_run)
    run_rewards.append(run_reward_sum)
    run_steps_total.append(len(pred_run))
    run_updates_total.append(run_updates_sum)

    print(f"\n=== RUN {run+1}: single animation (episode 0 only) ===")
    _ = display_run_animation_from_frames(frames_run, fps=FPS, resize=RESIZE)

# ============================================================
# Aggregate across runs: mean + band (std)  (pad with NaN)
# ============================================================
maxT = int(max(run_steps_total))
def pad_to(arr, T):
    out = np.full((T,), np.nan, dtype=np.float64)
    out[:len(arr)] = arr
    return out

pred_mat  = np.vstack([pad_to(a, maxT) for a in run_pred_time])
train_mat = np.vstack([pad_to(a, maxT) for a in run_train_time])
wall_mat  = np.vstack([pad_to(a, maxT) for a in run_wall_cum])

t_axis = np.arange(maxT, dtype=np.int64)

pred_mean  = np.nanmean(pred_mat, axis=0)
pred_std   = np.nanstd(pred_mat, axis=0)

wall_cum_mean = np.nanmean(wall_mat, axis=0)
wall_cum_std  = np.nanstd(wall_mat, axis=0)

# training: ignore zeros -> update-only
train_mat_upd = train_mat.copy()
train_mat_upd[~np.isfinite(train_mat_upd)] = np.nan
train_mat_upd[train_mat_upd <= 0.0] = np.nan

train_update_mean = np.nanmean(train_mat_upd, axis=0)
train_update_std  = np.nanstd(train_mat_upd, axis=0)

# cumulative update-only training
train_cum_upd_mat = np.full_like(train_mat_upd, np.nan)
for r in range(N_RUNS):
    y = train_mat_upd[r].copy()
    valid = np.isfinite(y)
    cum = np.zeros((maxT,), dtype=np.float64)
    running = 0.0
    for i in range(maxT):
        if valid[i]:
            running += float(y[i])
            cum[i] = running
        else:
            cum[i] = np.nan
    train_cum_upd_mat[r] = cum

train_update_cum_mean = np.nanmean(train_cum_upd_mat, axis=0)
train_update_cum_std  = np.nanstd(train_cum_upd_mat, axis=0)

def moving_average_1d(y, window=7):
    y = np.asarray(y, dtype=np.float64)
    w = int(max(1, window))
    if y.size == 0 or w == 1:
        return y.copy()
    kernel = np.ones((w,), dtype=np.float64) / float(w)
    return np.convolve(y, kernel, mode="same")

SMOOTH_UPD_WIN = 7

mask_upd = np.isfinite(train_update_mean)
t_upd = t_axis[mask_upd]
y_upd_mean = train_update_mean[mask_upd]
y_upd_std  = train_update_std[mask_upd]
y_upd_mean_s = moving_average_1d(y_upd_mean, window=SMOOTH_UPD_WIN)

mask_upd_cum = np.isfinite(train_update_cum_mean)
t_upd_cum = t_axis[mask_upd_cum]
y_upd_cum_mean = train_update_cum_mean[mask_upd_cum]
y_upd_cum_std  = train_update_cum_std[mask_upd_cum]
y_upd_cum_mean_s = moving_average_1d(y_upd_cum_mean, window=SMOOTH_UPD_WIN)

# ============================================================
# Plots with variability band (mean ± 1 std)
# ============================================================
plt.figure(figsize=(10, 3.2))
plt.plot(t_upd, y_upd_mean_s, linewidth=2.0, label=f"mean (smoothed over updates, win={SMOOTH_UPD_WIN})")
plt.fill_between(t_upd, y_upd_mean - y_upd_std, y_upd_mean + y_upd_std, alpha=0.2, label="±1 std (across runs)")
plt.xlabel("timestep (only update steps)")
plt.ylabel("training time per update (s)")
plt.title("Training time per UPDATE (zeros ignored) — mean ± std across runs")
plt.grid(True, alpha=0.25)
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 3.2))
plt.plot(t_upd_cum, y_upd_cum_mean_s, linewidth=2.0, label=f"mean cum (smoothed, win={SMOOTH_UPD_WIN})")
plt.fill_between(t_upd_cum, y_upd_cum_mean - y_upd_cum_std, y_upd_cum_mean + y_upd_cum_std, alpha=0.2, label="±1 std (across runs)")
plt.xlabel("timestep (only update steps)")
plt.ylabel("cumulative training time (s)")
plt.title("Cumulative training time over UPDATEs (zeros ignored) — mean ± std across runs")
plt.grid(True, alpha=0.25)
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 3.2))
plt.plot(t_axis, pred_mean, linewidth=2.0, label="mean")
plt.fill_between(t_axis, pred_mean - pred_std, pred_mean + pred_std, alpha=0.2, label="±1 std (across runs)")
plt.xlabel("timestep (within run, concatenated episodes)")
plt.ylabel("prediction time (s) per step (MPPI planning)")
plt.title("Prediction (MPPI) time per timestep — mean ± std across runs")
plt.grid(True, alpha=0.25)
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 3.2))
plt.plot(t_axis, wall_cum_mean, linewidth=2.0, label="mean")
plt.fill_between(t_axis, wall_cum_mean - wall_cum_std, wall_cum_mean + wall_cum_std, alpha=0.2, label="±1 std (across runs)")
plt.xlabel("timestep (within run, concatenated episodes)")
plt.ylabel("wall time cumulative (s) (EXCLUDING visualization)")
plt.title("Wall time cumulative (no visualization) — mean ± std across runs")
plt.grid(True, alpha=0.25)
plt.legend()
plt.tight_layout()
plt.show()

# ============================================================
# Register this method (for cross-method overlay)
# ============================================================
METHOD_NAME = "PALSGP_OSGPR_fixedZglobal_localSubset_ACROBOT"

register_method_results(METHOD_NAME, dict(
    t_axis=t_axis,
    pred_mean=pred_mean,
    wall_cum_mean=wall_cum_mean,

    train_update_t=t_upd,
    train_update_mean=y_upd_mean_s,
    train_update_std=y_upd_std,
    train_update_cum_t=t_upd_cum,
    train_update_cum_mean=y_upd_cum_mean_s,
    train_update_cum_std=y_upd_cum_std,

    meta=dict(
        N_RUNS=N_RUNS,
        N_EPISODES_PER_RUN=N_EPISODES_PER_RUN,
        UPDATE_EVERY=UPDATE_EVERY,
        HORIZON=HORIZON,
        K_SAMPLES=K_SAMPLES,
        M_GLOBAL=M_GLOBAL,
        M_SUB=M_SUB,
    )
))

print(f"✅ Registered results into EVAL_REGISTRY under key: {METHOD_NAME}")
print("Current methods in registry:", list(EVAL_REGISTRY.keys()))

print("\n==================== SUMMARY ====================")
for r in range(N_RUNS):
    print(f"RUN {r+1}: total_steps={run_steps_total[r]}, total_reward_sum={run_rewards[r]:.3f}, updates_sum={run_updates_total[r]}")
print("================================================\n")
