In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir ./til_marl_tensorboard_rs_explore_v7/

In [None]:
from __future__ import annotations

import glob
import os
import functools
from functools import partial

import imageio
import numpy as np # For frame manipulation
from stable_baselines3 import PPO

from til_environment import gridworld
from til_environment.gridworld import NUM_ITERS
from til_environment.flatten_dict import FlattenDictWrapper
from supersuit import frame_stack_v2
from til_environment.types import RewardNames, Action

from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType
from pettingzoo.utils.wrappers.base import BaseWrapper

# --- Custom Wrapper to Clip Step Observation (from training script) ---
class ClipStepObservationWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
    def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
        super().__init__(env)

    def observe(self, agent: AgentID) -> ObsType | None:
        obs = super().observe(agent)
        if obs is not None and isinstance(obs, dict) and "step" in obs:
            obs["step"] = min(obs["step"], NUM_ITERS - 1)
        return obs

    @functools.lru_cache(maxsize=None)
    def observation_space(self, agent: AgentID):
        return super().observation_space(agent)

    @functools.lru_cache(maxsize=None)
    def action_space(self, agent: AgentID):
        return super().action_space(agent)

# --- Define constants and configurations relevant to v6 model ---
REWARDS_DICT_V6 = {
    RewardNames.GUARD_CAPTURES: 50,
    RewardNames.SCOUT_CAPTURED: -50,
    RewardNames.SCOUT_RECON: 5,
    RewardNames.SCOUT_MISSION: 20,
    RewardNames.WALL_COLLISION: -0.5,
    RewardNames.STATIONARY_PENALTY: -0.1,
    RewardNames.SCOUT_STEP: -0.01,
    RewardNames.GUARD_STEP: -0.01,
    RewardNames.GUARD_TRUNCATION: -10,
    RewardNames.SCOUT_TRUNCATION: 0,
}
print(f"Using v6 rewards_dict for evaluation: {REWARDS_DICT_V6}")

CUSTOM_ENV_WRAPPERS_LIST_V6 = [
    ClipStepObservationWrapper,
    FlattenDictWrapper,
    partial(frame_stack_v2, stack_size=4, stack_dim=-1),
]

RANDOM_SEED = 42
NOVICE_MODE = False
MODEL_NAME_FOR_VIDEOS = "til_ai_marl_ppo_rs_explore_v6"

def create_placeholder_frame(frame_shape, color=(0,0,0)):
    """Creates a blank frame of a given shape and color."""
    ph_frame = np.full(frame_shape, color, dtype=np.uint8)
    return ph_frame

def create_composite_frame(frames_list, grid_layout=(2,2)):
    """
    Creates a composite frame by arranging frames in a grid.
    Assumes all frames in frames_list have the same shape.
    Handles cases where fewer frames are provided than grid slots by using placeholders.
    """
    if not frames_list or not any(f is not None for f in frames_list):
        return None # Or a default blank image of expected composite size

    # Determine frame shape from the first valid frame
    first_valid_frame = next((f for f in frames_list if f is not None), None)
    if first_valid_frame is None: # Should not happen if a game ran
        return None
    frame_h, frame_w, frame_c = first_valid_frame.shape
    placeholder = create_placeholder_frame((frame_h, frame_w, frame_c))

    rows, cols = grid_layout
    
    # Ensure frames_list has enough elements for the grid, padding with placeholders
    padded_frames = list(frames_list) # Make a mutable copy
    while len(padded_frames) < rows * cols:
        padded_frames.append(placeholder)
    
    # Replace None frames with placeholders
    for i in range(len(padded_frames)):
        if padded_frames[i] is None:
            padded_frames[i] = placeholder

    grid_rows = []
    for i in range(rows):
        row_frames = padded_frames[i*cols : (i+1)*cols]
        if not row_frames: continue # Should not happen with padding
        # If a row has fewer frames than cols (e.g. 3 games for 2x2), pad it
        while len(row_frames) < cols:
            row_frames.append(placeholder)
        grid_rows.append(np.hstack(row_frames))
    
    if not grid_rows:
        return placeholder # Fallback, should not happen
    
    composite = np.vstack(grid_rows)
    return composite

