In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt

import json
import os
import copy
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Callable, Any
import imageio.v2 as imageio
from PIL import Image, ImageDraw, ImageFont
from gymnasium.wrappers import RecordVideo

import gymnasium as gym
import warnings
warnings.filterwarnings('ignore')

In [None]:
# %% [markdown]
# ## Emergent CA (discrete space, continuous time) — minimal notebook
# Configure parameters & seed here, then run the following cells.
# The model integrates with RK4 and plots a 1-D CA-style spacetime diagram
# for a(t) and m(t): x-axis = space; y-axis (down) = time steps.

# %%


# ---- Config (edit these) ----
N   = 125        # number of cells (space)
T   = 125          # number of integration steps (time depth)
dt  = 0.05         # time step (continuous-time integrator)

alpha_m = 0.1      # memory leak rate

In [None]:
class CEM():
    def __init__(self, N, T, dt, alpha_m, persist = False):

        # Local polynomial coefficients (theta_*):
        # dot a_i = theta1 + theta_c*c + theta_l*l + theta_r*r + theta_m*m
        #           + theta_cl*c*l + theta_cr*c*r + theta_cm*c*m
        self.theta = {'theta1': 1.4376621267715537, 'theta_c': -7.8766253635181105, 'theta_l': -2.722349618817911, 'theta_r': 2.0576759135197538, 'theta_m': -8.45005113559998, 'theta_cl': -0.12868716237756087, 'theta_cr': 0.30220266691459485, 'theta_cm': -0.1061964674000701}
        self.N = N
        self.T = T
        self.dt = dt
        self.alpha_m = alpha_m

       # RNG seed for reproducible initial state
        self.A, self.M = self.__init_state__()

        spacing = N//5
        self.input_locs = [0*spacing, 1*spacing, 3*spacing, 4*spacing]
        self.output_locs = [2*spacing]

        self.A_mask = np.ones_like(self.A)
        if persist:
            self.A_mask[:, self.input_locs] = 0

    def __init_state__(self, state = None):
        if state is None:
            init_scale = 3e-1  # amplitude of random initialization
            rng = np.random.default_rng(42)
            a0 = init_scale * rng.standard_normal(self.N)
            m0 = init_scale * rng.standard_normal(self.N)

            A = np.empty((self.T+1, self.N), dtype=float)
            M = np.empty((self.T+1, self.N), dtype=float)
            A[0] = a0
            M[0] = m0

            return A,M
        if state is not None:
            self.A[0], self.M[0] = state


    def norm_obs(self,x, dx, th, dth ):
        # clip to reasonable limits
        x  = np.clip(x,  -2.4,  2.4)   / 2.4
        dx = np.clip(dx, -3.0,  3.0)   / 3.0
        th = np.clip(th, -0.418, 0.418) / 0.418
        dth= np.clip(dth,-4.0,  4.0)   / 4.0
        return np.array([x, dx, th, dth], dtype=float)


# %%
# ---- Dynamics (periodic boundary; RK4 integrator) ----

    def seed_inputs(self, x, dx , theta, dtheta):
        self.A[0,self.input_locs] = self.norm_obs(x, dx, theta, dtheta)



    def deriv(self,a, m):
        """Compute time-derivatives (dot a, dot m) given current (a, m)."""
        l = np.roll(a, 1)   # a_{i-1}
        r = np.roll(a, -1)  # a_{i+1}
        c = a
        t = self.theta
        dot_a = (t["theta1"]
                 + t["theta_c"]*c + t["theta_l"]*l + t["theta_r"]*r + t["theta_m"]*m
                 + t["theta_cl"]*c*l + t["theta_cr"]*c*r + t["theta_cm"]*c*m) - 5*self.alpha_m*c
        dot_m = self.alpha_m * (a - m)
        return dot_a, dot_m

    def rk4_step(self,a, m):
        """One RK4 step for the coupled ODEs."""
        k1_a, k1_m = self.deriv(a, m)
        k2_a, k2_m = self.deriv(a + 0.5*self.dt*k1_a, m + 0.5*self.dt*k1_m)
        k3_a, k3_m = self.deriv(a + 0.5*self.dt*k2_a, m + 0.5*self.dt*k2_m)
        k4_a, k4_m = self.deriv(a + self.dt*k3_a,     m + self.dt*k3_m)
        a_next = a + (self.dt/6.0)*(k1_a + 2*k2_a + 2*k3_a + k4_a)
        m_next = m + (self.dt/6.0)*(k1_m + 2*k2_m + 2*k3_m + k4_m)
        return a_next, m_next

    def simulate(self, x, dx , theta, dtheta):
        self.seed_inputs(x, dx, theta, dtheta)
        for t in range(1, self.T+1):

            a,m = self.A[t-1], self.M[t-1]
            a, m = self.rk4_step(a, m)
            self.A[t] = a * self.A_mask[t]
            self.M[t] = m
        return self.A[:, self.output_locs]

