# V-JEPA2 Policy Training Pipeline

This notebook implements policy training using the V-JEPA2 encoder, following the same structure as the RSSM-based policy training.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np
import random
from pathlib import Path
import matplotlib.pyplot as plt
import ale_py
import imageio
gym.register_envs(ale_py)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cuda


In [None]:
# --- Experiment Configuration ---
import os

GAME_ID = os.environ.get("ATARI_GAME", "PongNoFrameskip-v4")
USE_VJEPA_ENCODER = bool(int(os.environ.get("USE_VJEPA_ENCODER", "0")))
# Temporal context (number of frames provided to encoder / policy)
T_CONTEXT = int(os.environ.get("T_CONTEXT", "4"))
# Action embedding only applies when using V-JEPA encoder
EMBED_ACTIONS = bool(int(os.environ.get("VJEPA_EMBED_ACTIONS", "1"))) if USE_VJEPA_ENCODER else False
NUM_EPISODES = int(os.environ.get("NUM_TRAIN_EPISODES", "100"))
EVAL_EPISODES = int(os.environ.get("NUM_EVAL_EPISODES", "5"))
GAMMA = float(os.environ.get("DISCOUNT_GAMMA", "0.99"))
LEARNING_RATE = float(os.environ.get("POLICY_LR", "1e-4"))

print(f"Game: {GAME_ID}")
print(f"Use V-JEPA encoder: {USE_VJEPA_ENCODER}")
print(f"Temporal context (frames): {T_CONTEXT}")
print(f"Action embedding (V-JEPA only): {EMBED_ACTIONS}")
print(f"Training episodes: {NUM_EPISODES}")
print(f"Evaluation episodes: {EVAL_EPISODES}")


## Environment Setup


In [None]:
from gymnasium import spaces

# Create environment
env = gym.make(GAME_ID)
env = gym.wrappers.ResizeObservation(env, (84, 84))

def transform_obs(obs):
    obs_t = np.transpose(obs, (2, 0, 1)).astype(np.float32) / 255.0
    return obs_t

new_obs_space = spaces.Box(low=0.0, high=1.0, shape=(3, 84, 84), dtype=np.float32)

env = gym.wrappers.TransformObservation(
    env,
    func=transform_obs,
    observation_space=new_obs_space,
)

obs, info = env.reset()
assert obs.shape == (3, 84, 84), f"Expected (3, 84, 84), got {obs.shape}"
print(f"Environment initialized successfully for {GAME_ID}")


Environment initialized successfully


## Load V-JEPA2 Encoder


In [None]:
import os
import torch
import torch.nn.functional as F

# V-JEPA2 configuration (only used when USE_VJEPA_ENCODER=True)
VJEPA_BACKEND = os.environ.get("VJEPA_BACKEND", "hub")
# Note: V-JEPA2-AC was designed for robot control with explicit states (end-effector, joints)
# For Atari (visual-only), regular V-JEPA2 may be more appropriate
# Set USE_AC=1 to try AC model (uses visual tokens as states), USE_AC=0 for regular model
USE_AC = bool(int(os.environ.get("VJEPA_USE_AC", "0")))
HF_REPO = os.environ.get("VJEPA_HF_REPO", "facebook/vjepa2-vitl-fpc64-256")

preprocessor = None
vjepa_model = None
vjepa_ac_predictor = None

if not USE_VJEPA_ENCODER:
    print("Skipping V-JEPA encoder load (baseline CNN mode).")
else:
    if VJEPA_BACKEND == "hub":
        preprocessor = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_preprocessor')
        if USE_AC:
            obj = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_ac_vit_giant')
            if isinstance(obj, tuple):
                vjepa_model, vjepa_ac_predictor = obj[0], obj[1]
            else:
                vjepa_model = obj
                vjepa_ac_predictor = None
            print("Loaded V-JEPA 2-AC (ViT-g) from hub.")
            if vjepa_ac_predictor is not None:
                vjepa_ac_predictor.to(device).eval()
                for p in vjepa_ac_predictor.parameters():
                    p.requires_grad_(False)
                print("Action-conditioned predictor loaded and frozen.")
        else:
            obj = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_large')
            vjepa_model = obj[0] if isinstance(obj, tuple) else obj
            print("Loaded V-JEPA 2 (ViT-L) from hub.")
        
        vjepa_model.to(device).eval()
        for p in vjepa_model.parameters():
            p.requires_grad_(False)

    elif VJEPA_BACKEND == "hf":
        from transformers import AutoVideoProcessor, AutoModel
        vjepa_model = AutoModel.from_pretrained(HF_REPO).to(device).eval()
        preprocessor = AutoVideoProcessor.from_pretrained(HF_REPO)
        for p in vjepa_model.parameters():
            p.requires_grad_(False)
        print(f"Loaded {HF_REPO} from Hugging Face.")
        if USE_AC:
            print("Warning: Action-conditioned model not available via HuggingFace. Use hub backend.")
    else:
        raise ValueError("VJEPA_BACKEND must be 'hub' or 'hf'.")

    print(f"V-JEPA2 encoder loaded successfully (AC mode: {USE_AC})")


