In [None]:
# Install required libraries for Atari environment
!pip install -q gymnasium[atari,accept-rom-license] shimmy ale-py

In [None]:
import sys
import os
import time
import math
import argparse
import csv
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import types
import gymnasium as gym
from collections import deque
from torch.distributions import Normal
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass
import ale_py

gym.register_envs(ale_py)

#Module mapping
sys.modules['src'] = types.ModuleType('src')
sys.modules['src.models'] = types.ModuleType('src.models')
sys.modules['src.models.vlm'] = types.ModuleType('src.models.vlm')
sys.modules['src.models.metacontroller'] = types.ModuleType('src.models.metacontroller')


class TrainConfig:
    # Paths
    model_checkpoint = "/kaggle/input/best-models-montezuma/vlm_best.pt"
    meta_checkpoint = "/kaggle/input/best-models-montezuma/meta_simple_epoch_18.pt"
    save_path = "policy_best_.pt"
    log_path = "internal_rl_logs.csv"
    episode_log_path = "episode_logs.csv"   
    
    # Training Hyperparameters
    epochs = 1000
    batch_size = 64         # Trajectories per epoch
    num_envs = 8            
    lr = 3e-4
    
    # Model Params
    hidden_dim = 256
    input_dim = 256
    latent_dim = 8
    initial_std = 0.7
    
    # Hierarchy Settings
    fixed_rate = 0
    max_steps_without_reward = 500     #internal steps before timeout
    
    # VLM/Env Settings
    seq_len = 64
    img_size = 84
    patch_size = 14
    embed_dim = 256
    n_layers = 6
    n_heads = 8
    frame_stack = 4
    
    # Optimization
    entropy_coef = 0
    
    # Misc
    render = False
    device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = TrainConfig()

# ==========================================
# 2. MODELS (VLM & Metacontroller)
# ==========================================

@dataclass
class Config:
    img_size: int = 84
    patch_size: int = 14
    frame_stack: int = 4
    embed_dim: int = 256
    n_heads: int = 8
    n_layers: int = 6
    dropout: float = 0.1
    seq_len: int = 64
    n_actions: int = 18
    duration_vocab_size: int = 65

class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, n_heads, dropout):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = dropout

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        out = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=True
        )
        out = out.transpose(1, 2).reshape(B, T, C)
        return self.proj(out)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = CausalSelfAttention(embed_dim, n_heads, dropout)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class MontezumaVLM(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.embed_dim = config.embed_dim
        self.n_patches = (config.img_size // config.patch_size) ** 2
        
        self.patch_embed = PatchEmbedding(config.img_size, config.patch_size, config.frame_stack, config.embed_dim)
        self.action_embed = nn.Embedding(config.n_actions + 1, config.embed_dim)
        self.start_token_id = config.n_actions
        
        max_tokens = 512 
        self.pos_embed = nn.Parameter(torch.zeros(1, max_tokens, config.embed_dim))
        self.type_embed = nn.Embedding(2, config.embed_dim)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(config.embed_dim, config.n_heads, config.dropout)
            for _ in range(config.n_layers)
        ])
        
        self.ln_f = nn.LayerNorm(config.embed_dim)
        self.action_head = nn.Linear(config.embed_dim, config.n_actions)
        self.obs_head = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim * 2),
            nn.GELU(),
            nn.Linear(config.embed_dim * 2, self.n_patches * config.embed_dim)
        )
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.trunc_normal_(module.weight, std=0.02)
                if module.bias is not None: nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.trunc_normal_(module.weight, std=0.02)
    
    def forward(self, frames, actions, return_residuals=False, adapter_params=None):
        """
        adapter_params (optional):
          - None
          - (A, B_mat, layer_idx, pos_spec)
    
        where:
          A, B_mat: (B, D, R) low-rank factors from hypernet
          layer_idx: int (which transformer block to intervene at)
          pos_spec: either
            - "last_action"  -> apply at the last action token position
            - int            -> apply at that absolute token index
            - None           -> (legacy) apply to all tokens (NOT recommended for paper-faithful)
        """
        B, NumImg, C, H, W = frames.shape
        device = frames.device
    
        # ---- vision tokens ----
        frames_flat = frames.view(B * NumImg, C, H, W)
        vis_tokens = self.patch_embed(frames_flat).view(B, NumImg, self.n_patches, self.embed_dim)
    
        # ---- action tokens ----
        action_input = torch.cat(
            [torch.full((B, 1), self.start_token_id, device=device, dtype=torch.long), actions[:, :-1]],
            dim=1,
        )
        act_embeds = self.action_embed(action_input)
    
        # ---- interleave tokens and build positions ----
        token_list, type_list = [], []
        action_positions = []
        num_actions = act_embeds.shape[1]
        current_pos = 0
    
        for i in range(NumImg):
            token_list.append(vis_tokens[:, i])  # (B, n_patches, D)
            type_list.append(torch.zeros(self.n_patches, dtype=torch.long, device=device))
            current_pos += self.n_patches
    
            start_idx = i * 8
            end_idx = min((i + 1) * 8, num_actions)
            if start_idx < num_actions:
                chunk = act_embeds[:, start_idx:end_idx]  # (B, k, D)
                token_list.append(chunk)
                type_list.append(torch.ones(chunk.shape[1], dtype=torch.long, device=device))
    
                for j in range(chunk.shape[1]):
                    action_positions.append(current_pos + j)
                current_pos += chunk.shape[1]
    
        tokens = torch.cat(token_list, dim=1)  # (B, T, D)
        seq_len_curr = tokens.shape[1]
    
        tokens = tokens + self.pos_embed[:, :seq_len_curr]
        tokens = tokens + self.type_embed(torch.cat(type_list))
    
        # Determine last action position *before* blocks so adapters can use it.
        if len(action_positions) == 0:
            # Should not happen if seq_len >= 1, but guard anyway.
            last_action_pos = seq_len_curr - 1
        else:
            last_action_pos = int(action_positions[-1])
    
        residuals = []
        x = tokens
    
        # ---- transformer blocks (optionally intervene at one layer + one position) ----
        for i, block in enumerate(self.blocks):
            x = block(x)
    
            if adapter_params is not None:
                A, B_mat, layer_idx, pos_spec = adapter_params
                if i == layer_idx:
                    # Choose position to intervene
                    if pos_spec is None:
                        if A.dim() == 3:
                            term1 = torch.einsum("btd,bdr->btr", x, B_mat)
                            term2 = torch.einsum("btr,bdr->btd", term1, A)
                            x = x + term2
                        elif A.dim() == 4:
                            term1 = torch.einsum("btd,btdr->btr", x, B_mat)
                            term2 = torch.einsum("btr,btdr->btd", term1, A)
                            x = x + term2
                        else:
                            raise ValueError(f"Bad adapter A shape: {A.shape}")

                    elif pos_spec == "action_tokens":
                        # apply tokenwise A/B only on action token indices (not on patch tokens)
                        idx = torch.tensor(action_positions, device=device, dtype=torch.long)  # (T_act,)
                        x_tok = x[:, idx, :]                     # (B, T_act, D)
                    
                        # expects A and B_mat to be (B, T_tokens, D, R)
                        A_tok = A[:, idx, :, :]                  # (B, T_act, D, R)
                        B_tok = B_mat[:, idx, :, :]              # (B, T_act, D, R)
                    
                        term1 = torch.einsum("btd,btdr->btr", x_tok, B_tok)   # (B, T_act, R)
                        term2 = torch.einsum("btr,btdr->btd", term1, A_tok)   # (B, T_act, D)
                        x[:, idx, :] = x_tok + term2
                    else:
                        if pos_spec == "last_action":
                            pos = last_action_pos
                        elif isinstance(pos_spec, int):
                            pos = pos_spec
                        else:
                            raise ValueError(f"Unsupported pos_spec: {pos_spec}")
    
                        # Intervene only at x[:, pos, :]
                        x_pos = x[:, pos, :]                         # (B, D)
                        term1 = torch.einsum("bd,bdr->br", x_pos, B_mat)  # (B, R) = x_pos @ B
                        term2 = torch.einsum("br,bdr->bd", term1, A)      # (B, D) = term1 @ A^T
                        x[:, pos, :] = x_pos + term2
    
            if return_residuals:
                residuals.append(x.clone())
    
        x = self.ln_f(x)
    
        # Compute logits at action token positions; return only the last action logits for sampling
        action_positions_t = torch.tensor(action_positions, device=device, dtype=torch.long)
        action_embeddings = x[:, action_positions_t, :]      # (B, n_action_tokens, D)
        action_logits = self.action_head(action_embeddings)  # (B, n_action_tokens, n_actions)
    
        output = {
            "logits": action_logits[:, -1, :],  # last action-token logits
            "final_embedding": x[:, -1, :],
            "last_action_pos": last_action_pos,
        }
        if return_residuals:
            output["residuals"] = residuals
        return output