class CEM1Layer:

    def __init__(self, N, T, dt, alpha_m,
                 theta=None, gamma=0.0, rng_seed=42, init_scale=1.0, persist=False):
        self.N = int(N)
        self.T = int(T)
        self.dt = float(dt)
        self.alpha_m = float(alpha_m)   # kept for interface compatibility
        self.gamma = float(gamma)


        spacing = N//5
        self.input_locs = [0*spacing, 1*spacing, 3*spacing, 4*spacing]
        self.output_locs = [2*spacing]



        if theta is None:
            self.theta = {
                "k0": random.uniform(-0.10,  0.10),
                "k1": random.uniform(-0.40,  0.40),
                "k2": random.uniform(-0.60,  0.60),
                "k3": random.uniform(-0.40,  0.40),
                "k4": random.uniform(-0.30,  0.30),
                "k5": random.uniform(-0.30,  0.30),
                "k6": random.uniform(-0.15,  0.15),
            }
        else:
            keys = ["k0","k1","k2","k3","k4","k5","k6"]
            self.theta = {k: float(theta.get(k, 0.0)) for k in keys}

        # Allocate state history buffers to mirror your class
        self.A = np.empty((self.T + 1, self.N), dtype=float)
        self.A_mask = np.ones_like(self.A)
        if persist:
            self.A_mask[:, self.input_locs] = 0


        # Fitness bookkeeping (mirrors your attributes)
        self.fitness = -1e3
        self.fit_res = None

        # Initialize state (A[0]) like your __init_state__
        self.__init_state__(state=None, rng_seed=rng_seed, init_scale=init_scale)

    # ---- Initialization (mirrors your signature/behavior) ----
    def __init_state__(self, state=None, rng_seed=42, init_scale=1.0):
        """
        If state is None: random normal initial A[0].
        If state is a tuple/list (A0, M0) from old 2-layer code, we use only A0.
        If state is a 1D array of shape (N,), we use it as A[0].
        """
        if state is None:
            rng = np.random.default_rng(rng_seed)
            a0 = init_scale * rng.standard_normal(self.N)
            self.A[0] = a0
            return self.A

        else:
            a0 = np.asarray(state, dtype=float)

        assert a0.shape == (self.N,), f"Initial state must be shape (N,), got {a0.shape}"
        self.A[0] = a0
        return self.A

    def norm_obs(self,x, dx, th, dth ):
        # clip to reasonable limits
        x  = np.clip(x,  -2.4,  2.4)   / 2.4
        dx = np.clip(dx, -3.0,  3.0)   / 3.0
        th = np.clip(th, -0.418, 0.418) / 0.418
        dth= np.clip(dth,-4.0,  4.0)   / 4.0
        return np.array([x, dx, th, dth], dtype=float)

    def seed_inputs(self, x, dx , theta, dtheta):
        #self.A[0,self.input_locs] = self.norm_obs(x, dx, theta, dtheta)
        self.A[0,self.input_locs] = x, dx, theta, dtheta

    # ---- Dynamics (periodic boundary; RK4 integrator) ----
    def deriv(self, a):
        """Compute ds/dt for the current state a (periodic neighborhood)."""
        l = np.roll(a,  1)
        r = np.roll(a, -1)
        c = a
        t = self.theta
        dot = (t["k0"]
               + t["k1"]*l + t["k2"]*c + t["k3"]*r
               + t["k4"]*(c*r) + t["k5"]*(c*l)
               + t["k6"]*(l*c*r)
               - self.gamma * c)
        return dot

    def rk4_step(self, a):
        """One RK4 step for ds/dt = f(a)."""
        k1 = self.deriv(a)
        k2 = self.deriv(a + 0.5*self.dt*k1)
        k3 = self.deriv(a + 0.5*self.dt*k2)
        k4 = self.deriv(a + self.dt*k3)
        a_next = a + (self.dt/6.0)*(k1 + 2*k2 + 2*k3 + k4)
        return a_next


    def simulate(self,x,dx, theta, dtheta):
        self.seed_inputs(x, dx, theta, dtheta)
        """Run T integration steps, filling self.A[1:]."""
        for t in range(1, self.T + 1):
            a_prev = self.A[t - 1]
            a_next = self.rk4_step(a_prev)
            self.A[t] = a_next
        return self.A[:,self.output_locs]

    # ---- Helpers to mirror ergonomics ----
    def set_theta(self, **kwargs):
        """Update any subset of k0..k6 (e.g., set_theta(k2=0.1, k6=-0.05))."""
        for k, v in kwargs.items():
            if k not in self.theta:
                raise KeyError(f"Unknown coefficient '{k}'. Valid keys: {list(self.theta.keys())}")
            self.theta[k] = float(v)

    def get_theta(self):
        return dict(self.theta)