Downloading: "https://github.com/facebookresearch/vjepa2/zipball/main" to /root/.cache/torch/hub/main.zip


Using cache found in /root/.cache/torch/hub/facebookresearch_vjepa2_main


Downloading: "https://dl.fbaipublicfiles.com/vjepa2/vitl.pt" to /root/.cache/torch/hub/checkpoints/vitl.pt


100%|██████████| 4.78G/4.78G [00:17<00:00, 294MB/s]


Loaded V-JEPA 2 (ViT-L) from hub.
V-JEPA2 encoder loaded successfully (AC mode: False)


## V-JEPA2 Feature Extractor


In [None]:
class VJEPA2FeatureExtractor(nn.Module):
    """
    Extracts features from a single observation using V-JEPA2 encoder.

    V-JEPA2 expects 256x256 images, so we need to:
    1. Resize observations to 256x256 (we use 84×84 → 256×256)
    2. Create a temporal context (V-JEPA expects sequences)
    3. Extract spatial features from center frame tokens

    If action-conditioned predictor is available, actions are used to condition features.
    """
    def __init__(self, vjepa_model, preprocessor, device, ac_predictor=None, action_dim=None, embed_actions_in_encoder=False):
        super().__init__()
        self.vjepa_model = vjepa_model
        self.preprocessor = preprocessor
        self.device = device
        self.ac_predictor = ac_predictor
        self.use_ac = ac_predictor is not None
        self.action_dim = action_dim
        self.embed_actions_in_encoder = embed_actions_in_encoder

        # Create action embedding if we want to encode actions into visual tokens
        if self.embed_actions_in_encoder and action_dim is not None:
            # Action embedding: map action indices to embedding space
            # Embedding dimension should match token dimension (typically 1024)
            self.action_embedding = nn.Embedding(action_dim, 1024).to(device)
            print(f"Created action embedding layer: {action_dim} actions -> 1024 dim")
        else:
            self.action_embedding = None

        # Infer output dimensions from the encoder
        self._infer_output_dims()

    def _infer_output_dims(self):
        """Probe the encoder to determine output dimensions"""
        # Create dummy input: single frame resized to 256x256, then create temporal context
        # Use a simpler approach: create a dummy observation similar to what we'll get from env
        dummy_obs = np.random.rand(3, 64, 64).astype(np.float32)

        with torch.no_grad():
            # Convert to tensor and resize (matching forward() method)
            obs_tensor = torch.tensor(dummy_obs, dtype=torch.float32, device=self.device).unsqueeze(0)
            obs_resized = F.interpolate(
                obs_tensor,
                size=(256, 256),
                mode='nearest'
            )  # (1, 3, 256, 256)

            # Create temporal context (for inference, just repeat frame since we don't have history)
            # In actual usage, obs_history will provide real sequences
            context = max(T_CONTEXT, 1)
            obs_seq = obs_resized.repeat(context, 1, 1, 1).unsqueeze(0)  # (1, T, C, H, W)
            obs_seq = obs_seq.permute(0, 2, 1, 3, 4)  # (1, C, T, H, W)

            # Process through preprocessor (matching forward() method)
            proc_list = []
            for t_idx in range(context):
                frame = obs_seq[0, :, t_idx]  # Get (C, H, W) frame
                if self.preprocessor is not None:
                    # Convert to (H, W, C) numpy array for preprocessor
                    frame_np = frame.cpu().permute(1, 2, 0).numpy()  # (H, W, C)
                    # Ensure values are in valid range [0, 1] or [0, 255]
                    frame_np = np.clip(frame_np, 0, 1)
                    out = self.preprocessor([frame_np])
                    out = out[0] if isinstance(out, (list, tuple)) else out
                    if isinstance(out, torch.Tensor):
                        # Ensure output is (C, H, W)
                        if out.dim() == 4:  # (C, T, H, W) or (B, C, H, W)
                            if out.shape[1] == 3 or out.shape[0] == 3:
                                # Likely (C, T, H, W) or (B, C, H, W)
                                if out.shape[0] == 3:  # (C, T, H, W)
                                    out = out[:, 0]  # Take first temporal slice -> (C, H, W)
                                else:  # (B, C, H, W)
                                    out = out[0]  # Take first batch -> (C, H, W)
                        elif out.dim() == 3:
                            # Could be (C, H, W) or (H, W, C)
                            if out.shape[0] != 3 and out.shape[2] == 3:
                                out = out.permute(2, 0, 1)  # (H, W, C) -> (C, H, W)
                        proc_list.append(out)
                    else:
                        # Convert numpy to tensor if needed
                        if isinstance(out, np.ndarray):
                            out = torch.from_numpy(out)
                            # Ensure (C, H, W) format
                            if out.dim() == 3:
                                if out.shape[0] != 3 and out.shape[2] == 3:
                                    out = out.permute(2, 0, 1)  # (H, W, C) -> (C, H, W)
                        proc_list.append(out)
                else:
                    proc_list.append(frame)

            # Stack processed frames: (T, C, H, W) -> (1, C, T, H, W)
            if len(proc_list) > 0:
                # Verify all frames have shape (C=3, H, W)
                for i, frame in enumerate(proc_list):
                    if frame.dim() != 3 or frame.shape[0] != 3:
                        raise RuntimeError(f"Frame {i} has unexpected shape: {frame.shape}, expected (3, H, W)")

                # Each item in proc_list is (C, H, W), stack to (T, C, H, W)
                stacked = torch.stack(proc_list, dim=0)  # (T, C, H, W)
                # Rearrange to (1, C, T, H, W)
                proc = stacked.permute(1, 0, 2, 3).unsqueeze(0).contiguous()  # (1, C, T, H, W)
                proc = proc.to(self.device)  # (1, C, T, 256, 256)

                # Encode
                out = self.vjepa_model(proc)

                # Extract token features
                if isinstance(out, (list, tuple)) and len(out) > 0:
                    tokens = out[0]
                elif isinstance(out, dict):
                    tokens = out.get('x', list(out.values())[0])
                else:
                    tokens = out

                # Get center temporal slice and spatial dimension
                if tokens.dim() == 4:  # (B, T', N, D)
                    b, T, N, D = tokens.shape
                    tokens = tokens[:, T // 2]  # Center frame
                elif tokens.dim() == 3:  # (B, N, D)
                    b, N, D = tokens.shape
                else:
                    raise RuntimeError(f"Unexpected V-JEPA output shape: {tokens.shape}")

                # Compute spatial grid
                Ht = Wt = int(np.sqrt(N))
                self.token_dim = D
                self.spatial_tokens = Ht * Wt
                self.flat_feature_dim = D  # Will pool spatially for policy

                print(f"V-JEPA2 feature dims: token_dim={D}, spatial_grid={Ht}x{Wt}")
            else:
                # Fallback dimensions
                self.token_dim = 1024
                self.spatial_tokens = 256
                self.flat_feature_dim = 1024
                print("Using default V-JEPA2 feature dims")

    @torch.no_grad()
    def forward(self, obs, obs_history=None, actions=None):
        """
        obs: (3, 84, 84) normalized to [0,1] in C,H,W format (current frame)
        obs_history: Optional list of previous observations. If provided, should be a list
                    of T_context-1 frames. Current obs will be appended to make T_context total.
                    If None, will use current frame repeated (fallback).
        actions: Optional action history for AC predictor.
                If use_ac=True, should be tensor of shape (T_context,) or (T_context, action_dim)
                If None and use_ac=True, uses zero actions (no-op)
        Returns: flattened features suitable for policy network
        """
        context = max(T_CONTEXT, 1)  # Match temporal context used in training

        # 1. Prepare observation sequence
        if obs_history is not None and len(obs_history) > 0:
            # Use actual temporal history: [oldest, ..., newest]
            # Append current observation
            obs_list = obs_history + [obs]

            # Ensure we have exactly context frames
            if len(obs_list) > context:
                obs_list = obs_list[-context:]  # Take most recent frames
            elif len(obs_list) < context:
                # Pad from the left with oldest frame
                oldest = obs_list[0]
                obs_list = [oldest] * (context - len(obs_list)) + obs_list

            # Convert list to tensor and resize each frame
            obs_tensors = []
            for frame in obs_list:
                frame_tensor = torch.tensor(frame, dtype=torch.float32, device=self.device).unsqueeze(0)
                frame_resized = F.interpolate(frame_tensor, size=(256, 256), mode='nearest')
                obs_tensors.append(frame_resized.squeeze(0))

            # Stack to (T, C, H, W)
            obs_resized_seq = torch.stack(obs_tensors, dim=0)
        else:
            # Fallback: resize current frame and repeat (for backward compatibility)
            obs_tensor = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
            obs_resized = F.interpolate(obs_tensor, size=(256, 256), mode='nearest').squeeze(0)
            obs_resized_seq = obs_resized.repeat(context, 1, 1, 1)  # (T, C, H, W)

        # Rearrange to (1, C, T, H, W) for V-JEPA
        obs_seq = obs_resized_seq.unsqueeze(0).permute(0, 2, 1, 3, 4)  # (1, C, T, H, W)

        # Store context for later use in AC predictor
        self._current_T_context = context

        # 3. Preprocess through V-JEPA preprocessor
        proc_list = []
        for t_idx in range(context):
            frame = obs_seq[0, :, t_idx]  # Get (C, H, W) frame
            if self.preprocessor is not None:
                # Convert to (H, W, C) numpy array for preprocessor
                frame_np = frame.cpu().permute(1, 2, 0).numpy()  # (H, W, C)
                # Ensure values are in valid range [0, 1]
                frame_np = np.clip(frame_np, 0, 1)
                out = self.preprocessor([frame_np])
                out = out[0] if isinstance(out, (list, tuple)) else out
                if isinstance(out, torch.Tensor):
                    # Ensure output is (C, H, W)
                    if out.dim() == 4:  # (C, T, H, W) or (B, C, H, W)
                        if out.shape[1] == 3 or out.shape[0] == 3:
                            # Likely (C, T, H, W) or (B, C, H, W)
                            if out.shape[0] == 3:  # (C, T, H, W)
                                out = out[:, 0]  # Take first temporal slice -> (C, H, W)
                            else:  # (B, C, H, W)
                                out = out[0]  # Take first batch -> (C, H, W)
                    elif out.dim() == 3:
                        # Could be (C, H, W) or (H, W, C)
                        if out.shape[0] != 3 and out.shape[2] == 3:
                            out = out.permute(2, 0, 1)  # (H, W, C) -> (C, H, W)
                    proc_list.append(out)
                else:
                    # Convert numpy to tensor if needed
                    if isinstance(out, np.ndarray):
                        out = torch.from_numpy(out)
                        # Ensure (C, H, W) format
                        if out.dim() == 3:
                            if out.shape[0] != 3 and out.shape[2] == 3:
                                out = out.permute(2, 0, 1)  # (H, W, C) -> (C, H, W)
                    proc_list.append(out)
            else:
                proc_list.append(frame)

        # Stack processed frames: (T, C, H, W) -> (1, C, T, H, W)
        # Verify all frames have shape (C=3, H, W)
        for i, frame in enumerate(proc_list):
            if frame.dim() != 3 or frame.shape[0] != 3:
                raise RuntimeError(f"Frame {i} has unexpected shape: {frame.shape}, expected (3, H, W)")

        # Each item in proc_list is (C, H, W), stack to (T, C, H, W)
        stacked = torch.stack(proc_list, dim=0)  # (T, C, H, W)
        # Rearrange to (1, C, T, H, W)
        proc = stacked.permute(1, 0, 2, 3).unsqueeze(0).contiguous()  # (1, C, T, H, W)
        proc = proc.to(self.device)  # (1, C, T, 256, 256)

        # 4. Encode through V-JEPA (with optional action conditioning)
        if self.embed_actions_in_encoder and actions is not None and self.action_embedding is not None:
            # Option 1: Embed actions and add to visual features BEFORE encoding
            # This makes the encoder action-aware from the start
            # Convert actions to embeddings
            actions_tensor = torch.tensor(actions, device=self.device, dtype=torch.long)
            if actions_tensor.dim() == 1:
                # (T,) -> embed -> (T, D)
                action_embeds = self.action_embedding(actions_tensor)  # (T, 1024)
                # Broadcast to match spatial tokens: we'll add this after encoding
                # For now, we encode normally and add action info after
                out = self.vjepa_model(proc)
            else:
                # Already embedded or one-hot
                out = self.vjepa_model(proc)
        else:
            # Standard encoding without action embedding
            out = self.vjepa_model(proc)

        # 5. Extract features
        if isinstance(out, (list, tuple)) and len(out) > 0:
            tokens = out[0]
        elif isinstance(out, dict):
            tokens = out.get('x', list(out.values())[0])
        else:
            tokens = out

        # Get center temporal slice
        if tokens.dim() == 4:  # (B, T', N, D)
            tokens = tokens[:, tokens.size(1) // 2]  # Center frame

        # tokens is now (B, N, D) where N = Ht * Wt

        # 5b. Optionally inject action information into tokens after encoding
        # This adds action context to visual tokens (alternative to AC predictor)
        if self.embed_actions_in_encoder and actions is not None and self.action_embedding is not None:
            # Get action embedding for most recent action
            if isinstance(actions, (list, tuple)):
                recent_action = actions[-1] if len(actions) > 0 else 0
            else:
                recent_action = actions[-1].item() if hasattr(actions[-1], 'item') else actions[-1]

            action_embed = self.action_embedding(torch.tensor([recent_action], device=self.device))  # (1, D)
            # Add action embedding to all spatial tokens (broadcast)
            # This makes tokens action-aware: tokens += action_embed
            tokens = tokens + action_embed.unsqueeze(1)  # (B, N, D) + (1, 1, D) -> (B, N, D)

        # 6. Apply action-conditioned predictor if available
        if self.use_ac and self.ac_predictor is not None:
                        # Get context from the current processing
            context = getattr(self, '_current_T_context', max(T_CONTEXT, 1))
            
            # Prepare actions for predictor
            if actions is None:
                # Use zero/no-op action for each timestep
                if self.action_dim is not None:
                    actions = torch.zeros(context, self.action_dim, device=self.device)
                else:
                    actions = torch.zeros(context, dtype=torch.long, device=self.device)
            else:
                # Ensure actions are on correct device and have correct shape
                actions = torch.tensor(actions, device=self.device)
                if actions.dim() == 1:
                    # Convert to one-hot if needed
                    if self.action_dim is not None and actions.dtype != torch.float32:
                        actions_onehot = torch.zeros(len(actions), self.action_dim, device=self.device)
                        actions_onehot.scatter_(1, actions.long().unsqueeze(1), 1.0)
                        actions = actions_onehot
                elif actions.dim() == 2 and actions.shape[1] != self.action_dim:
                    # Might be one-hot already, check shape matches
                    if actions.shape[1] == self.action_dim:
                        pass  # Already correct
                    else:
                        # Assume it's action indices, convert to one-hot
                        if self.action_dim is not None:
                            actions_onehot = torch.zeros(len(actions), self.action_dim, device=self.device)
                            actions_onehot.scatter_(1, actions.long().unsqueeze(1), 1.0)
                            actions = actions_onehot

            # Apply AC predictor: requires tokens, states, and actions
            # For Atari, we don't have explicit robot states, so we use tokens as states
            # Based on V-JEPA2-AC paper, states typically represent robot state (end-effector, joints)
            # For visual-only environments like Atari, visual tokens serve as the state representation

            try:
                # Prepare tokens with temporal dimension if needed
                if tokens.dim() == 3:  # (B, N, D)
                    tokens_with_time = tokens.unsqueeze(1)  # (B, 1, N, D)
                else:
                    tokens_with_time = tokens

                # Prepare actions with batch dimension if needed
                if actions.dim() == 2:  # (T, action_dim) or (T,)
                    actions_with_batch = actions.unsqueeze(0)  # (1, T, action_dim) or (1, T)
                else:
                    actions_with_batch = actions

                # For Atari: use tokens as states (visual state representation)
                # In robot tasks, states would be end-effector positions, joint angles, etc.
                # Here, the spatial tokens encode the visual state of the game
                states = tokens_with_time

                # Try signature: predictor(tokens, states, actions) - common order in robotics
                try:
                    conditioned_tokens = self.ac_predictor(tokens_with_time, states, actions_with_batch)
                except TypeError:
                    # Try alternative order: predictor(tokens, actions, states)
                    conditioned_tokens = self.ac_predictor(tokens_with_time, actions_with_batch, states)

                # Handle output dimensions
                if conditioned_tokens.dim() == 4:  # (B, T, N, D)
                    conditioned_tokens = conditioned_tokens[:, 0]  # Take first timestep -> (B, N, D)
                elif conditioned_tokens.dim() == 3:  # (B, N, D)
                    pass  # Already correct

                tokens = conditioned_tokens
            except Exception as e:
                # Fallback: try different argument patterns
                try:
                    # Try: predictor(states, actions) - some implementations combine tokens+states
                    if tokens.dim() == 3:
                        tokens_with_time = tokens.unsqueeze(1)
                    else:
                        tokens_with_time = tokens
                    if actions.dim() == 2:
                        actions_with_batch = actions.unsqueeze(0)
                    else:
                        actions_with_batch = actions
                    states = tokens_with_time
                    conditioned_tokens = self.ac_predictor(states, actions_with_batch)
                    if conditioned_tokens.dim() == 4:
                        conditioned_tokens = conditioned_tokens[:, 0]
                    tokens = conditioned_tokens
                except Exception as e2:
                    # Final fallback: try minimal signature or disable AC
                    try:
                        conditioned_tokens = self.ac_predictor(tokens_with_time, actions_with_batch)
                        if conditioned_tokens.dim() == 4:
                            conditioned_tokens = conditioned_tokens[:, 0]
                        tokens = conditioned_tokens
                    except Exception as e3:
                        print(f"Warning: Could not apply AC predictor: {e3}")
                        print(f"Note: V-JEPA2-AC was designed for robot control with explicit states.")
                        print(f"For Atari, consider disabling AC (USE_AC=0) or use visual tokens as states.")
                        print(f"Continuing with unconditioned features.")

        # 7. Pool spatially for policy: average pool or flatten
        features = tokens.mean(dim=1)  # (B, D) - average pooling over spatial tokens

        return features.squeeze(0)  # (D,)


In [None]:
class CNNFeatureExtractor(nn.Module):
    """Simple CNN feature extractor for Atari frames (baseline without V-JEPA)."""

    def __init__(self, context_frames: int, device: str, output_dim: int = 256):
        super().__init__()
        self.context = max(context_frames, 1)
        in_channels = self.context * 3  # RGB frames stacked along channel dimension
        self.device = device

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, output_dim),
            nn.ReLU(),
        )

        self.flat_feature_dim = output_dim

    def _stack_frames(self, obs, obs_history):
        frames = []
        if obs_history is not None and len(obs_history) > 0:
            frames.extend(obs_history)
        frames.append(obs)

        if len(frames) > self.context:
            frames = frames[-self.context:]
        elif len(frames) < self.context:
            frames = [frames[0]] * (self.context - len(frames)) + frames

        frame_array = np.stack(frames, axis=0)
        frame_tensor = torch.from_numpy(frame_array).to(self.device, dtype=torch.float32)
        # (context, C, H, W) -> (1, context*C, H, W)
        frame_tensor = frame_tensor.reshape(1, self.context * 3, frame_tensor.size(-2), frame_tensor.size(-1))
        return frame_tensor

    def forward(self, obs, obs_history=None, actions=None):
        x = self._stack_frames(obs, obs_history)
        feats = self.conv(x)
        feats = feats.view(feats.size(0), -1)
        feats = self.fc(feats)
        return feats.squeeze(0)