# --- Evaluation Function ---
def evaluate_til_ai_marl(
    num_games_to_display_concurrently: int = 4, # Number of games for the grid display
    render_mode: str | None = None, # MUST be "rgb_array" for concurrent video
    novice_eval: bool = True,
    model_to_load_path: str | None = None,
    eval_device: str = "cpu",
    print_actions: bool = False
):
    if model_to_load_path is None:
        print("Error: model_to_load_path must be provided.")
        return
    
    if render_mode != "rgb_array" and num_games_to_display_concurrently > 0:
        print("Warning: For concurrent video, render_mode must be 'rgb_array'.")
        print("Video generation for concurrent display will be skipped.")
        # Fallback to sequential non-video evaluation if not rgb_array
        # Or, you could force render_mode to rgb_array here if a concurrent display is explicitly requested.

    print(f"\nStarting evaluation for model: {model_to_load_path}")
    print(f"Preparing to display {num_games_to_display_concurrently} games concurrently in video.")
    print(f"(Novice: {novice_eval}, Render: {render_mode}, Device: {eval_device})...")

    eval_rewards_dict = REWARDS_DICT_V6
    eval_env_wrappers = CUSTOM_ENV_WRAPPERS_LIST_V6

    eval_env_kwargs = {
        "render_mode": "rgb_array", # Force rgb_array for frame collection
        "novice": novice_eval,
        "env_wrappers": eval_env_wrappers,
        "rewards_dict": eval_rewards_dict,
    }

    eval_env = gridworld.env(**eval_env_kwargs) # One env instance, used sequentially
            
    try:
        loaded_model = PPO.load(model_to_load_path, device=eval_device)
        print(f"Successfully loaded model: {model_to_load_path}")
    except Exception as e:
        print(f"Error loading model {model_to_load_path}: {e}")
        eval_env.close()
        return

    total_rewards_all_games_stats = {agent: 0.0 for agent in eval_env.possible_agents} # For stats
    
    video_folder = f"./logs/videos_{MODEL_NAME_FOR_VIDEOS}_eval_concurrent/"
    os.makedirs(video_folder, exist_ok=True)

    all_runs_all_frames = [] # List to store lists of frames for each game run

    for game_idx in range(num_games_to_display_concurrently):
        print(f"\n--- Simulating Game {game_idx + 1} of {num_games_to_display_concurrently} ---")
        eval_env.reset(seed=RANDOM_SEED + game_idx + 3000) # Different seed for each game
        
        current_game_collected_frames = []
        game_rewards_this_round = {agent: 0.0 for agent in eval_env.possible_agents}
        # action_counts_this_game = {act: 0 for act in Action} # If needed per game

        for agent_id in eval_env.agent_iter():
            observation, reward, termination, truncation, info = eval_env.last()
            game_rewards_this_round[agent_id] += reward

            if termination or truncation: action_val = None
            else: action_val, _ = loaded_model.predict(observation, deterministic=False)
            
            eval_env.step(action_val)

            if print_actions and action_val is not None:
                print(f"Game {game_idx+1}, Agent {agent_id}, Action: {Action(action_val).name}")

            # Always render if we intend to make a video
            try:
                frame = eval_env.render() # mode is already rgb_array from env creation
                if frame is not None:
                    current_game_collected_frames.append(frame)
            except Exception as e:
                print(f"Warning: Could not render frame for game {game_idx+1}, agent {agent_id}: {e}")
        
        for agent_id_sum in eval_env.possible_agents:
            total_rewards_all_games_stats[agent_id_sum] += game_rewards_this_round[agent_id_sum]
        
        all_runs_all_frames.append(current_game_collected_frames)
        print(f"Game {game_idx + 1} finished. Rewards: {game_rewards_this_round}. Collected {len(current_game_collected_frames)} frames.")

    eval_env.close() # Close the single environment instance

    # --- Post-simulation: Create Composite Video ---
    if render_mode == "rgb_array" and all_runs_all_frames and any(all_runs_all_frames):
        print("\n--- Creating Composite Video ---")
        # Determine grid layout (e.g., 2x2 for 4 games, 1xN or Nx1 for others)
        if num_games_to_display_concurrently == 0:
             print("No games specified for concurrent display.")
        elif num_games_to_display_concurrently == 1:
            grid_layout = (1,1)
        elif num_games_to_display_concurrently == 2:
            grid_layout = (1,2) # Side-by-side
        elif num_games_to_display_concurrently == 3:
            grid_layout = (1,3) # Side-by-side
        elif num_games_to_display_concurrently == 4:
            grid_layout = (2,2)
        else: # Default for > 4, could be smarter
            cols = int(np.ceil(np.sqrt(num_games_to_display_concurrently)))
            rows = int(np.ceil(num_games_to_display_concurrently / cols))
            grid_layout = (rows, cols)
            print(f"Using grid layout: {grid_layout} for {num_games_to_display_concurrently} games.")

        max_frames_in_any_run = 0
        if all_runs_all_frames:
             max_frames_in_any_run = max(len(frames) for frames in all_runs_all_frames if frames) if any(all_runs_all_frames) else 0


        composite_video_frames = []
        example_frame_shape = None
        if max_frames_in_any_run > 0:
            for game_frames in all_runs_all_frames:
                if game_frames:
                    example_frame_shape = game_frames[0].shape
                    break
        
        if example_frame_shape is None and num_games_to_display_concurrently > 0 :
             print("Could not determine frame shape. Skipping composite video.")
        elif num_games_to_display_concurrently > 0 :
            placeholder_frame = create_placeholder_frame(example_frame_shape)

            for t in range(max_frames_in_any_run):
                frames_at_this_timestep = []
                for game_idx in range(num_games_to_display_concurrently):
                    if game_idx < len(all_runs_all_frames) and t < len(all_runs_all_frames[game_idx]):
                        frames_at_this_timestep.append(all_runs_all_frames[game_idx][t])
                    else: # Game ended or no frame
                        # Use last available frame for that game, or placeholder
                        if game_idx < len(all_runs_all_frames) and all_runs_all_frames[game_idx]:
                             frames_at_this_timestep.append(all_runs_all_frames[game_idx][-1])
                        else:
                             frames_at_this_timestep.append(placeholder_frame)
                
                composite = create_composite_frame(frames_at_this_timestep, grid_layout)
                if composite is not None:
                    composite_video_frames.append(composite)
            
            if composite_video_frames:
                video_file_suffix = f"concurrent_{grid_layout[0]}x{grid_layout[1]}_{num_games_to_display_concurrently}games"
                composite_video_filename = f"{os.path.basename(model_to_load_path).replace('.zip', '')}_{video_file_suffix}.mp4"
                composite_video_path = os.path.join(video_folder, composite_video_filename)
                try:
                    imageio.mimsave(composite_video_path, composite_video_frames, fps=eval_env.metadata.get("render_fps", 10))
                    print(f"Saved composite video to {composite_video_path}")
                except Exception as e:
                    print(f"Error saving composite video: {e}")
            else:
                print("No composite frames generated for video.")

    # --- Final Statistics ---
    print("\n--- Overall Evaluation Summary ---")
    avg_rewards_per_agent = {
        agent: total_rewards_all_games_stats[agent] / num_games_to_display_concurrently for agent in eval_env.possible_agents
    }
    print(f"Average rewards per agent over {num_games_to_display_concurrently} simulated games (using v6 training rewards): {avg_rewards_per_agent}")
    
    if eval_env.possible_agents: # Check if possible_agents is not empty
        team_avg_reward = sum(avg_rewards_per_agent.values())
        print(f"Sum of average rewards for all agents (team perspective): {team_avg_reward:.4f}")
        print(f"Note: The avg_rewards_per_agent is based on the v6 training rewards_dict values.")