# --- Metacontroller Components ---

class LRU(nn.Module):
    def __init__(self, d_input, d_state, n_heads=8):
        super().__init__()
        self.d_input, self.d_state = d_input, d_state
        self.W_in = nn.Linear(d_input, d_state, bias=False)
        self.log_lambda = nn.Parameter(torch.ones(d_state) * 2.0)
        self.W_out = nn.Linear(d_state, d_input, bias=False)
        self.D = nn.Parameter(torch.ones(d_input) * 0.1)

    def forward(self, u, reverse=False):
        B, T, _ = u.shape
        if reverse: u = torch.flip(u, dims=[1])
        lam = torch.sigmoid(self.log_lambda)
        u_proj = self.W_in(u)
        x = torch.zeros(B, self.d_state, device=u.device)
        ys = []
        for t in range(T):
            x = lam * x + (1 - lam) * u_proj[:, t]
            ys.append(x)
        y = self.W_out(torch.stack(ys, dim=1)) + u * self.D
        return torch.flip(y, dims=[1]) if reverse else y

    def step(self, u_t, state=None):
        lam = torch.sigmoid(self.log_lambda)
        u_proj = self.W_in(u_t)
        if state is None: state = torch.zeros_like(u_proj)
        new_state = lam * state + (1 - lam) * u_proj
        y_t = self.W_out(new_state) + u_t * self.D
        return y_t, new_state

class HawkBlock(nn.Module):
    def __init__(self, d_model, d_state=256, n_heads=8, mlp_ratio=2):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.lru = LRU(d_model, d_state, n_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * mlp_ratio),
            nn.ReLU(),
            nn.Linear(d_model * mlp_ratio, d_model),
        )

    def forward(self, x, reverse=False):
        x = x + self.lru(self.norm1(x), reverse=reverse)
        x = x + self.mlp(self.norm2(x))
        return x

    def step(self, x_t, state=None):
        residual = x_t
        lru_out, new_state = self.lru.step(self.norm1(x_t), state)
        x_t = residual + lru_out
        x_t = x_t + self.mlp(self.norm2(x_t))
        return x_t, new_state

