In [3]:
from __future__ import annotations

import glob
import os
import time
import functools # Make sure functools is imported
from functools import partial

import imageio
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

# Import the TIL-AI environment
from til_environment import gridworld
from til_environment.gridworld import NUM_ITERS # Import NUM_ITERS
from til_environment.flatten_dict import FlattenDictWrapper # For explicit wrapper list
from supersuit import frame_stack_v2 # For explicit wrapper list

# Import the SB3 wrapper from SuperSuit and types for monkey-patching
from supersuit.vector.sb3_vector_wrapper import SB3VecEnvWrapper
import types

# For custom wrapper type hints
from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType
from pettingzoo.utils.wrappers.base import BaseWrapper


# --- Custom Wrapper to Clip Step Observation ---
class ClipStepObservationWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
    def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
        super().__init__(env)

    def observe(self, agent: AgentID) -> ObsType | None:
        obs = super().observe(agent)
        if obs is not None and isinstance(obs, dict) and "step" in obs:
            obs["step"] = min(obs["step"], NUM_ITERS - 1)
        return obs

    @functools.lru_cache(maxsize=None)
    def observation_space(self, agent: AgentID):
        return super().observation_space(agent)

    @functools.lru_cache(maxsize=None)
    def action_space(self, agent: AgentID):
        return super().action_space(agent)


# --- Training Parameters ---
TOTAL_TIMESTEPS = 100_000
RANDOM_SEED = 42
MODEL_SAVE_NAME = "til_ai_marl_ppo"
NOVICE_MODE = True
LEARNING_RATE = 0.0003


# --- Environment Setup ---
custom_env_wrappers_list = [
    ClipStepObservationWrapper,
    FlattenDictWrapper,
    partial(frame_stack_v2, stack_size=4, stack_dim=-1),
]

train_env_kwargs = {
    "render_mode": None,
    "novice": NOVICE_MODE,
    "env_wrappers": custom_env_wrappers_list,
}

def create_training_env():
    env = gridworld.parallel_env(**train_env_kwargs)
    env = ss.black_death_v3(env)
    return env

print(f"Initializing training environment (Novice mode: {NOVICE_MODE})...")

pz_env = create_training_env()
ss_markov_vec_env = ss.pettingzoo_env_to_vec_env_v1(pz_env)

def no_op_seed_for_markov_vec_env(self, seed=None):
    pass
ss_markov_vec_env.seed = types.MethodType(no_op_seed_for_markov_vec_env, ss_markov_vec_env)

vec_env = SB3VecEnvWrapper(ss_markov_vec_env)

N_STEPS_PPO = 2048
BATCH_SIZE_PPO = N_STEPS_PPO * vec_env.num_envs

# MODIFICATION: Avoid printing the full observation space directly if it causes RecursionError
# print(f"Observation space: {vec_env.observation_space}") # Problematic line
print(f"Observation space type: {type(vec_env.observation_space)}")
print(f"Observation space shape: {vec_env.observation_space.shape}")
print(f"Observation space dtype: {vec_env.observation_space.dtype}")

print(f"Action space: {vec_env.action_space}") # Action space is usually simpler to print
print(f"Number of environments (agents passed to SB3): {vec_env.num_envs}")


# --- Model Training ---
print(f"Starting training with PPO and MlpPolicy for {TOTAL_TIMESTEPS} timesteps.")
model = PPO(
    MlpPolicy,
    vec_env,
    verbose=1,
    seed=RANDOM_SEED,
    learning_rate=LEARNING_RATE,
    n_steps=N_STEPS_PPO,
    batch_size=BATCH_SIZE_PPO,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    tensorboard_log="./til_marl_tensorboard/",
    device="cpu" 
)

model.learn(total_timesteps=TOTAL_TIMESTEPS, progress_bar=True)

# --- Save Model ---
model_filename = f"{MODEL_SAVE_NAME}_{time.strftime('%Y%m%d-%H%M%S')}.zip"
model_path = os.path.join(".", model_filename)
model.save(model_path)
print(f"Model saved to {model_path}")

vec_env.close()
print("Training finished.")