# --- Run Evaluation ---
if __name__ == "__main__":
    PATH_TO_V6_MODEL = "./til_ai_marl_ppo_rs_explore_v6_20250520-071438.zip" # YOUR MODEL PATH

    if not os.path.exists(PATH_TO_V6_MODEL):
        print(f"ERROR: Model file not found at {PATH_TO_V6_MODEL}")
        print("Please update PATH_TO_V6_MODEL with the correct path to your .zip file.")
    else:
        # Example: Render 4 games into a single 2x2 concurrent video
        evaluate_til_ai_marl(
            num_games_to_display_concurrently=4, # Number of games for the grid
            render_mode="rgb_array",      # Must be "rgb_array" for this to work
            novice_eval=NOVICE_MODE,      # Should match training
            model_to_load_path=PATH_TO_V6_MODEL,
            eval_device="cpu",            # Can be "cuda" if preferred and available
            print_actions=False           # Typically false when focusing on video
        )
        
        # Example: Render 2 games side-by-side
        # evaluate_til_ai_marl(
        #     num_games_to_display_concurrently=2,
        #     render_mode="rgb_array",
        #     novice_eval=NOVICE_MODE,
        #     model_to_load_path=PATH_TO_V6_MODEL,
        #     eval_device="cpu",
        #     print_actions=False
        # )

        video_output_dir = os.path.abspath(f"./logs/videos_{MODEL_NAME_FOR_VIDEOS}_eval_concurrent/")
        print(f"\nTo view concurrent videos, check the '{video_output_dir}' directory.")