In [None]:
# pip install gymnasium numpy imageio imageio-ffmpeg


# ---- import your final CEM class here ----
# from your_file import CEM
# (Make sure CEM is top-level and pickleable for multiprocessing.)

Theta = Dict[str, float]
THETA_KEYS = [
    "k0","k1","k2","k3","k4","k5","k6"
]
#THETA_KEYS = ["theta1","theta_c","theta_l","theta_r","theta_m","theta_cl","theta_cr","theta_cm"]

# ---------- EA helpers ----------
def theta_to_vec(theta: Theta) -> np.ndarray:
    return np.array([theta[k] for k in THETA_KEYS], dtype=float)

def vec_to_theta(vec: np.ndarray) -> Theta:
    return {k: float(v) for k, v in zip(THETA_KEYS, vec)}

def clone_theta(theta: Theta) -> Theta:
    return {k: float(v) for k, v in theta.items()}

def random_theta(rng: np.random.Generator, low=-0.7, high=0.7) -> Theta:
    return {k: float(v) for k, v in zip(THETA_KEYS, rng.uniform(low, high, len(THETA_KEYS)))}

def uniform_crossover(a: Theta, b: Theta, rng: np.random.Generator, swap_p=0.5) -> Theta:
    return {k: (a[k] if rng.random() > swap_p else b[k]) for k in THETA_KEYS}

def mutate(theta: Theta, rng: np.random.Generator, pmut=0.25, sigma=0.12, clip=(-1.5,1.5)) -> Theta:
    v = theta_to_vec(theta)
    for i in range(v.size):
        if rng.random() < pmut:
            v[i] += rng.normal(0.0, sigma)
    # gentle push to keep theta_cm a bit negative
    """i_cm = THETA_KEYS.index("theta_cm")
    if rng.random() < 0.1:
        v[i_cm] = v[i_cm] - abs(rng.normal(0.0, sigma/2))"""
    v = np.clip(v, *clip)
    return vec_to_theta(v)

def tournament_select(pop_fit: List[Tuple[Theta,float]], rng: np.random.Generator, k=3) -> Theta:
    idxs = rng.choice(len(pop_fit), size=k, replace=False)
    best = max(idxs, key=lambda i: pop_fit[i][1])
    return clone_theta(pop_fit[best][0])

# ---------- configs ----------
@dataclass
class CEMConfig:
    N: int = 16
    T: int = 32       # micro-steps per env step
    dt: float = 0.02
    alpha_m: float = 0.2

@dataclass
class EvalConfig:
    env_id: str = "CartPole-v1"
    max_steps: int = 500
    eval_episodes: int = 1
    render_mode: str = None  # use "human" to watch interactively

@dataclass
class EAConfig:
    mu: int = 30
    lam: int = 60
    generations: int = 60
    tourney_k: int = 3
    swap_p: float = 0.5
    pmut: float = 0.25
    sigma: float = 0.12
    theta_clip: Tuple[float,float] = (-1.5, 1.5)
    seed: int = 123