# --- Evaluation Function ---
def evaluate_til_ai_marl(
    num_games: int = 10,
    render_mode: str | None = None,
    novice_eval: bool = True,
    model_to_load_path: str | None = None
):
    print(f"\nStarting evaluation (Novice: {novice_eval}, Games: {num_games}, Render: {render_mode})...")

    eval_env_kwargs = {
        "render_mode": render_mode,
        "novice": novice_eval,
        "env_wrappers": custom_env_wrappers_list,
    }
    
    eval_env = gridworld.env(**eval_env_kwargs)

    if model_to_load_path is None:
        try:
            list_of_models = glob.glob(f"./{MODEL_SAVE_NAME}*.zip")
            if not list_of_models:
                print("No trained model found for evaluation. Exiting evaluation.")
                eval_env.close()
                return
            model_to_load_path = max(list_of_models, key=os.path.getctime)
            print(f"Loading latest model for evaluation: {model_to_load_path}")
        except ValueError:
            print("Could not find a model to load. Exiting evaluation.")
            eval_env.close()
            return
            
    loaded_model = PPO.load(model_to_load_path, device="cpu")

    total_rewards_all_games = {agent: 0.0 for agent in eval_env.possible_agents}
    
    video_folder = "./logs/videos_marl"
    os.makedirs(video_folder, exist_ok=True)

    for game_num in range(num_games):
        eval_env.reset(seed=RANDOM_SEED + game_num + 1000)
        
        current_game_frames = []
        game_rewards_this_round = {agent: 0.0 for agent in eval_env.possible_agents}

        for agent_id in eval_env.agent_iter():
            observation, reward, termination, truncation, info = eval_env.last()
            
            game_rewards_this_round[agent_id] += reward

            if termination or truncation:
                action = None
            else:
                action, _ = loaded_model.predict(observation, deterministic=True)
            
            eval_env.step(action)

            if render_mode == "rgb_array" and action is not None:
                try:
                    frame = eval_env.render()
                    if frame is not None:
                        current_game_frames.append(frame)
                except Exception as e:
                    print(f"Warning: Could not render frame for game {game_num}, agent {agent_id}: {e}")
        
        for agent_id_sum in eval_env.possible_agents:
            total_rewards_all_games[agent_id_sum] += game_rewards_this_round[agent_id_sum]
        
        print(f"Game {game_num + 1} finished. Rewards this game: {game_rewards_this_round}")

        if render_mode == "rgb_array" and current_game_frames:
            video_path = os.path.join(video_folder, f"{MODEL_SAVE_NAME}_game_{game_num}.mp4")
            try:
                imageio.mimsave(video_path, current_game_frames, fps=eval_env.metadata.get("render_fps", 10))
                print(f"Saved video of game {game_num} to {video_path}")
            except Exception as e:
                print(f"Error saving video for game {game_num}: {e}")
    
    eval_env.close()

    print("\n--- Evaluation Summary ---")
    avg_rewards_per_agent = {
        agent: total_rewards_all_games[agent] / num_games for agent in eval_env.possible_agents
    }
    print(f"Average rewards per agent over {num_games} games: {avg_rewards_per_agent}")
    
    if eval_env.possible_agents:
        team_avg_reward = sum(avg_rewards_per_agent.values())
        print(f"Sum of average rewards for all agents (team perspective): {team_avg_reward:.4f}")
        
        main_agent_example = eval_env.possible_agents[0]
        til_score_example = avg_rewards_per_agent[main_agent_example] / 100
        print(f"Example TIL-AI style score for '{main_agent_example}' (avg per game / 100): {til_score_example:.4f}")

# --- Run Evaluation ---
if __name__ == "__main__":
    evaluate_til_ai_marl(
        num_games=5,
        render_mode=None,
        novice_eval=NOVICE_MODE,
        model_to_load_path=model_path
    )
    
    evaluate_til_ai_marl(
        num_games=1,
        render_mode="rgb_array",
        novice_eval=NOVICE_MODE,
        model_to_load_path=model_path
    )
    print(f"\nTo view videos, check the '{os.path.abspath('./logs/videos_marl/')}' directory.")
    print("To view TensorBoard logs (if enabled and tensorboard installed), run: tensorboard --logdir ./til_marl_tensorboard/")

