In [None]:
from __future__ import annotations

import glob
import os
import time
import functools
from functools import partial

import imageio
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

# Import the TIL-AI environment
from til_environment import gridworld
from til_environment.gridworld import NUM_ITERS
from til_environment.flatten_dict import FlattenDictWrapper
from supersuit import frame_stack_v2
from til_environment.types import RewardNames, Action

from supersuit.vector.sb3_vector_wrapper import SB3VecEnvWrapper
import types

from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType
from pettingzoo.utils.wrappers.base import BaseWrapper

# --- Custom Wrapper to Clip Step Observation ---
class ClipStepObservationWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
    def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
        super().__init__(env)

    def observe(self, agent: AgentID) -> ObsType | None:
        obs = super().observe(agent)
        if obs is not None and isinstance(obs, dict) and "step" in obs:
            obs["step"] = min(obs["step"], NUM_ITERS - 1)
        return obs

    @functools.lru_cache(maxsize=None)
    def observation_space(self, agent: AgentID):
        return super().observation_space(agent)

    @functools.lru_cache(maxsize=None)
    def action_space(self, agent: AgentID):
        return super().action_space(agent)

# --- Reward Shaping Definition (v7) ---
rewards_dict = {
    RewardNames.GUARD_CAPTURES: 50,
    RewardNames.SCOUT_CAPTURED: -50,
    RewardNames.SCOUT_RECON: 5,
    RewardNames.SCOUT_MISSION: 20,
    RewardNames.WALL_COLLISION: -0.5,
    RewardNames.STATIONARY_PENALTY: -0.2,  # Stronger penalty for staying still or spinning
    RewardNames.SCOUT_STEP: -0.01,
    RewardNames.GUARD_STEP: -0.02,         # Stronger penalty for guards to encourage movement
    RewardNames.GUARD_TRUNCATION: -10,
    RewardNames.SCOUT_TRUNCATION: 0,
}
print(f"Using custom rewards_dict (v7): {rewards_dict}")

# --- Training Parameters ---
TOTAL_TIMESTEPS = 5_000_000  # For improved performance
RANDOM_SEED = 42
MODEL_SAVE_NAME = "til_ai_marl_ppo_rs_explore_v7"
NOVICE_MODE = True
LEARNING_RATE = 0.0003

custom_env_wrappers_list = [
    ClipStepObservationWrapper,
    FlattenDictWrapper,
    partial(frame_stack_v2, stack_size=4, stack_dim=-1),
]

train_env_kwargs = {
    "render_mode": None,
    "novice": NOVICE_MODE,
    "env_wrappers": custom_env_wrappers_list,
    "rewards_dict": rewards_dict,
}

def create_training_env():
    env = gridworld.parallel_env(**train_env_kwargs)
    env = ss.black_death_v3(env)
    return env

print(f"Initializing training environment (Novice mode: {NOVICE_MODE})...")

pz_env = create_training_env()
ss_markov_vec_env = ss.pettingzoo_env_to_vec_env_v1(pz_env)

def no_op_seed_for_markov_vec_env(self, seed=None):
    pass
ss_markov_vec_env.seed = types.MethodType(no_op_seed_for_markov_vec_env, ss_markov_vec_env)

vec_env = SB3VecEnvWrapper(ss_markov_vec_env)

N_STEPS_PPO = 2048
BATCH_SIZE_PPO = N_STEPS_PPO * vec_env.num_envs

print(f"Observation space type: {type(vec_env.observation_space)}")
print(f"Observation space shape: {vec_env.observation_space.shape}")
print(f"Observation space dtype: {vec_env.observation_space.dtype}")

print(f"Action space: {vec_env.action_space}")
print(f"Number of environments (agents passed to SB3): {vec_env.num_envs}")

# --- Model Training ---
print(f"Starting training with PPO and MlpPolicy for {TOTAL_TIMESTEPS} timesteps.")
model = PPO(
    MlpPolicy,
    vec_env,
    verbose=0,  # Set to 0 to suppress logging output
    seed=RANDOM_SEED,
    learning_rate=LEARNING_RATE,
    n_steps=N_STEPS_PPO,
    batch_size=BATCH_SIZE_PPO,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    ent_coef=0.001,
    tensorboard_log="./til_marl_tensorboard_rs_explore_v7/",
    device="cpu"  # Change to "cuda" if GPU is available and configured
)

model.learn(total_timesteps=TOTAL_TIMESTEPS, progress_bar=True)

# --- Save Model ---
model_filename = f"{MODEL_SAVE_NAME}_{time.strftime('%Y%m%d-%H%M%S')}.zip"
model_path = os.path.join(".", model_filename)
model.save(model_path)
print(f"Model saved to {model_path}")

vec_env.close()
print("Training finished.")