def save_micro_state(A_last: np.ndarray, M_last: np.ndarray, cem_cfg: CEMConfig, path: str):
    """
    Save final micro-state as .npz (portable, precise).
    """
    os.makedirs(os.path.dirname(path), exist_ok=True)
    np.savez(path, A_last=A_last, M_last=M_last, N=cem_cfg.N, T=cem_cfg.T,
             dt=cem_cfg.dt, alpha_m=cem_cfg.alpha_m)

# ---------- Persistent-state micro-simulation ----------
def cem_force_step_persistent(model, obs: np.ndarray) -> float:
    """
    Use ONE persistent CEM. Carry internal state across env steps by copying
    last micro-state into t=0, then inject current observation and run T steps.
    """
    # carry state forward


    model.A[0] = model.A[-1]

    # unpack CartPole obs -> (x, dx, theta, dtheta)
    x, dx, th, dth = map(float, obs)

    # run T micro-steps with current inputs seeded at t=0
    force_vec = model.simulate(x, dx, th, dth)  # shape: (T+1, 1)
    return float(force_vec[-1, 0])              # last micro-step output



def run_episode_return_state(theta: Dict[str, float],
                             cem_cfg: CEMConfig,
                             env_id: str,
                             max_steps: int,
                             seed: int) -> Tuple[float, np.ndarray, np.ndarray]:
    """
    Persistent-state run. Returns (return_sum, A_last, M_last) from the final env step.
    """
    import gymnasium as gym
    env = gym.make(env_id, render_mode=None, max_episode_steps=max_steps)
    try:
        obs, _ = env.reset(seed=seed)
        model = CEM1Layer(N=cem_cfg.N, T=cem_cfg.T, dt=cem_cfg.dt, alpha_m=cem_cfg.alpha_m)
        # body params only
        model.theta.update({k: theta[k] for k in model.theta.keys()})

        total = 0.0
        for _ in range(max_steps):
            # carry persistent micro-state
            model.A[0] = model.A[-1]


            x, dx, th, dth = map(float, obs)   # plug in normalization if you use it
            force_series = model.simulate(x, dx, th, dth)  # (T+1, 1)
            force = float(force_series[-1, 0])  # or your smoothed/out_head version
            action = 0 if force < 0.0 else 1

            obs, reward, terminated, truncated, _ = env.step(action)
            total += reward
            if terminated or truncated:
                break

        A_last = model.A[-1].copy()

        return float(total), A_last, A_last
    finally:
        env.close()

def evaluate_theta_with_state(theta: Dict[str, float],
                              cem_cfg: CEMConfig,
                              eval_cfg: EvalConfig,
                              base_seed: int) -> Tuple[float, np.ndarray, np.ndarray]:
    """
    Averages fitness over eval_episodes. Returns:
      (mean_fitness, A_last, M_last)
    where A_last/M_last are from the *last episode evaluated*.
    """
    returns = []
    A_last = None
    M_last = None
    for i in range(eval_cfg.eval_episodes):
        ret, A_last, M_last = run_episode_return_state(
            theta, cem_cfg, eval_cfg.env_id, eval_cfg.max_steps, base_seed + i*9973
        )
        returns.append(ret)
    return float(np.mean(returns)), A_last, M_last

# top-level fn for ProcessPoolExecutor
def _eval_one(args):
    theta, cem_cfg, eval_cfg, base_seed = args
    fit, A_last, M_last = evaluate_theta_with_state(theta, cem_cfg, eval_cfg, base_seed)
    # return EVERYTHING we need: params, sigmas, fitness, and the final micro-state
    return (theta,  float(fit), A_last, M_last)





# ---------- IO helpers ----------
def save_theta(theta: Theta, path: str):
    with open(path, "w") as f:
        json.dump(theta, f, indent=2)

def load_theta(path: str) -> Theta:
    with open(path, "r") as f:
        return {k: float(v) for k, v in json.load(f).items()}

def save_checkpoint(gen: int, pop_fit: List[Tuple[Theta, float]], out_dir="checkpoints"):
    os.makedirs(out_dir, exist_ok=True)
    # save top-1
    best_idx = int(np.argmax([f for _, f in pop_fit]))
    best_theta, best_fit = pop_fit[best_idx]
    save_theta(best_theta, os.path.join(out_dir, f"best_gen{gen:04d}.json"))
    # also save small leaderboard
    board = [{"rank": i+1, "fitness": float(pop_fit[i][1]), "theta": pop_fit[i][0]}
             for i in range(min(5, len(pop_fit)))]
    with open(os.path.join(out_dir, f"top5_gen{gen:04d}.json"), "w") as f:
        json.dump(board, f, indent=2)