In [2]:
"""Manages the RL model."""

import os
from collections import deque
import numpy as np
from stable_baselines3 import PPO
from gymnasium.spaces import Box, Discrete, Dict, flatten

# Assuming til_environment is in the python path or PYTHONPATH is set.
try:
    from til_environment.gridworld import NUM_ITERS
except ImportError:
    print("Warning: Could not import NUM_ITERS from til_environment.gridworld.")
    print("Ensure til_environment is installed and accessible.")
    print("Using a default NUM_ITERS = 100 for clipping. This might be incorrect.")
    NUM_ITERS = 100

class RLManager:
    """
    Manages the RL model for inference, including observation preprocessing
    to match the training setup (clipping, flattening, frame stacking).
    """

    def __init__(self, model_path: str = "til_ai_marl_ppo_rs_explore_v7_20250520-104436.zip"):
        """
        Initializes the RLManager, loads the model, and sets up
        observation processing components.

        Args:
            model_path (str): Path to the trained PPO model .zip file.
        """
        self.observation_space_unflattened = Dict({
            "viewcone": Box(0, 255, (7, 5), np.uint8),
            "direction": Discrete(4),
            "scout": Discrete(2),
            "location": Box(np.array([0, 0]), np.array([15, 15]), (2,), np.int8),
            "step": Box(0, NUM_ITERS - 1, (1,), np.int8)
        })
        print(f"RLManager: Observation space for flattening defined with NUM_ITERS={NUM_ITERS}.")

        absolute_model_path = model_path
        if not os.path.isabs(model_path):
            try:
                # Preferred method: model path relative to this script file
                script_dir = os.path.dirname(os.path.abspath(__file__))
                absolute_model_path = os.path.join(script_dir, model_path)
            except NameError:
                # Fallback for environments where __file__ is not defined (e.g., Jupyter)
                print(f"RLManager Warning: __file__ not defined. Assuming model path '{model_path}' is relative to current working directory.")
                absolute_model_path = os.path.join(os.getcwd(), model_path)
        
        if not os.path.exists(absolute_model_path):
            raise FileNotFoundError(
                f"RLManager Error: Model file not found at {absolute_model_path}. "
                "Please ensure the model .zip file is correctly located."
            )
        
        self.model = PPO.load(absolute_model_path, device="cpu")
        print(f"RLManager: Successfully loaded model from {absolute_model_path}")

        self.frame_stack_size = 4
        self._frame_stack_buffer = deque(maxlen=self.frame_stack_size)

        dummy_unflattened_obs = self.observation_space_unflattened.sample()
        dummy_unflattened_obs["step"] = np.clip(dummy_unflattened_obs["step"], 0, NUM_ITERS - 1).astype(np.int8)
        self._single_flat_obs_shape = flatten(
            self.observation_space_unflattened, dummy_unflattened_obs
        ).shape
        print(f"RLManager: Shape of a single flattened observation: {self._single_flat_obs_shape}")

        self.reset_agent_state()
        print("RLManager: Initialization complete.")

    def _preprocess_observation(self, observation: dict) -> np.ndarray:
        processed_observation = observation.copy()
        if "step" in processed_observation:
            step_val = np.array(processed_observation["step"])
            processed_observation["step"] = np.clip(step_val, 0, NUM_ITERS - 1).astype(np.int8)
        else:
            print("Warning: 'step' not in observation. Adding default 0.")
            processed_observation["step"] = np.array([0], dtype=np.int8)

        if "location" in processed_observation and isinstance(processed_observation["location"], list):
            processed_observation["location"] = np.array(processed_observation["location"], dtype=np.int8)
        if "viewcone" in processed_observation and isinstance(processed_observation["viewcone"], list):
            processed_observation["viewcone"] = np.array(processed_observation["viewcone"], dtype=np.uint8)
        if "step" in processed_observation and isinstance(processed_observation["step"], int):
             processed_observation["step"] = np.array([processed_observation["step"]], dtype=np.int8)

        flat_obs = flatten(self.observation_space_unflattened, processed_observation)
        return flat_obs

    def reset_agent_state(self):
        print("RLManager: Resetting agent state (frame stack).")
        zero_flat_obs = np.zeros(self._single_flat_obs_shape, dtype=np.float32)
        self._frame_stack_buffer.clear()
        for _ in range(self.frame_stack_size):
            self._frame_stack_buffer.append(zero_flat_obs)

    def rl(self, observation: dict[str, int | list[int] | list[list[int]]]) -> int:
        current_flat_obs = self._preprocess_observation(observation)
        current_flat_obs = current_flat_obs.astype(np.float32)
        self._frame_stack_buffer.append(current_flat_obs)
        stacked_observation_list = list(self._frame_stack_buffer)
        model_input = np.concatenate(stacked_observation_list, axis=0)
        action, _ = self.model.predict(model_input, deterministic=True)
        return int(action)