## Policy Network and Training Setup


In [None]:
class LinearPolicy(nn.Module):
    def __init__(self, latent_dim, action_dim, hidden_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, z):
        return self.net(z)

# Initialize feature extractor and policy
action_dim = env.action_space.n

if USE_VJEPA_ENCODER:
    print("Using V-JEPA2 feature extractor (frozen).")
    encoder = VJEPA2FeatureExtractor(
        vjepa_model,
        preprocessor,
        device,
        ac_predictor=vjepa_ac_predictor,
        action_dim=action_dim,
        embed_actions_in_encoder=EMBED_ACTIONS
    ).to(device)
    latent_dim = encoder.flat_feature_dim
    policy = LinearPolicy(latent_dim, action_dim).to(device)

    params = list(policy.parameters())
    if encoder.embed_actions_in_encoder and encoder.action_embedding is not None:
        params += list(encoder.action_embedding.parameters())
        print(f"Optimizer will update {sum(p.numel() for p in encoder.action_embedding.parameters())} action-embedding parameters")
    optimizer = torch.optim.Adam(params, lr=LEARNING_RATE)
else:
    print("Using simple CNN feature extractor (trainable baseline).")
    encoder = CNNFeatureExtractor(T_CONTEXT, device).to(device)
    latent_dim = encoder.flat_feature_dim
    policy = LinearPolicy(latent_dim, action_dim).to(device)
    optimizer = torch.optim.Adam(list(policy.parameters()) + list(encoder.parameters()), lr=LEARNING_RATE)