# --- Evaluation Function ---
def evaluate_til_ai_marl(
    num_games: int = 10,
    render_mode: str | None = None,
    novice_eval: bool = True,
    model_to_load_path: str | None = None,
    eval_device: str = "cpu",
    print_actions: bool = False
):
    print(f"\nStarting evaluation (Novice: {novice_eval}, Games: {num_games}, Render: {render_mode}, Device: {eval_device})...")

    eval_rewards_dict = rewards_dict

    eval_env_kwargs = {
        "render_mode": render_mode,
        "novice": novice_eval,
        "env_wrappers": custom_env_wrappers_list, 
        "rewards_dict": eval_rewards_dict, 
    }
    
    eval_env = gridworld.env(**eval_env_kwargs)

    if model_to_load_path is None:
        try:
            list_of_models = glob.glob(f"./{MODEL_SAVE_NAME}*.zip") 
            if not list_of_models:
                print(f"No trained model found matching pattern ./{MODEL_SAVE_NAME}*.zip. Exiting evaluation.")
                eval_env.close()
                return
            model_to_load_path = max(list_of_models, key=os.path.getctime)
            print(f"Loading latest model for evaluation: {model_to_load_path}")
        except ValueError:
            print(f"Could not find a model to load with pattern ./{MODEL_SAVE_NAME}*.zip or list was empty. Exiting evaluation.")
            eval_env.close()
            return
            
    loaded_model = PPO.load(model_to_load_path, device=eval_device)

    total_rewards_all_games = {agent: 0.0 for agent in eval_env.possible_agents}
    
    video_folder = "./logs/videos_marl_rs_explore_v7"
    os.makedirs(video_folder, exist_ok=True)

    for game_num in range(num_games):
        eval_env.reset(seed=RANDOM_SEED + game_num + 1000)
        
        current_game_frames = []
        game_rewards_this_round = {agent: 0.0 for agent in eval_env.possible_agents}
        action_counts_this_game = {act: 0 for act in Action}

        for agent_id in eval_env.agent_iter():
            observation, reward, termination, truncation, info = eval_env.last()
            
            game_rewards_this_round[agent_id] += reward

            if termination or truncation:
                action_val = None
            else:
                action_val, _ = loaded_model.predict(observation, deterministic=True)
            
            eval_env.step(action_val)

            if action_val is not None:
                action_enum_member = Action(action_val)
                action_counts_this_game[action_enum_member] +=1
                if print_actions:
                    print(f"Game {game_num+1}, Agent {agent_id}, Action: {action_enum_member.name} ({action_val})")

            if render_mode == "rgb_array" and action_val is not None:
                try:
                    frame = eval_env.render()
                    if frame is not None:
                        current_game_frames.append(frame)
                except Exception as e:
                    print(f"Warning: Could not render frame for game {game_num+1}, agent {agent_id}: {e}")
        
        for agent_id_sum in eval_env.possible_agents:
            total_rewards_all_games[agent_id_sum] += game_rewards_this_round[agent_id_sum]
        
        print(f"Game {game_num + 1} finished. Rewards this game: {game_rewards_this_round}")
        print(f"Action distribution for Game {game_num + 1}: { {act.name: count for act, count in action_counts_this_game.items()} }")

        if render_mode == "rgb_array" and current_game_frames:
            video_path = os.path.join(video_folder, f"{MODEL_SAVE_NAME}_game_{game_num}.mp4")
            try:
                imageio.mimsave(video_path, current_game_frames, fps=eval_env.metadata.get("render_fps", 10))
                print(f"Saved video of game {game_num+1} to {video_path}")
            except Exception as e:
                print(f"Error saving video for game {game_num+1}: {e}")
    
    eval_env.close()

    print("\n--- Evaluation Summary ---")
    avg_rewards_per_agent = {
        agent: total_rewards_all_games[agent] / num_games for agent in eval_env.possible_agents
    }
    print(f"Average rewards per agent over {num_games} games: {avg_rewards_per_agent}")
    
    if eval_env.possible_agents:
        team_avg_reward = sum(avg_rewards_per_agent.values())
        print(f"Sum of average rewards for all agents (team perspective): {team_avg_reward:.4f}")
        
        scout_agent_example = None
        for agent_name in eval_env.possible_agents:
            if "scout" in agent_name.lower():
                scout_agent_example = agent_name
                break
        if not scout_agent_example and eval_env.possible_agents:
            scout_agent_example = eval_env.possible_agents[0]

        official_scout_recon_reward = 1 
        official_scout_mission_reward = 5
        print(f"Note: The avg_rewards_per_agent is based on the training rewards_dict values.")
        print(f"A true qualifier score would need evaluation with official reward values.")

# --- Run Evaluation ---
if __name__ == "__main__":
    evaluate_til_ai_marl(
        num_games=2,
        render_mode=None,
        novice_eval=NOVICE_MODE,
        model_to_load_path=model_path, 
        eval_device="cpu",
        print_actions=True
    )
    
    evaluate_til_ai_marl(
        num_games=1,
        render_mode="rgb_array",
        novice_eval=NOVICE_MODE,
        model_to_load_path=model_path, 
        eval_device="cpu",
        print_actions=False
    )
    print(f"\nTo view videos, check the '{os.path.abspath('./logs/videos_marl_rs_explore_v7/')}' directory.")
    print("To view TensorBoard logs (if enabled and tensorboard installed), run: tensorboard --logdir ./til_marl_tensorboard_rs_explore_v7/")

Using custom rewards_dict (v7): {<RewardNames.GUARD_CAPTURES: 'guard_captures'>: 50, <RewardNames.SCOUT_CAPTURED: 'scout_captured'>: -50, <RewardNames.SCOUT_RECON: 'scout_recon'>: 5, <RewardNames.SCOUT_MISSION: 'scout_mission'>: 20, <RewardNames.WALL_COLLISION: 'wall_collision'>: -0.5, <RewardNames.STATIONARY_PENALTY: 'stationary_penalty'>: -0.2, <RewardNames.SCOUT_STEP: 'scout_step'>: -0.01, <RewardNames.GUARD_STEP: 'guard_step'>: -0.02, <RewardNames.GUARD_TRUNCATION: 'guard_truncation'>: -10, <RewardNames.SCOUT_TRUNCATION: 'scout_truncation'>: 0}
Initializing training environment (Novice mode: True)...
Observation space type: <class 'gymnasium.spaces.box.Box'>
Observation space shape: (572,)
Observation space dtype: int64
Action space: Discrete(5)
Number of environments (agents passed to SB3): 4
Starting training with PPO and MlpPolicy for 5000000 timesteps.


Output()

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir ./til_marl_tensorboard_rs_explore_v6/

In [None]:
print("Hello World")