Initializing training environment (Novice mode: True)...
Observation space type: <class 'gymnasium.spaces.box.Box'>
Observation space shape: (572,)
Observation space dtype: int64
Action space: Discrete(5)
Number of environments (agents passed to SB3): 4
Starting training with PPO and MlpPolicy for 100000 timesteps.
Using cpu device
Logging to ./til_marl_tensorboard/PPO_1


-----------------------------
| time/              |      |
|    fps             | 647  |
|    iterations      | 1    |
|    time_elapsed    | 12   |
|    total_timesteps | 8192 |
-----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 636           |
|    iterations           | 2             |
|    time_elapsed         | 25            |
|    total_timesteps      | 16384         |
| train/                  |               |
|    approx_kl            | 0.00037209538 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1.61         |
|    explained_variance   | -0.155        |
|    learning_rate        | 0.0003        |
|    loss                 | 1.91          |
|    n_updates            | 10            |
|    policy_gradient_loss | -0.00137      |
|    value_loss           | 4.22          |
------------------------------------------

Output()

------------------------------------------
| time/                   |              |
|    fps                  | 624          |
|    iterations           | 3            |
|    time_elapsed         | 39           |
|    total_timesteps      | 24576        |
| train/                  |              |
|    approx_kl            | 0.0002950217 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.61        |
|    explained_variance   | -0.065       |
|    learning_rate        | 0.0003       |
|    loss                 | 1.98         |
|    n_updates            | 20           |
|    policy_gradient_loss | -0.00136     |
|    value_loss           | 4.2          |
------------------------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 615           |
|    iterations           | 4             |
|    time_elapsed         | 53            |
|    t

Model saved to ./til_ai_marl_ppo_20250519-154146.zip
Training finished.

Starting evaluation (Novice: True, Games: 5, Render: None)...
Game 1 finished. Rewards this game: {'player_0': 7.0, 'player_1': 0.0, 'player_2': 0.0, 'player_3': 0.0}
Game 2 finished. Rewards this game: {'player_0': 0.0, 'player_1': 7.0, 'player_2': 0.0, 'player_3': 0.0}
Game 3 finished. Rewards this game: {'player_0': 0.0, 'player_1': 0.0, 'player_2': 7.0, 'player_3': 0.0}
Game 4 finished. Rewards this game: {'player_0': 0.0, 'player_1': 0.0, 'player_2': 0.0, 'player_3': 7.0}
Game 5 finished. Rewards this game: {'player_0': 7.0, 'player_1': 0.0, 'player_2': 0.0, 'player_3': 0.0}

--- Evaluation Summary ---
Average rewards per agent over 5 games: {'player_0': 2.8, 'player_1': 1.4, 'player_2': 1.4, 'player_3': 1.4}
Sum of average rewards for all agents (team perspective): 7.0000
Example TIL-AI style score for 'player_0' (avg per game / 100): 0.0280

Starting evaluation (Novice: True, Games: 1, Render: rgb_array)...

In [4]:
%load_ext tensorboard
%tensorboard --logdir ./til_marl_tensorboard/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 33177), started 0:06:35 ago. (Use '!kill 33177' to kill it.)

In [1]:
!pip install tensorboard --force-reinstall

Collecting tensorboard
  Using cached tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting absl-py>=0.4 (from tensorboard)
  Using cached absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting grpcio>=1.48.2 (from tensorboard)
  Downloading grpcio-1.71.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting markdown>=2.6.8 (from tensorboard)
  Using cached markdown-3.8-py3-none-any.whl.metadata (5.1 kB)
Collecting numpy>=1.12.0 (from tensorboard)
  Downloading numpy-2.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting packaging (from tensorboard)
  Using cached packaging-25.0-py3-none-any.whl.metadata (3.3 kB)
Collecting protobuf!=4.24.0,>=3.19.6 (from tensorboard)
  Using cached protobuf-6.31.0-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)
Collecting setuptools>=41.0.0 (from tensorboard)
  Using cached setuptools-80.7.1-py3-none-any.whl.metadata (6.6 kB)
Collecting six>1.9 (from tensor