print(f"Encoder feature dim: {latent_dim}")
print(f"Action dim: {action_dim}")
if USE_VJEPA_ENCODER:
    print(f"Using action-conditioned predictor (AC): {encoder.use_ac}")
    print(f"Using action embedding in encoder: {encoder.embed_actions_in_encoder}")

Created action embedding layer: 18 actions -> 1024 dim
V-JEPA2 feature dims: token_dim=1024, spatial_grid=22x22
Optimizer includes 18432 action embedding parameters
Encoder feature dim: 1024
Action dim: 18
Using action-conditioned predictor (AC): False
Using action embedding in encoder: True
Action embedding parameters will be learned during policy training


  self.gen = func(*args, **kwds)


In [None]:
# Quick sanity check on feature extractor
sample_obs, _ = env.reset()
with torch.no_grad():
    sample_features = encoder(sample_obs, obs_history=None, actions=None)
print(f"Sample feature vector shape: {sample_features.shape}")


## Policy Training Loop


In [None]:
def run_episode(env, encoder, policy, gamma=GAMMA, seed=None, save_video=False):
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        env.reset(seed=seed)

    obs, info = env.reset()
    log_probs, rewards = [], []
    done = False
    total_reward = 0

    raw_frames = []

    # Track action and observation history for temporal context
    context = getattr(encoder, "context", max(T_CONTEXT, 1))
    action_history = []
    obs_history = []  # Store previous observations for temporal sequence

    while not done:
        frame = (obs * 255).astype(np.uint8).transpose(1, 2, 0)
        raw_frames.append(frame)

        # Prepare action history for AC predictor / action embedding
        uses_ac = getattr(encoder, "use_ac", False)
        if uses_ac:
            if len(action_history) == 0:
                # First step: use no-op actions (action 0 typically)
                actions_for_encoder = [0] * context
            else:
                recent_actions = action_history[-context:]
                actions_for_encoder = [0] * (context - len(recent_actions)) + recent_actions
        else:
            actions_for_encoder = None

        # Extract features using the configured encoder
        z = encoder(obs, obs_history=obs_history, actions=actions_for_encoder)
        logits = policy(z)

        # Mask invalid actions if needed
        valid = env.action_space.n
        logits[valid:] = -1e9

        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()

        log_probs.append(dist.log_prob(action))

        # Extract action as scalar (handle both batched and unbatched cases)
        if action.dim() > 0:
            action_scalar = action.squeeze().item()
        else:
            action_scalar = action.item()

        # Store action and observation for next step's history
        action_history.append(action_scalar)
        obs_history.append(obs.copy())  # Store current observation before step

        # Keep only last context-1 observations (current obs will be added next step)
        if len(obs_history) >= context:
            obs_history = obs_history[-(context-1):]

        obs, reward, terminated, truncated, _ = env.step(action_scalar)
        done = terminated or truncated
        rewards.append(reward)
        total_reward += reward

    if save_video:
        imageio.mimsave("vjepa2_raw.gif", raw_frames, fps=15)
        print("Saved: vjepa2_raw.gif")

    returns, G = [], 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns, dtype=torch.float32, device=device)
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    loss = -torch.sum(torch.stack(log_probs) * returns)
    return loss, float(total_reward)