# ---------- (μ + λ) EA with PARALLEL evaluation ----------

def _eval_many_serial(args_list):
    """Strictly serial evaluation; avoids any multiprocessing."""
    return [_eval_one(args) for args in args_list]

def evolve_cartpole_with_cem(cem_cfg=CEMConfig(), eval_cfg=EvalConfig(), ea_cfg=EAConfig(),
                             out_dir="runs/cem_ea"):
    rng = np.random.default_rng(ea_cfg.seed)
    os.makedirs(out_dir, exist_ok=True)

    # init μ parents
    population: List[Theta] = [random_theta(rng) for _ in range(ea_cfg.mu)]
    base_seed = ea_cfg.seed * 11

    args = [(clone_theta(th), cem_cfg, eval_cfg, base_seed + i*1000)
        for i, th in enumerate(population)]
    pop_fit = _eval_many_serial(args)

    history = []
    best_over_time: List[Tuple[Theta, float]] = []

    for gen in range(ea_cfg.generations):
        fits = [f for (_,  f, _, _) in pop_fit]
        best_idx = int(np.argmax(fits))
        best_theta, best_fit, best_A_last, best_M_last = pop_fit[best_idx]
        save_theta(best_theta, os.path.join(out_dir, "best_theta.json"))
        save_micro_state(best_A_last, best_M_last,cem_cfg, os.path.join(out_dir, "best_micro_state.npz"))
        history.append({
            "generation": gen,
            "mean_fitness": float(np.mean(fits)),
            "std_fitness": float(np.std(fits)),
            "max_fitness": float(np.max(fits)),
            "min_fitness": float(np.min(fits)),
        })
        best_over_time.append((clone_theta(best_theta), float(best_fit)))
        print(f"[Gen {gen:03d}] mean={np.mean(fits):.2f} best={best_fit:.2f}")


        if np.mean(fits).__int__() == eval_cfg.max_steps:
            print("Stopping Early")
            break

        # save rolling best + small leaderboard
        #save_checkpoint(gen, pop_fit, out_dir=os.path.join(out_dir, "checkpoints"))

        # λ offspring
        offspring: List[Theta] = []
        for _ in range(ea_cfg.lam):
            p1 = tournament_select(pop_fit, rng, k=ea_cfg.tourney_k)
            p2 = tournament_select(pop_fit, rng, k=ea_cfg.tourney_k)
            child = uniform_crossover(p1, p2, rng, swap_p=ea_cfg.swap_p)
            child = mutate(child, rng, pmut=ea_cfg.pmut, sigma=ea_cfg.sigma, clip=ea_cfg.theta_clip)
            offspring.append(child)

        # evaluate offspring in parallel with fresh seeds (reduce overfit to single RNG stream)
        gen_seed_base = base_seed + (gen+1)*100_000
        args = [(clone_theta(ch), cem_cfg, eval_cfg, gen_seed_base + i*1000)
        for i, ch in enumerate(offspring)]
        off_fit = _eval_many_serial(args)


        # μ + λ truncation
        pool = pop_fit + off_fit
        pool.sort(key=lambda tf: tf[1], reverse=True)
        pop_fit = pool[:ea_cfg.mu]

    # final best & logs
    fits = [f for (_,  f, _, _) in pop_fit]
    best_idx = int(np.argmax(fits))
    best_theta, best_fit, best_A_last, best_M_last = pop_fit[best_idx]
    """save_theta(best_theta, os.path.join(out_dir, "best_theta.json"))
    save_micro_state(best_A_last, best_M_last,cem_cfg, os.path.join(out_dir, "best_micro_state.npz"))"""

    # save best +




    with open(os.path.join(out_dir, "history.json"), "w") as f:
        json.dump(history, f, indent=2)

    return {
        "best_theta": clone_theta(best_theta),
        "best_fitness": float(best_fit),
        "history": history,
        "best_over_time": best_over_time,
        "run_dir": out_dir,
    }