# Example usage (for testing locally, not part of the deployed RLManager typically)
if __name__ == '__main__':
    print("RLManager self-test initiated.")
    # Ensure the model zip file is in the current working directory when running this self-test
    # in an environment where __file__ is not defined (like Jupyter).
    # Or, provide an absolute path to RLManager.
    # model_file_for_test = "til_ai_marl_ppo_rs_explore_v7_20250520-104436.zip"
    
    try:
        # If __file__ is not defined, this will assume the model_path is relative to os.getcwd()
        manager = RLManager() # Uses the default model path from the constructor

        sample_raw_observation = {
            "viewcone": np.random.randint(0, 256, (7, 5), dtype=np.uint8).tolist(),
            "direction": np.random.randint(0, 4),
            "scout": np.random.randint(0, 2),
            "location": np.random.randint(0, 16, (2,), dtype=np.int8).tolist(),
            "step": np.random.randint(0, NUM_ITERS)
        }
        print(f"\nSample raw observation: {sample_raw_observation}")

        for i in range(5):
            action = manager.rl(sample_raw_observation)
            print(f"Step {i+1}: Received observation, predicted action: {action}")
            sample_raw_observation["step"] = min(NUM_ITERS -1, sample_raw_observation["step"] + 1)
            sample_raw_observation["location"][0] = (sample_raw_observation["location"][0] + 1) % 16

        print("\nSimulating reset:")
        manager.reset_agent_state()
        action_after_reset = manager.rl(sample_raw_observation)
        print(f"Action after reset: {action_after_reset}")
        
        print("\nRLManager self-test completed.")

    except FileNotFoundError as e:
        print(f"RLManager Self-Test FileNotFoundError: {e}")
        print("Ensure the model file is in the correct location (e.g., current working directory for Jupyter test, or update path).")
    except Exception as e:
        print(f"An error occurred during RLManager self-test: {e}")
        import traceback
        traceback.print_exc()

RLManager self-test initiated.
RLManager: Observation space for flattening defined with NUM_ITERS=100.
RLManager: Successfully loaded model from /home/jupyter/Tim Testing/til_ai_marl_ppo_rs_explore_v7_20250520-104436.zip
RLManager: Shape of a single flattened observation: (44,)
RLManager: Resetting agent state (frame stack).
RLManager: Initialization complete.

Sample raw observation: {'viewcone': [[102, 220, 225, 95, 179], [61, 234, 203, 92, 3], [98, 243, 14, 149, 245], [46, 106, 244, 99, 187], [71, 212, 153, 199, 188], [174, 65, 153, 20, 44], [203, 152, 102, 214, 240]], 'direction': 1, 'scout': 0, 'location': [6, 4], 'step': 74}
An error occurred during RLManager self-test: Error: Unexpected observation shape (176,) for Box environment, please use (572,) or (n_env, 572) for the observation shape.


Traceback (most recent call last):
  File "/var/tmp/ipykernel_15068/2936968207.py", line 131, in <module>
    action = manager.rl(sample_raw_observation)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/tmp/ipykernel_15068/2936968207.py", line 106, in rl
    action, _ = self.model.predict(model_input, deterministic=True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.12/site-packages/stable_baselines3/common/base_class.py", line 557, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.12/site-packages/stable_baselines3/common/policies.py", line 365, in predict
    obs_tensor, vectorized_env = self.obs_to_tensor(observation)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.12/site-packages/stable_basel