## Training


In [None]:
if "Pong" in GAME_ID:
    print("Pong theoretical max episodic reward: 21 (win 21-0).")

reward_history = []
best_reward = float("-inf")

for ep in range(NUM_EPISODES):
    policy.train()
    if USE_VJEPA_ENCODER:
        encoder.eval()
    else:
        encoder.train()

    loss, total_reward = run_episode(env, encoder, policy, seed=ep)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    reward_history.append(float(total_reward))
    torch.cuda.empty_cache()

    best_reward = max(best_reward, total_reward)
    print(f"Episode {ep+1}/{NUM_EPISODES} | Reward: {total_reward:.1f} | Best so far: {best_reward:.1f}")

print(f"Training complete. Best episodic reward achieved: {best_reward:.1f}")


Episode 1/200 | Reward: 130.0
Episode 2/200 | Reward: 150.0
Episode 3/200 | Reward: 230.0
Episode 4/200 | Reward: 110.0
Episode 5/200 | Reward: 140.0
Episode 6/200 | Reward: 110.0
Episode 7/200 | Reward: 110.0
Episode 8/200 | Reward: 180.0
Episode 9/200 | Reward: 260.0
Episode 10/200 | Reward: 280.0
Episode 11/200 | Reward: 260.0
Episode 12/200 | Reward: 230.0
Episode 13/200 | Reward: 220.0
Episode 14/200 | Reward: 140.0
Episode 15/200 | Reward: 340.0
Episode 16/200 | Reward: 270.0
Episode 17/200 | Reward: 250.0
Episode 18/200 | Reward: 140.0
Episode 19/200 | Reward: 260.0
Episode 20/200 | Reward: 140.0
Episode 21/200 | Reward: 140.0
Episode 22/200 | Reward: 250.0
Episode 23/200 | Reward: 140.0
Episode 24/200 | Reward: 140.0
Episode 25/200 | Reward: 250.0
Episode 26/200 | Reward: 230.0
Episode 27/200 | Reward: 230.0
Episode 28/200 | Reward: 140.0
Episode 29/200 | Reward: 230.0
Episode 30/200 | Reward: 140.0
Episode 31/200 | Reward: 140.0
Episode 32/200 | Reward: 140.0
Episode 33/200 | 