# ---------- Video rendering ----------
def render_video_with_theta(theta: Dict[str, float], cem_cfg: CEMConfig,
                            env_id="CartPole-v1", max_steps=500,
                            video_dir="videos", name_prefix="cem_cartpole",
                            microstate_path: str | None = None,
                            seed: int = 2025):
    os.makedirs(video_dir, exist_ok=True)
    env = gym.make(env_id, render_mode="rgb_array")
    env = RecordVideo(env, video_folder=video_dir, name_prefix=name_prefix,
                      episode_trigger=lambda i: True)
    try:
        obs, _ = env.reset(seed=seed)

        model = CEM1Layer(N=cem_cfg.N, T=cem_cfg.T, dt=cem_cfg.dt, alpha_m=cem_cfg.alpha_m)
        # body params only
        body_keys = list(model.theta.keys())
        model.theta.update({k: theta[k] for k in body_keys})

        # --- Inject saved micro-state (if provided) ---
        if microstate_path is not None and os.path.exists(microstate_path):
            print("Loading microstate from {}".format(microstate_path))
            data = np.load(microstate_path)
            A_last = data["A_last"]
            model.A[-1] = A_last
            # initialize persistent internal state to saved last micro-step

            """model.A[-1] = A_last
            model.M[-1] = M_last"""
        else:
            # otherwise start from model's default A[0], M[0]
            pass

        total = 0.0
        for _ in range(max_steps):
            # persistent carryover for each env step
            model.A[0] = model.A[-1]


            x, dx, th, dth = obs  # or normalized if you use norm_obs
            force_series = model.simulate(float(x), float(dx), float(th), float(dth))
            force = float(force_series[-1, 0])

            action = 0 if force < 0.0 else 1
            obs, reward, terminated, truncated, _ = env.step(action)
            total += reward
            if terminated or truncated:
                break

        print(f"Episode return (video): {total:.1f}")
    finally:
        env.close()
    print(f"Video saved to: {os.path.abspath(video_dir)}  (look for {name_prefix}-episode-*.mp4)")