class BidirectionalHawk(nn.Module):
    """Bidirectional Hawk SSM for acausal sequence embedding."""
    def __init__(self, d_input, d_output, d_state=256, n_heads=8, n_layers=1, dropout=0.0):
        super().__init__()
        self.input_proj = nn.Linear(d_input, d_output) if d_input != d_output else nn.Identity()
        self.forward_layers = nn.ModuleList([
            HawkBlock(d_output, d_state, n_heads, mlp_ratio=2) for _ in range(n_layers)
        ])
        self.backward_layers = nn.ModuleList([
            HawkBlock(d_output, d_state, n_heads, mlp_ratio=2) for _ in range(n_layers)
        ])
        self.output_proj = nn.Linear(d_output * 2, d_output)

    def forward(self, x):
        x = self.input_proj(x)
        h_fwd = x
        for layer in self.forward_layers:
            h_fwd = layer(h_fwd, reverse=False)
        h_bwd = x
        for layer in self.backward_layers:
            h_bwd = layer(h_bwd, reverse=True)
        combined = torch.cat([h_fwd[:, -1, :], h_bwd[:, 0, :]], dim=-1)
        return self.output_proj(combined)


class Metacontroller(nn.Module):
    def __init__(self, config, base_embed_dim, aux_position_predictor=False):
        super().__init__()
        self.config = config
        self.base_embed_dim = base_embed_dim
        self.latent_dim = 8
        self.rank = 16
        self.n_h = 32
        self.n_s = 32  # THIS WAS MISSING
        self.aux_position_predictor = aux_position_predictor
        
        self.history_gru = nn.GRUCell(base_embed_dim, self.n_h)
        
        # THIS WAS MISSING - required for checkpoint loading
        self.seq_embedder = BidirectionalHawk(
            d_input=base_embed_dim,
            d_output=self.n_s,
            d_state=64,
            n_heads=4,
            n_layers=1,
            dropout=0.0
        )
        
        # THIS WAS MISSING
        self.encoder = nn.Sequential(
            nn.Linear(base_embed_dim + self.n_h + self.n_s, 64),
            nn.ReLU(),
            nn.Linear(64, 2 * self.latent_dim)
        )
        
        self.switch_net = nn.Sequential(
            nn.Linear(base_embed_dim + self.n_h + self.latent_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )
        with torch.no_grad():
            self.switch_net[-1].bias.fill_(1.0)
            
        self.hypernet = nn.Sequential(
            nn.Linear(self.latent_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 2 * base_embed_dim * self.rank)
        )
        
        # THIS WAS MISSING
        if aux_position_predictor:
            self.position_predictor = nn.Sequential(
                nn.Linear(self.latent_dim, 32),
                nn.ReLU(),
                nn.Linear(32, 2)
            )

    def step_with_z(self, e_t, z, h_switch_prev):
        B = e_t.shape[0]
        switch_input = torch.cat([e_t, h_switch_prev, z], dim=1)

        logit = self.switch_net(switch_input)          # (B,1)
        beta_t = torch.sigmoid(logit)                  # (B,1)
        h_switch_t = self.history_gru(e_t, h_switch_prev)
        params = self.hypernet(z).view(B, 2, self.base_embed_dim, self.rank)
        return beta_t, h_switch_t, (params[:, 0], params[:, 1]), logit
        
# ==========================================
# 3. ENVIRONMENT WRAPPER
# ==========================================

class InternalRLWrapper(gym.Env):
    def __init__(self, device="cpu", beta_threshold=0.5, seq_len=64, render_mode=None, shared_models=None):
        super().__init__()
        self.device = device
        self.beta_threshold = beta_threshold
        self.env = gym.make("ALE/MontezumaRevenge-v5", render_mode=render_mode, frameskip=1)

        self.base_model, self.metacontroller, self.config = shared_models
        self.seq_len = seq_len
        self.control_layer = self.config.n_layers // 2

        self.num_images = 9
        self.frame_stack = 4
        self.frame_buffer = deque(maxlen=self.num_images * self.frame_stack)

        self.n_patches = (self.config.img_size // self.config.patch_size) ** 2

        self.action_history = deque(maxlen=self.seq_len - 1)
        # switch-state must persist across primitive steps inside an option
        self.h_switch = None

    def reset(self, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        self._reset_buffers(obs)
        self.h_switch = torch.zeros(1, self.metacontroller.n_h, device=self.device)
        return self._get_current_embedding(), info

    def primitive_step(self, action: int):
        obs, reward, terminated, truncated, info = self.env.step(int(action))
        self._update_frame_buffer(obs)
        self.action_history.append(int(action))
        return reward, terminated, truncated, info

    def _get_sequence_length(self):
        n_patches = (self.config.img_size // self.config.patch_size) ** 2
        return self.num_images * n_patches + self.seq_len

    def _reset_buffers(self, obs):
        self.frame_buffer.clear()
        self.action_history.clear()
    
        frame = self._process_frame(obs)
        for _ in range(self.num_images * self.frame_stack):
            self.frame_buffer.append(frame)
    
        for _ in range(self.seq_len - 1):
            self.action_history.append(0)


    def _update_frame_buffer(self, obs):
        self.frame_buffer.append(self._process_frame(obs))

    def _process_frame(self, obs):
        img = cv2.resize(obs, (84, 84))
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        return torch.tensor(img, dtype=torch.float32, device=self.device) / 255.0

    def _get_model_inputs(self):
        frames_list = list(self.frame_buffer)
        actions_list = list(self.action_history) + [0]  # placeholder for "next action"
        
        images = []
        for i in range(self.num_images):
            start = i * self.frame_stack
            images.append(torch.stack(frames_list[start : start + self.frame_stack], dim=0))
        frames = torch.stack(images, dim=0).unsqueeze(0)  # (1, NumImg, 4, 84, 84)
        actions = torch.tensor(actions_list, dtype=torch.long, device=self.device).unsqueeze(0)
        return frames, actions

    def _get_embedding_tensor(self):
        frames, actions = self._get_model_inputs()
        with torch.no_grad():
            out = self.base_model(frames, actions, return_residuals=True)  # no adapter (clean)
            lap = out["last_action_pos"]
            e = out["residuals"][self.control_layer][:, lap, :]
        return e


    def _get_current_embedding(self):
        return self._get_embedding_tensor().detach().cpu().numpy().flatten()

# ==========================================
# 4. AGENT & UTILS
# ==========================================

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Actor(nn.Module):
    def __init__(self, input_dim=256, latent_dim=8, hidden_dim=256, initial_std=0.5):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim) if input_dim != hidden_dim else nn.Identity()
        self.ssm = HawkBlock(hidden_dim, d_state=256, n_heads=8, mlp_ratio=2)

        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(hidden_dim, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, latent_dim), std=0.01)
        )
        self.log_std = nn.Parameter(torch.ones(latent_dim) * np.log(initial_std))

    def forward(self, x):
        """
        x: (B, D) or (B, T, D)
        returns a Normal over z with matching batch/time shape:
          - if x is (B, D): mu is (B, Z)
          - if x is (B, T, D): mu is (B, T, Z)
        """
        x = self.input_proj(x)

        if x.dim() == 2:
            # treat as length-1 sequence for HawkBlock.forward
            x = x.unsqueeze(1)  # (B, 1, D)
            x = self.ssm(x, reverse=False)  # (B, 1, D)
            x = x.squeeze(1)  # (B, D)
        elif x.dim() == 3:
            x = self.ssm(x, reverse=False)  # (B, T, D)
        else:
            raise ValueError(f"Actor.forward expected (B,D) or (B,T,D), got {tuple(x.shape)}")

        mu = self.actor_mean(x)
        std = torch.exp(self.log_std)

        # broadcast std to match mu shape
        while std.dim() < mu.dim():
            std = std.unsqueeze(0)
        std = std.expand_as(mu)

        return Normal(mu, std), None

    def step(self, x_t, state=None):
        # x_t: (B, D)
        x_t = self.input_proj(x_t)
        x_t, new_state = self.ssm.step(x_t, state)  # x_t: (B, D)
        mu = self.actor_mean(x_t)                   # (B, Z)
        std = torch.exp(self.log_std)               # (Z,)
        std = std.unsqueeze(0).expand_as(mu)        # (B, Z)
        return Normal(mu, std), new_state


class CSVLogger:
    def __init__(self, filename):
        self.filename = filename
        self.writer = None
        self.fieldnames = None
        self.buffer = []  # Accumulate rows here
        
        # Check if file exists and has content
        file_exists = os.path.exists(filename) and os.path.getsize(filename) > 0
        
        if file_exists:
            with open(filename, 'r') as f:
                reader = csv.DictReader(f)
                self.fieldnames = reader.fieldnames
            self.file = open(filename, 'a', newline='')
            if self.fieldnames:
                self.writer = csv.DictWriter(self.file, fieldnames=self.fieldnames)
        else:
            self.file = open(filename, 'w', newline='')
    
    def log(self, data, step, flush_every=5):
        # Initialize writer on first call
        if self.writer is None:
            self.fieldnames = list(data.keys())
            self.writer = csv.DictWriter(self.file, fieldnames=self.fieldnames)
            self.writer.writeheader()
        
        # Accumulate in buffer
        self.buffer.append(data)
        
        # Write all buffered rows every N steps
        if step % flush_every == 0:
            for row in self.buffer:
                self.writer.writerow(row)
            self.file.flush()
            self.buffer = []  # Clear buffer
    
    def close(self):
        # Write any remaining buffered data
        if self.buffer and self.writer:
            for row in self.buffer:
                self.writer.writerow(row)
            self.file.flush()
        if self.file:
            self.file.close()

def load_shared_models(device):
    # Module mapping for unpickling
    sys.modules['src.models.vlm'].Config = Config
    sys.modules['src.models.vlm'].MontezumaVLM = MontezumaVLM
    sys.modules['src.models.vlm'].PatchEmbedding = PatchEmbedding
    sys.modules['src.models.vlm'].TransformerBlock = TransformerBlock
    sys.modules['src.models.vlm'].CausalSelfAttention = CausalSelfAttention
    
    sys.modules['src.models.metacontroller'].Metacontroller = Metacontroller
    sys.modules['src.models.metacontroller'].HawkBlock = HawkBlock
    sys.modules['src.models.metacontroller'].LRU = LRU
    sys.modules['src.models.metacontroller'].BidirectionalHawk = BidirectionalHawk  # ADD THIS
    
    config = Config(img_size=cfg.img_size, patch_size=cfg.patch_size, embed_dim=cfg.embed_dim, 
                    n_layers=cfg.n_layers, n_heads=cfg.n_heads, seq_len=cfg.seq_len, frame_stack=cfg.frame_stack)
    
    base_model = MontezumaVLM(config).to(device).eval()
    metacontroller = Metacontroller(config, 256, aux_position_predictor=False).to(device).eval()
    
    if os.path.exists(cfg.model_checkpoint):
        print(f"Loading base model from {cfg.model_checkpoint}")
        ckpt = torch.load(cfg.model_checkpoint, map_location=device, weights_only=False)
        base_model.load_state_dict(ckpt['model_state_dict'])  # NO strict=False
        
    if os.path.exists(cfg.meta_checkpoint):
        print(f"Loading metacontroller from {cfg.meta_checkpoint}")
        ckpt = torch.load(cfg.meta_checkpoint, map_location=device, weights_only=False)
        metacontroller.load_state_dict(ckpt["metacontroller_state_dict"])  # NO strict=False
        # === DIAGNOSTIC: Compare checkpoints ===
    import hashlib
    
    def get_weight_fingerprint(model, name):
        """Get a fingerprint of model weights for comparison"""
        total = 0.0
        count = 0
        for p in model.parameters():
            total += p.data.float().sum().item()
            count += p.numel()
        return total / count if count > 0 else 0
    
    vlm_fp = get_weight_fingerprint(base_model, "VLM")
    meta_fp = get_weight_fingerprint(metacontroller, "Meta")
    
    print(f"=== CHECKPOINT FINGERPRINTS ===")
    print(f"VLM weight mean: {vlm_fp:.6f}")
    print(f"Metacontroller weight mean: {meta_fp:.6f}")
    print(f"Hypernet[0] weight sum: {metacontroller.hypernet[0].weight.sum().item():.6f}")
    print(f"Hypernet[-1] weight sum: {metacontroller.hypernet[-1].weight.sum().item():.6f}")
    
    # Quick sanity check - run a forward pass
    with torch.no_grad():
        test_z = torch.randn(1, 8, device=device)
        test_params = metacontroller.hypernet(test_z)
        print(f"Hypernet output norm: {test_params.norm().item():.4f}")
    
    return base_model, metacontroller, config

# ==========================================
# 5. TRAINING LOOP
# ==========================================
def compute_action_positions(num_images: int, n_patches: int, seq_len: int, chunk: int = 8):
    """
    Mirrors MontezumaVLM interleaving:
      [patches(img0), actions(0..7), patches(img1), actions(8..15), ...]
    Returns:
      action_positions: list[int] absolute token indices
      T_tokens: total token length of the interleaved sequence
    """
    num_actions = seq_len  # start token + (seq_len-1) previous actions = seq_len tokens
    action_positions = []
    cur = 0
    for i in range(num_images):
        cur += n_patches
        start = i * chunk
        end = min((i + 1) * chunk, num_actions)
        if start < num_actions:
            k = end - start
            action_positions.extend([cur + j for j in range(k)])
            cur += k
    T_tokens = num_images * n_patches + num_actions
    return action_positions, T_tokens


def build_tokenwise_adapters(curr_z, metacontroller, config, num_images, seq_len, device):
    """
    Builds A_seq/B_seq of shape (N, T_tokens, D, R) with nonzeros only at action token positions.
    This matches the way train_meta.py builds per-token controller sequences.
    """
    N = curr_z.shape[0]
    D = config.embed_dim
    R = metacontroller.rank
    n_patches = (config.img_size // config.patch_size) ** 2

    action_pos, T_tokens = compute_action_positions(num_images, n_patches, seq_len)
    idx = torch.tensor(action_pos, device=device, dtype=torch.long)  # (T_act,)

    # (N, 2, D, R) -> A0/B0: (N, D, R)
    params = metacontroller.hypernet(curr_z).view(N, 2, D, R)
    A0 = params[:, 0]
    B0 = params[:, 1]

    # Tokenwise (N, T_tokens, D, R), filled only at action positions
    A_seq = torch.zeros(N, T_tokens, D, R, device=device)
    B_seq = torch.zeros_like(A_seq)

    A_seq[:, idx] = A0.unsqueeze(1).expand(N, idx.numel(), D, R)
    B_seq[:, idx] = B0.unsqueeze(1).expand(N, idx.numel(), D, R)

    return A_seq, B_seq

def fixed_switch_and_beta(fixed_rate, option_prim_steps, device):
    """
    Returns:
      will_switch: (N,) bool
      beta:        (N,) float (for logging only)
      beta_logit:  (N,1) float (for logging only)
    """
    N = len(option_prim_steps)
    fr = float(fixed_rate)

    if fr < 1.0:
        will_switch = (torch.rand(N, device=device) < fr)
        beta_val = fr
    else:
        k = max(1, int(fr))
        will_switch = torch.tensor([(s + 1) >= k for s in option_prim_steps],
                                   device=device, dtype=torch.bool)
        beta_val = 1.0 / float(k)

    beta = torch.full((N,), beta_val, device=device)
    beta_logit = torch.logit(beta.clamp(1e-6, 1.0 - 1e-6)).unsqueeze(-1)
    return will_switch, beta, beta_logit


def train_rl():
    control_layer = cfg.n_layers // 2
    
    def _switch_stats(i):
        # safe stats even if no segments yet
        segs = t_seg_lens[i]
        if len(segs) == 0:
            return f"sw={t_switches[i]} segs=0 prim={t_prim_steps[i]}"
        avg_len = float(np.mean(segs))
        med_len = float(np.median(segs))
        p90_len = float(np.percentile(segs, 90))
        sw_per_step = float(t_switches[i]) / max(1, t_prim_steps[i])
        return f"sw={t_switches[i]} segs={len(segs)} len(avg/med/p90)={avg_len:.1f}/{med_len:.0f}/{p90_len:.0f} sw/step={sw_per_step:.3f} prim={t_prim_steps[i]}"

    
    print(f"Using device: {cfg.device}")

    # 1) Load shared models
    shared_models = load_shared_models(cfg.device)
    base_model, metacontroller, config = shared_models

    # 2) Create internal environments (internal step = execute until beta triggers)
    print(f"Initializing {cfg.num_envs} environments...")
    envs = [
        InternalRLWrapper(
            device=cfg.device,
            beta_threshold=0.75,
            seq_len=cfg.seq_len,
            render_mode=None,
            shared_models=shared_models,
        )
        for _ in range(cfg.num_envs)
    ]

    # 3) Initialize policy
    agent = Actor(
        input_dim=cfg.input_dim,
        latent_dim=cfg.latent_dim,
        hidden_dim=cfg.hidden_dim,
        initial_std=cfg.initial_std,
    ).to(cfg.device)
    optimizer = optim.Adam(agent.parameters(), lr=cfg.lr, eps=1e-5)

    logger = CSVLogger(cfg.log_path)
    ep_logger = CSVLogger(cfg.episode_log_path)
    global_trajs = 0
    global_eps = 0
    global_env_eps = 0
    start_epoch = 0

    if os.path.exists(cfg.save_path):
        print(f"Resuming from {cfg.save_path}...")
        ckpt = torch.load(cfg.save_path, map_location=cfg.device)
        agent.load_state_dict(ckpt["policy_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        start_epoch = ckpt["epoch"]

    # ------------------------------------------------------------
    # Rollout state per env (INTERNAL timescale)
    # ------------------------------------------------------------
    curr_obs = [None] * cfg.num_envs          # numpy (256,)
    curr_z = torch.zeros(cfg.num_envs, cfg.latent_dim, device=cfg.device)
    curr_h = [None] * cfg.num_envs            # recurrent state for policy

    # per-env stagnation counter (counts INTERNAL steps without reward)
    last_reward_timers = [0] * cfg.num_envs

    # per-env trajectory buffers (aligned: obs[t] -> act[t] -> rew[t])
    t_obs = [[] for _ in range(cfg.num_envs)]       # list of np arrays
    t_acts = [[] for _ in range(cfg.num_envs)]      # list of torch tensors (latent_dim,)
    t_logps = [[] for _ in range(cfg.num_envs)]     # list of floats
    t_rews = [[] for _ in range(cfg.num_envs)]      # list of floats
    t_seg_lens = [[] for _ in range(cfg.num_envs)]          # lengths (primitive steps) per option segment
    t_switches = [0 for _ in range(cfg.num_envs)]           # count beta-trigger terminations
    t_prim_steps = [0 for _ in range(cfg.num_envs)]         # primitive steps elapsed in trajectory
    curr_seg_len = [0 for _ in range(cfg.num_envs)]         # current segment primitive length accumulator
    t_beta_max = [0.0 for _ in range(cfg.num_envs)]
    t_beta_sum = [0.0 for _ in range(cfg.num_envs)]
    t_beta_n = [0 for _ in range(cfg.num_envs)]
    
    t_logit_max = [-1e9 for _ in range(cfg.num_envs)]
    t_logit_sum = [0.0 for _ in range(cfg.num_envs)]
    t_logit_n = [0 for _ in range(cfg.num_envs)]
    # --- add these BEFORE the epoch loop (or at least before collection starts) ---
    option_rewards = [0.0 for _ in range(cfg.num_envs)]      # accum primitive rewards inside current option
    option_prim_steps = [0   for _ in range(cfg.num_envs)]   # primitive steps inside current option

    

    # Initial reset + sample initial z (store (s0, z0, logp0); reward comes after executing z0)
    print("Performing initial reset...")
    for i in range(cfg.num_envs):
        obs_np, _ = envs[i].reset()
        curr_obs[i] = obs_np

        obs_t = torch.tensor(obs_np, dtype=torch.float32, device=cfg.device).unsqueeze(0)
        with torch.no_grad():
            dist, h_next = agent.step(obs_t, None)
            z = dist.sample()
            logp = dist.log_prob(z).sum(dim=-1)  # shape (1,)

        curr_z[i] = z.squeeze(0)
        curr_h[i] = h_next

        t_obs[i].append(obs_np)                 # s0
        t_acts[i].append(z.squeeze(0).cpu())    # z0
        t_logps[i].append(float(logp.item()))   # log Ï€(z0|s0)

    print(f"Starting internal RL (Max no-reward internal steps: {cfg.max_steps_without_reward})...")

    # ------------------------------------------------------------
    # Training loop
    # ------------------------------------------------------------
    for epoch in range(start_epoch, cfg.epochs):
        epoch_start = time.time()

        # batch buffers (variable-length trajectories)
        b_obs, b_acts, b_oldlogps, b_advs = [], [], [], []
        batch_total_rewards = []
        trajs_collected = 0

        # ----------------------------
        # COLLECT trajectories
        # ----------------------------
        with torch.no_grad():
            while trajs_collected < cfg.batch_size:

                # ---- (B) gather model inputs (PRE-STEP) ----
                frames_list, actions_list = [], []
                for env in envs:
                    f, a = env._get_model_inputs()
                    frames_list.append(f)
                    actions_list.append(a)
                
                all_frames  = torch.cat(frames_list, dim=0)   # (N, NumImg, 4, 84, 84)
                all_actions = torch.cat(actions_list, dim=0)  # (N, seq_len)

                if cfg.fixed_rate != 0.0:
                    will_switch, beta, beta_logit = fixed_switch_and_beta(
                        fixed_rate=cfg.fixed_rate,
                        option_prim_steps=option_prim_steps,
                        device=cfg.device,
                    )
                else:
                    # ============================================================
                    # PASS 1 (CLEAN): compute e_t for beta + h_switch update
                    # ============================================================
                    out_clean = base_model(
                        all_frames,
                        all_actions,
                        return_residuals=True,
                        adapter_params=None,                      # IMPORTANT: CLEAN
                    )
                    
                    lap = out_clean["last_action_pos"]            # scalar int
                    e_t_clean = out_clean["residuals"][control_layer][:, lap, :]  # (N, D)
                    
                    h_prev = torch.cat([env.h_switch for env in envs], dim=0)     # (N, n_h)
                    switch_inp = torch.cat([e_t_clean, h_prev, curr_z], dim=1)    # (N, D+n_h+Z)
                    
                    beta_logit = metacontroller.switch_net(switch_inp)            # (N, 1)
                    beta = torch.sigmoid(beta_logit).squeeze(-1)                  # (N,)
                    
                    h_next = metacontroller.history_gru(e_t_clean, h_prev)        # (N, n_h)
                    for i in range(cfg.num_envs):
                        envs[i].h_switch = h_next[i:i+1]
                    
                    will_switch = beta > envs[0].beta_threshold                   # (N,)

                # ============================================================
                # PASS 2 (CONTROLLED): sample primitive action logits under adapters
                # ============================================================
                A_seq, B_seq = build_tokenwise_adapters(
                    curr_z=curr_z,
                    metacontroller=metacontroller,
                    config=config,
                    num_images=envs[0].num_images,
                    seq_len=cfg.seq_len,
                    device=cfg.device,
                )
                
                out_ctrl = base_model(
                    all_frames,
                    all_actions,
                    return_residuals=False,
                    adapter_params=(A_seq, B_seq, control_layer, "action_tokens"),
                )

                
                probs = torch.softmax(out_ctrl["logits"], dim=-1)             # (N, n_actions)
                prim_actions = torch.multinomial(probs, 1).squeeze(-1)        # (N,)
                
                # ---- (D) step all envs one primitive step (using sampled actions) ----
                done_mask = torch.zeros(cfg.num_envs, device=cfg.device, dtype=torch.bool)
                for i in range(cfg.num_envs):
                    a = int(prim_actions[i].item())
                    r, terminated, truncated, _ = envs[i].primitive_step(a)
                    done = terminated or truncated
                    done_mask[i] = done
                
                    # primitive-step accounting
                    option_rewards[i] += float(r)
                    option_prim_steps[i] += 1
                    t_prim_steps[i] += 1
                    curr_seg_len[i] += 1
                
                # ---- (G) handle option termination / trajectory termination ----
                for i in range(cfg.num_envs):
                    # beta/logit stats (per trajectory)
                    b = float(beta[i].item())
                    l = float(beta_logit[i].item())
                    t_beta_max[i] = max(t_beta_max[i], b)
                    t_beta_sum[i] += b
                    t_beta_n[i] += 1
                    t_logit_max[i] = max(t_logit_max[i], l)
                    t_logit_sum[i] += l
                    t_logit_n[i] += 1
                
                    done = bool(done_mask[i].item())
                    ended_by_beta = bool(will_switch[i].item()) and (not done)
                
                    if ended_by_beta or done:
                        if ended_by_beta:
                            t_switches[i] += 1
                        t_seg_lens[i].append(curr_seg_len[i])
                        curr_seg_len[i] = 0
                
                        seg_reward = option_rewards[i]
                        option_rewards[i] = 0.0
                        option_prim_steps[i] = 0
                
                        t_rews[i].append(float(seg_reward))
                
                        if seg_reward > 0:
                            last_reward_timers[i] = 0
                        else:
                            last_reward_timers[i] += 1
                        timeout = last_reward_timers[i] >= cfg.max_steps_without_reward
                
                        traj_complete = done or timeout
                        if traj_complete:
                            T = len(t_rews[i])
                            assert len(t_obs[i]) == T and len(t_acts[i]) == T and len(t_logps[i]) == T
                
                            if trajs_collected < cfg.batch_size:
                                total_r = float(sum(t_rews[i]))
                                batch_total_rewards.append(total_r)
                
                                b_obs.append(torch.tensor(np.array(t_obs[i]), dtype=torch.float32, device=cfg.device))
                                b_acts.append(torch.stack(t_acts[i]).to(cfg.device))
                                b_oldlogps.append(torch.tensor(t_logps[i], dtype=torch.float32, device=cfg.device))
                                b_advs.append(torch.full((T,), total_r, dtype=torch.float32, device=cfg.device))
                
                                trajs_collected += 1
                
                                beta_mean = t_beta_sum[i] / max(1, t_beta_n[i])
                                logit_mean = t_logit_sum[i] / max(1, t_logit_n[i])
                                status = "DONE" if done else "TIMEOUT"
                                print(
                                    f"\n Traj {trajs_collected}/{cfg.batch_size} | R: {total_r:.1f} | L: {T} | {status} | "
                                    f"sw={t_switches[i]} segs={len(t_seg_lens[i])} "
                                    f"len(avg/med/p90)={np.mean(t_seg_lens[i]):.1f}/{np.median(t_seg_lens[i]):.0f}/{np.percentile(t_seg_lens[i],90):.0f} "
                                    f"sw/step={t_switches[i]/max(1,t_prim_steps[i]):.3f} prim={t_prim_steps[i]} | "
                                    f"beta(max/mean)={t_beta_max[i]:.3f}/{beta_mean:.3f} logit(max/mean)={t_logit_max[i]:.3f}/{logit_mean:.3f}",
                                    end=""
                                )

                                global_eps += 1
                                if done:
                                    global_env_eps += 1
                                
                                used_in_batch = int(trajs_collected < cfg.batch_size)
                                
                                ep_logger.log(
                                    {
                                        "episode": global_eps,                 # DONE+TIMEOUT
                                        "env_episode": global_env_eps,         # DONE only
                                        "epoch": epoch + 1,
                                        "used_in_batch": used_in_batch,
                                        "env_id": i,
                                        "status": status,                      # "DONE" or "TIMEOUT"
                                        "return": total_r,
                                        "L_options": T,                        # number of option decisions
                                        "prim_steps": t_prim_steps[i],
                                        "switches": t_switches[i],
                                        "segs": len(t_seg_lens[i]),
                                        "seg_len_avg": float(np.mean(t_seg_lens[i])) if len(t_seg_lens[i]) else 0.0,
                                        "seg_len_med": float(np.median(t_seg_lens[i])) if len(t_seg_lens[i]) else 0.0,
                                        "seg_len_p90": float(np.percentile(t_seg_lens[i], 90)) if len(t_seg_lens[i]) else 0.0,
                                        "beta_max": float(t_beta_max[i]),
                                        "beta_mean": float(beta_mean),
                                        "logit_max": float(t_logit_max[i]),
                                        "logit_mean": float(logit_mean),
                                    },
                                    step=global_eps,
                                    flush_every=50,   # reduce IO
                                )
                
                            # reset env + per-traj buffers
                            obs_np, _ = envs[i].reset()
                            curr_h[i] = None
                            last_reward_timers[i] = 0
                
                            t_obs[i], t_acts[i], t_logps[i], t_rews[i] = [], [], [], []
                
                            t_seg_lens[i] = []
                            t_switches[i] = 0
                            t_prim_steps[i] = 0
                            curr_seg_len[i] = 0
                
                            t_beta_max[i] = 0.0
                            t_beta_sum[i] = 0.0
                            t_beta_n[i] = 0
                            t_logit_max[i] = -1e9
                            t_logit_sum[i] = 0.0
                            t_logit_n[i] = 0
                
                            # sample initial z for new trajectory and store (s0,z0,logp0)
                            obs_t = torch.tensor(obs_np, dtype=torch.float32, device=cfg.device).unsqueeze(0)
                            dist, h_next_pol = agent.step(obs_t, None)
                            z = dist.sample()
                            logp = dist.log_prob(z).sum(dim=-1)
                
                            curr_z[i] = z.squeeze(0)
                            curr_h[i] = h_next_pol
                
                            t_obs[i].append(obs_np)
                            t_acts[i].append(z.squeeze(0).cpu())
                            t_logps[i].append(float(logp.item()))
                
                        else:
                            # option boundary: sample next z using CURRENT embedding (post-step state)
                            # NOTE: this does a VLM forward, but only at boundaries (much rarer than every primitive step).
                            obs_np = envs[i]._get_current_embedding()
                            obs_t = torch.tensor(obs_np, dtype=torch.float32, device=cfg.device).unsqueeze(0)
                
                            dist, h_next_pol = agent.step(obs_t, curr_h[i])
                            z = dist.sample()
                            logp = dist.log_prob(z).sum(dim=-1)
                
                            curr_z[i] = z.squeeze(0)
                            curr_h[i] = h_next_pol
                
                            t_obs[i].append(obs_np)
                            t_acts[i].append(z.squeeze(0).cpu())
                            t_logps[i].append(float(logp.item()))
                    if trajs_collected >= cfg.batch_size:
                        break

        print("")  # newline after collection

        # quick z stats (not used for learning, just logging)
        z_diversity = curr_z.std(dim=0).mean().item()
        z_norm = curr_z.norm(dim=1).mean().item()
        print(f"Z diversity: {z_diversity:.4f} | Z norm: {z_norm:.4f}")

        # ----------------------------
        # UPDATE (PPO, recurrent unroll)
        # ----------------------------
        if not batch_total_rewards:
            continue

        collection_time = time.time() - epoch_start
        all_r = torch.tensor(batch_total_rewards, device=cfg.device)
        mean_r = all_r.mean()
        max_r = all_r.max().item()
        std_r = all_r.std() + 1e-8

        # normalize advantages trajectory-wise (same value repeated per step)
        for j in range(len(b_advs)):
            b_advs[j] = (b_advs[j] - mean_r) / std_r

        # PPO epochs over same batch
        ppo_epochs = 4
        clip_eps = 0.2

        total_loss_val = 0.0

        for _ in range(ppo_epochs):
            optimizer.zero_grad()

            total_pg = 0.0
            total_ent = 0.0
            total_steps = 0

            # Unroll each trajectory with agent.step() to compute new logprobs correctly for a recurrent policy
            for traj_obs, traj_act, traj_oldlogp, traj_adv in zip(b_obs, b_acts, b_oldlogps, b_advs):
                h = None
                T = traj_obs.shape[0]

                for t in range(T):
                    obs_t = traj_obs[t].unsqueeze(0)          # (1,256)
                    act_t = traj_act[t].unsqueeze(0)          # (1,8)

                    dist, h = agent.step(obs_t, h)
                    new_logp = dist.log_prob(act_t).sum(dim=-1)       # (1,)
                    old_logp = traj_oldlogp[t].view(1)                # (1,)
                    adv_t = traj_adv[t].view(1)                       # (1,)
                    ent = dist.entropy().sum(dim=-1)                  # (1,)

                    logratio = new_logp - old_logp
                    ratio = torch.exp(logratio)

                    pg1 = -adv_t * ratio
                    pg2 = -adv_t * torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
                    pg = torch.maximum(pg1, pg2)

                    total_pg = total_pg + pg
                    total_ent = total_ent + ent
                    total_steps += 1

            # mean over all timesteps in batch
            pg_loss = total_pg / max(total_steps, 1)
            ent_mean = total_ent / max(total_steps, 1)
            total_loss = pg_loss - cfg.entropy_coef * ent_mean

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
            optimizer.step()

            total_loss_val = float(total_loss.item())

        global_trajs += trajs_collected

        logger.log(
            {
                "epoch": epoch + 1,
                "step": global_trajs,
                "reward": float(mean_r.item()),
                "loss": total_loss_val,
                "z_diversity": z_diversity,
                "z_norm": z_norm,
            },
            step=epoch + 1,
            flush_every=5,
        )

        print(
            f"Epoch {epoch+1} | Mean: {mean_r.item():.2f} | Max: {max_r:.0f} | "
            f"Loss: {total_loss_val:.4f} | Z-div: {z_diversity:.4f} | Coll: {collection_time:.1f}s"
        )

        if (epoch + 1) % 5 == 0:
            torch.save(
                {
                    "policy_state_dict": agent.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "epoch": epoch + 1,
                },
                cfg.save_path,
            )
            print(f"Saved checkpoint to {cfg.save_path}")

if __name__ == "__main__":
    train_rl()