## Results Visualization


In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
plt.figure()
plt.plot(reward_history)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Policy Training Reward Curve")
plt.tight_layout()
plt.show()

UsageError: unrecognized arguments: # or: %config InlineBackend.figure_format = 'retina'


In [None]:
# === EVALUATE AND RECORD A GAME RUN ===
import numpy as np, torch, imageio
from PIL import Image, ImageDraw, ImageFont

@torch.no_grad()
def evaluate_and_record(env, encoder, policy, seed=0, save_path="vjepa2_eval.mp4",
                        fps=30, overlay_actions=True):
    torch.manual_seed(seed)
    np.random.seed(seed)
    policy.eval()

    obs, info = env.reset(seed=seed)
    done = False
    total_reward = 0.0

    # Match training-time context tracking
    context = getattr(encoder, "context", max(T_CONTEXT, 1))
    action_history, obs_history = [], []
    frames = []

    # Helper to render RGB HxWxC uint8 (from normalized 84×84 obs)
    def obs_to_rgb(o):
        return (np.clip(o, 0, 1) * 255).astype(np.uint8).transpose(1, 2, 0)

    while not done:
        uses_ac = getattr(encoder, "use_ac", False)
        if uses_ac:
            if len(action_history) == 0:
                actions_for_encoder = [0] * context
            else:
                recent = action_history[-context:]
                actions_for_encoder = [0] * (context - len(recent)) + recent
        else:
            actions_for_encoder = None

        # Encode (with temporal obs history)
        z = encoder(obs, obs_history=obs_history, actions=actions_for_encoder)
        logits = policy(z)

        # Sample action
        valid = env.action_space.n
        logits[valid:] = -1e9
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        action_scalar = action.item()

        # Render current frame (before step) with optional overlay
        if save_path or overlay_actions:
            frame = obs_to_rgb(obs)
            if overlay_actions:
                img = Image.fromarray(frame)
                draw = ImageDraw.Draw(img)
                txt = f"action={action_scalar}"
                draw.rectangle([0, 0, 140, 20], fill=(0, 0, 0, 160))
                draw.text((5, 3), txt, fill=(255, 255, 255))
                frame = np.array(img)
            frames.append(frame)

        # Update histories for next step
        action_history.append(action_scalar)
        obs_history.append(obs.copy())
        if len(obs_history) >= context:
            obs_history = obs_history[-(context-1):]

        # Step env
        obs, reward, terminated, truncated, _ = env.step(action_scalar)
        total_reward += reward
        done = terminated or truncated

    # Save video (optional)
    if save_path:
        with imageio.get_writer(save_path, fps=fps, macro_block_size=1) as w:
            for f in frames:
                w.append_data(f)
        print(f"Saved video to {save_path} | Episode return: {total_reward:.1f}")
    else:
        print(f"Episode return: {total_reward:.1f}")
    return total_reward