# --- helper: draw A overlay (top-right) ---
def _draw_A_overlay(frame: np.ndarray, A: np.ndarray,
                    panel_w: int = 260, panel_h: int = 100,
                    margin: int = 12, bar_pad: int = 6) -> np.ndarray:
    import numpy as np
    from PIL import Image, ImageDraw, ImageFont

    # defensive: ensure HxWx3 uint8
    frame = np.asarray(frame)
    if frame.dtype != np.uint8:
        frame = np.clip(frame, 0, 255).astype(np.uint8)
    if frame.ndim == 2:
        frame = np.repeat(frame[..., None], 3, axis=2)

    h, w, _ = frame.shape
    im = Image.fromarray(frame).convert("RGBA")
    draw = ImageDraw.Draw(im)

    # panel rect (top-right)
    x1 = max(0, w - panel_w - margin)
    y1 = max(0, margin)
    x2 = min(w - 1, w - margin)
    y2 = min(h - 1, margin + panel_h)

    # ensure valid panel rect
    if x2 <= x1 + 2 or y2 <= y1 + 2:
        return np.array(im.convert("RGB"), dtype=np.uint8)

    # bg
    try:
        font = ImageFont.load_default()
    except Exception:
        font = None
    draw.rounded_rectangle([x1, y1, x2, y2], radius=10,
                           fill=(0, 0, 0, 160), outline=(255, 255, 255, 180), width=1)

    # stats
    if A.size == 0:
        text = "A  (empty)"
        draw.text((x1 + 10, y1 + 8), text, fill=(255, 255, 255, 230), font=font)
        return np.array(im.convert("RGB"), dtype=np.uint8)

    A_min = float(np.min(A))
    A_mean = float(np.mean(A))
    A_max = float(np.max(A))
    draw.text((x1 + 10, y1 + 8),
              f"A  min {A_min:+.2f}  mean {A_mean:+.2f}  max {A_max:+.2f}",
              fill=(255, 255, 255, 230), font=font)

    # bar area
    bars_x1 = x1 + 10
    bars_y1 = y1 + 28
    bars_x2 = x2 - 10
    bars_y2 = y2 - 10
    bars_w = max(1, bars_x2 - bars_x1)
    bars_h = max(2, bars_y2 - bars_y1)  # need >=2 for zero line + 1px bars

    # zero line
    zero_y = bars_y1 + bars_h // 2
    zero_y = int(np.clip(zero_y, bars_y1, bars_y2))  # within area
    draw.line([(bars_x1, zero_y), (bars_x2, zero_y)], fill=(200, 200, 200, 180), width=1)

    # scale A per-frame to [-1, 1] (robust)
    denom = float(max(1e-6, np.max(np.abs(A))))
    A_unit = (A / denom).astype(float)

    n = int(len(A_unit))
    if n > 0:
        # bar width + gap
        bw = max(1, bars_w // n)
        gap = max(0, int(0.15 * bw))
        bw_eff = max(1, bw - gap)

        for i, val in enumerate(A_unit):
            bx1 = bars_x1 + i * bw
            bx2 = min(bars_x2, bx1 + bw_eff)
            if bx2 <= bx1:
                continue  # nothing to draw

            height = int((bars_h / 2) * min(1.0, abs(val)))
            if height < 1:
                height = 1  # ensure at least 1px visible

            if val >= 0:
                y_top = int(np.clip(zero_y - height, bars_y1, bars_y2))
                y_bot = int(np.clip(zero_y - 1, bars_y1, bars_y2))
                color = (120, 200, 255, 220)
            else:
                y_top = int(np.clip(zero_y + 1, bars_y1, bars_y2))
                y_bot = int(np.clip(zero_y + height, bars_y1, bars_y2))
                color = (255, 160, 120, 220)

            # enforce y_top <= y_bot to satisfy PIL
            if y_bot < y_top:
                y_top, y_bot = y_bot, y_top
            if y_bot == y_top:  # still a line -> extend 1px
                y_bot = min(bars_y2, y_bot + 1)

            draw.rectangle([bx1, y_top, bx2, y_bot], fill=color)

    return np.array(im.convert("RGB"), dtype=np.uint8)



def _bwr_rgb_from_unit(x: np.ndarray) -> np.ndarray:
    """
    x in [-1, 1] -> RGB (uint8), blue-white-red
    """
    x = np.clip(x, -1.0, 1.0)
    R = np.where(x >= 0, 255.0, (1.0 + x) * 255.0)         # x<0: fade red down
    G = np.where(x >= 0, (1.0 - x) * 255.0, (1.0 + x) * 255.0)
    B = np.where(x >= 0, (1.0 - x) * 255.0, 255.0)         # x<0: strong blue
    rgb = np.stack([R, G, B], axis=-1)
    return np.clip(rgb, 0, 255).astype(np.uint8)

def _unrolled_A_to_frame(A_roll: np.ndarray, scale: int = 2) -> np.ndarray:
    """
    A_roll: (T+1, N) float -> RGB image as ndarray (H, W, 3).
    Normalizes by max abs within this micro-rollout for visibility.
    Upscales with nearest-neighbor by `scale`.
    """
    A = np.asarray(A_roll, dtype=float)
    # normalize per micro-rollout
    denom = float(np.max(np.abs(A))) if np.max(np.abs(A)) > 1e-12 else 1.0
    An = A / denom
    rgb = _bwr_rgb_from_unit(An)  # (T+1, N, 3)
    if scale != 1:
        h, w = rgb.shape[:2]
        img = Image.fromarray(rgb, mode="RGB").resize((w*scale, h*scale), resample=Image.NEAREST)
        rgb = np.array(img, dtype=np.uint8)
    return rgb


def render_video_with_theta_and_A_overlay(
    theta: dict,
    cem_cfg,
    env_id: str = "CartPole-v1",
    max_steps: int = 500,
    out_path: str = "videos/cem_cartpole_best_with_A.mp4",
    microstate_path: str | None = None,
    seed: int | None = None,
    # these must match training if you used them:
    use_norm: bool = False,
    smooth_k: int = 1,
    out_head: dict | None = None,
    fps: int = 30,
    unrolled_out_path: str | None = "videos/cem_cartpole_A_unrolled.mp4",
    unrolled_scale: int = 2,
):
    """
    Records a video with the CEM-controlled episode and overlays the current A state
    at the top-right in every frame.

    - If microstate_path is provided (npz with A_last, M_last, and seed), we start from that state and seed.
    - use_norm / smooth_k / out_head must mirror training for consistent behavior.
    """
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    if unrolled_out_path is not None:
        os.makedirs(os.path.dirname(unrolled_out_path) or ".", exist_ok=True)

    # Load micro-state (and seed) if provided
    A0 = M0 = None
    if microstate_path and os.path.exists(microstate_path):
        data = np.load(microstate_path)
        A0 = data["A_last"]
        M0 = data["M_last"]
        if seed is None and "seed" in data:
            seed = int(data["seed"])

    if seed is None:
        seed = 2025

    env = gym.make(env_id, render_mode="rgb_array")
    writer = imageio.get_writer(out_path, fps=fps, codec="libx264", quality=8)
    unrolled_writer = None
    if unrolled_out_path:
        unrolled_writer = imageio.get_writer(unrolled_out_path, fps=fps, codec="libx264", quality=8)
    total = 0.0
    try:
        obs, _ = env.reset(seed=seed)

        # build model (persistent across env steps)
        model = CEM1Layer(N=cem_cfg.N, T=cem_cfg.T, dt=cem_cfg.dt, alpha_m=cem_cfg.alpha_m)
        # body params only
        model.theta.update({k: theta[k] for k in model.theta.keys()})

        # inject saved micro-state before first step
        if A0 is not None:
            model.A[0] = A0


        first_step = True
        for _ in range(max_steps):
            # carry persistent micro-state (skip overwrite on the very first frame)
            if not first_step:
                model.A[0] = model.A[-1]

            first_step = False

            # (optional) normalize obs exactly as during training
            x, dx, th, dth = obs


            force_series = model.simulate(float(x), float(dx), float(th), float(dth))  # (T+1,1)
            A_unrolled = model.A.copy()          # shape: (T+1, N)
            unrolled_frame = _unrolled_A_to_frame(A_unrolled, scale=unrolled_scale)
            if unrolled_writer is not None:
                unrolled_writer.append_data(unrolled_frame)

            # readout (match training)
            if smooth_k and smooth_k > 1:
                k = min(smooth_k, force_series.shape[0])
                base_force = float(force_series[-k:, 0].mean())
            else:
                base_force = float(force_series[-1, 0])

            if out_head is not None:
                    force = out_head.get("out_gain", 1.0) * base_force + out_head.get("out_bias", 0.0)
            else:
                force = base_force

            action = 0 if force < 0.0 else 1
            obs, reward, terminated, truncated, _ = env.step(action)
            total += reward

            # grab a frame and overlay A (panel rendered at top-right)
            frame = env.render()  # rgb_array
            # use *current* A baseline for overlay (already updated this step)
            A_current = model.A[-1]
            frame = _draw_A_overlay(frame, A_current)

            writer.append_data(frame)

            if terminated or truncated:
                break
    finally:
        writer.close()
        if unrolled_writer is not None:
            unrolled_writer.close()
        env.close()

    print(f"Episode return (video): {total:.1f}")
    print(f"Saved video with A overlay to: {os.path.abspath(out_path)}")






In [None]:
# ---------- Example usage ----------
for i in range(10):
    c1 =CEMConfig(N=30, T=150, dt=0.3, alpha_m=0.2)
    # 1) Train (parallel)
    result = evolve_cartpole_with_cem(
        cem_cfg=c1,
        eval_cfg=EvalConfig(env_id="CartPole-v1", eval_episodes=1, max_steps=300),
        ea_cfg=EAConfig(mu=50, lam=100, generations=4, pmut=0.15,sigma=0.1,tourney_k=10, seed=i+25),
        out_dir="runs/cem_ea_parallel",
    )
    print("Best fitness:", result["best_fitness"])
    print("Best theta saved at:", os.path.join(result["run_dir"], "best_theta.json"))
    best_theta = load_theta(os.path.join(result["run_dir"], "best_theta.json"))
    micro_path = os.path.join(result["run_dir"], "best_micro_state.npz")

In [None]:
 render_video_with_theta(best_theta, cem_cfg=c1, env_id="CartPole-v1",
                        max_steps=300, video_dir="videos",
                        name_prefix="cem_cartpole_best",
                        microstate_path=micro_path, seed=2025)

    # 2) Visualize best (make a video)

render_video_with_theta_and_A_overlay(best_theta, cem_cfg=c1, env_id="CartPole-v1",
                        max_steps=500, out_path="videos/cem_cartpole_best_with_A.mp4",

                        microstate_path=micro_path, seed=2025)