# Usage (re-use your existing env/encoder/policy from training)
_ = evaluate_and_record(env, encoder, policy, seed=123, save_path="vjepa2_eval.mp4", fps=30, overlay_actions=True)

  self.gen = func(*args, **kwds)


Saved video to vjepa2_eval.mp4 | Episode return: 250.0


In [None]:
def make_eval_env(seed=None):
    env = gym.make(GAME_ID)
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.TransformObservation(env, func=transform_obs, observation_space=new_obs_space)
    if seed is not None:
        env.reset(seed=seed)
    return env


def evaluate_policy(encoder, policy, episodes=EVAL_EPISODES, seed_offset=1000):
    policy.eval()
    returns = []
    for idx in range(episodes):
        eval_env = make_eval_env(seed=seed_offset + idx)
        reward = evaluate_and_record(eval_env, encoder, policy, seed=seed_offset + idx,
                                     save_path=None, overlay_actions=False)
        returns.append(reward)
        eval_env.close()
    returns = np.array(returns)
    print(f"Evaluation over {episodes} episode(s): mean={returns.mean():.2f}, std={returns.std():.2f}, max={returns.max():.2f}")
    return returns



## Notes

### Pipeline Overview

This notebook implements a policy training pipeline using V-JEPA2 encoder as a frozen feature extractor:

1. **Environment**: Atari Alien with 84×84 RGB observations
2. **Encoder**: V-JEPA2 (ViT-L) - pre-trained, frozen
3. **Feature Extraction**:
   - Resize observations from 84×84 → 256×256 (V-JEPA expects 256×256)
   - Create temporal context using the last T=4 frames (real history)
   - Extract spatial tokens from the center temporal slice
   - Average pool over spatial tokens to get flat features
4. **Policy**: 2-layer MLP (linear → ReLU → linear)
5. **Training**: REINFORCE with returns normalization

### Key Differences from RSSM Pipeline

- **Image Resolution**: RSSM uses 64×64; V-JEPA uses 84×84 → 256×256
- **Temporal Context**: V-JEPA uses explicit multi-frame input; RSSM uses an RNN state
- **Feature Extraction**: V-JEPA outputs spatial tokens that we pool; RSSM learns a global embedding
- **Dynamics**: No learned dynamics here; encoder features are used directly

### Future Improvements

- Consider using the V-JEPA-AC variant for action-conditioned features
- Explore different spatial pooling strategies (attention-weighted, max pooling, etc.)
- Add a dynamics model on top of V-JEPA features (as done in RSSM)
- Implement imagination-based training using predicted futures


In [None]:
# Cleanup
try:
    env.close()
except NameError:
    pass

