In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
import imageio
from collections import deque, namedtuple
import time # Added for timing

# Assuming til_environment.gridworld and RewardNames are correctly importable
from til_environment.gridworld import RewardNames

import functools
from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType
from pettingzoo.utils.wrappers.base import BaseWrapper

# --- Configuration ---
# Environment specific
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100

# --- CNNDQN Model Hyperparameters (from your CNNDQN.py) ---
VIEWCONE_CHANNELS = 8
VIEWCONE_HEIGHT = 7
VIEWCONE_WIDTH = 5
OTHER_FEATURES_SIZE = 4 + 2 + 1 + 1 # Direction (4) + Location (2) + Scout (1) + Step (1)

CNN_OUTPUT_CHANNELS_1 = 16
CNN_OUTPUT_CHANNELS_2 = 32
KERNEL_SIZE_1 = (3, 3)
STRIDE_1 = 1
KERNEL_SIZE_2 = (3, 3)
STRIDE_2 = 1
MLP_HIDDEN_LAYER_1_SIZE = 128
MLP_HIDDEN_LAYER_2_SIZE = 128
OUTPUT_ACTIONS = 5
DROPOUT_RATE = 0.2 # Active during training; model.eval() handles it for inference

# Training Hyperparameters
BUFFER_SIZE = int(1e5)
BATCH_SIZE = 32 # Reduced from 64 due to potentially larger state
GAMMA = 0.99
LEARNING_RATE = 1e-4
TARGET_UPDATE_EVERY = 1000 # Global steps for the agent being trained
UPDATE_EVERY = 4 # Agent steps within an episode for the agent being trained

# Epsilon-greedy exploration (for guard training)
EPSILON_START = 0.5 # Start higher for guards to explore
EPSILON_END = 0.05  # End lower but still some exploration
EPSILON_DECAY = 0.9995 # Slower decay
MIN_EPSILON_FRAMES_GUARD = int(5e3) # Steps before significant decay for guards

SCOUT_EPSILON_INFERENCE = 0.01 # Small epsilon for the pre-trained scout

# PER Parameters
PER_ALPHA = 0.6
PER_BETA_START = 0.4
PER_BETA_FRAMES = int(1e5)
PER_EPSILON = 1e-6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Reward shaping (ensure this aligns with guard objectives)
EXPLORATION_BONUS_REWARD = 0.5 # Reduced bonus
CUSTOM_REWARDS_DICT = {
    RewardNames.SCOUT_MISSION: 0,
    RewardNames.SCOUT_RECON: 0,
    RewardNames.WALL_COLLISION: -10.0,
    RewardNames.AGENT_COLLIDER: -1.0,
    RewardNames.AGENT_COLLIDEE: -1.0,
    RewardNames.STATIONARY_PENALTY: -8,
    # Guard specific rewards should be positive for them
    RewardNames.GUARD_WINS: 50.0,          # Guards win (scout captured or mission failed by scout)
    RewardNames.GUARD_CAPTURES: 100.0,     # A guard directly captures the scout
    RewardNames.GUARD_TRUNCATION: 20.0,    # Guards successfully prevent scout mission for full duration
    RewardNames.GUARD_STEP: -0.05,          # Small step penalty for guards to encourage efficiency
    # Scout penalties from guard perspective (implicitly good for guards)
    RewardNames.SCOUT_CAPTURED: 100.0,     # If scout gets captured, good for guards
}

# Model file paths
INITIAL_SCOUT_MODEL_PATH = "my_wargame_cnn_agent_35500.pth"
SCOUT_MODEL_SAVE_PATH = "scout.pth" # Scout model will be re-saved but not trained here
GUARD_MODEL_SAVE_PATH = "guard.pth"


class CustomWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
    def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): # Removed unused manhattan_reward_scale
        super().__init__(env)
    # reset, step, observe, observation_space can be kept as is or removed if not adding custom logic

# --- SumTree (Unchanged) ---
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.n_entries = 0
    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity: self.n_entries += 1
    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change
    def get_leaf(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1; right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree): leaf_idx = parent_idx; break
            else:
                if value <= self.tree[left_child_idx]: parent_idx = left_child_idx
                else: value -= self.tree[left_child_idx]; parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
    @property
    def total_priority(self): return self.tree[0]

# --- Prioritized Replay Buffer (Adapted for two-part state) ---
Experience = namedtuple("Experience", field_names=["state_viewcone", "state_other", "action", "reward", "next_state_viewcone", "next_state_other", "done"])

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.max_priority = 1.0

    def add(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done):
        experience = Experience(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.tree.add(self.max_priority, experience)

    def sample(self, batch_size, beta=PER_BETA_START):
        batch_idx, batch_data, weights_list = np.empty(batch_size, dtype=np.int32), np.empty(batch_size, dtype=object), np.empty(batch_size, dtype=np.float32)
        priority_segment = self.tree.total_priority / batch_size if self.tree.n_entries > 0 else 0
        if self.tree.n_entries == 0: # Handle empty buffer
             return (torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0)), np.array([]), torch.empty(0)

        for i in range(batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority, data = self.tree.get_leaf(value)
            sampling_probabilities = priority / self.tree.total_priority if self.tree.total_priority > 0 else 0
            weights_list[i] = np.power(self.tree.n_entries * sampling_probabilities + 1e-8, -beta)
            batch_idx[i], batch_data[i] = index, data
        
        weights_list /= (weights_list.max() if weights_list.max() > 0 else 1.0)

        s_vc, s_o, act, r, next_s_vc, next_s_o, d = zip(*[e for e in batch_data if e is not None])
        if not s_vc: # Handle empty after filtering
            return (torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0)), np.array([]), torch.empty(0)

        states_vc = torch.from_numpy(np.array(s_vc)).float().to(DEVICE)
        states_o = torch.from_numpy(np.array(s_o)).float().to(DEVICE)
        actions = torch.from_numpy(np.vstack(act)).long().to(DEVICE)
        rewards = torch.from_numpy(np.vstack(r)).float().to(DEVICE)
        next_states_vc = torch.from_numpy(np.array(next_s_vc)).float().to(DEVICE)
        next_states_o = torch.from_numpy(np.array(next_s_o)).float().to(DEVICE)
        dones = torch.from_numpy(np.vstack(d).astype(np.uint8)).float().to(DEVICE)
        
        return (states_vc, states_o, actions, rewards, next_states_vc, next_states_o, dones), batch_idx, torch.from_numpy(weights_list).float().to(DEVICE)

    def update_priorities(self, batch_indices, td_errors):
        if len(batch_indices) == 0: return
        priorities = np.abs(td_errors) + PER_EPSILON
        priorities = np.power(priorities, self.alpha)
        for idx, priority_val in zip(batch_indices, priorities):
            self.tree.update(idx, priority_val)
        if priorities.size > 0: self.max_priority = max(self.max_priority, np.max(priorities))
    def __len__(self): return self.tree.n_entries

# --- CNNDQN Model (from your CNNDQN.py) ---
class CNNDQN(nn.Module):
    def __init__(self, viewcone_channels, viewcone_height, viewcone_width, other_features_size, mlp_hidden1, mlp_hidden2, num_actions, dropout_rate):
        super(CNNDQN, self).__init__()
        self.conv1 = nn.Conv2d(viewcone_channels, CNN_OUTPUT_CHANNELS_1, kernel_size=KERNEL_SIZE_1, stride=STRIDE_1, padding=1)
        self.relu_conv1 = nn.ReLU()
        h_out1 = (viewcone_height + 2 * 1 - KERNEL_SIZE_1[0]) // STRIDE_1 + 1
        w_out1 = (viewcone_width + 2 * 1 - KERNEL_SIZE_1[1]) // STRIDE_1 + 1
        
        self.conv2 = nn.Conv2d(CNN_OUTPUT_CHANNELS_1, CNN_OUTPUT_CHANNELS_2, kernel_size=KERNEL_SIZE_2, stride=STRIDE_2, padding=1)
        self.relu_conv2 = nn.ReLU()
        h_out2 = (h_out1 + 2 * 1 - KERNEL_SIZE_2[0]) // STRIDE_2 + 1
        w_out2 = (w_out1 + 2 * 1 - KERNEL_SIZE_2[1]) // STRIDE_2 + 1

        self.cnn_output_flat_size = CNN_OUTPUT_CHANNELS_2 * h_out2 * w_out2
        self.fc1_mlp = nn.Linear(self.cnn_output_flat_size + other_features_size, mlp_hidden1)
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2_mlp = nn.Linear(mlp_hidden1, mlp_hidden2)
        self.relu_fc2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc_output = nn.Linear(mlp_hidden2, num_actions)

    def forward(self, viewcone_input, other_features_input):
        x_cnn = self.relu_conv1(self.conv1(viewcone_input))
        x_cnn = self.relu_conv2(self.conv2(x_cnn))
        x_cnn_flat = x_cnn.view(-1, self.cnn_output_flat_size)
        combined_features = torch.cat((x_cnn_flat, other_features_input), dim=1)
        x = self.relu_fc1(self.fc1_mlp(combined_features))
        x = self.dropout1(x)
        x = self.relu_fc2(self.fc2_mlp(x))
        x = self.dropout2(x)
        return self.fc_output(x)

# --- Trainable RL Agent (Adapted for CNNDQN) ---
class TrainableRLAgent:
    def __init__(self, model_load_path=None, model_save_path="trained_cnn_dqn_model.pth", is_learning_agent=True):
        self.device = DEVICE
        self.is_learning_agent = is_learning_agent
        print(f"Agent ({'Learning' if is_learning_agent else 'Acting'}): Using device: {self.device}")

        self.policy_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH, 
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE, 
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)
        if self.is_learning_agent:
            self.target_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH, 
                                     OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE, 
                                     MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)

        if model_load_path and os.path.exists(model_load_path):
            try:
                self.policy_net.load_state_dict(torch.load(model_load_path, map_location=self.device))
                print(f"Loaded policy_net from {model_load_path}")
            except Exception as e:
                print(f"Error loading model from {model_load_path}: {e}. Initializing random weights.")
                self.policy_net.apply(self._initialize_weights)
        else:
            if model_load_path: print(f"Model path {model_load_path} not found. Initializing random weights.")
            else: print("No model_load_path. Initializing random weights.")
            self.policy_net.apply(self._initialize_weights)

        if self.is_learning_agent:
            self.target_net.load_state_dict(self.policy_net.state_dict())
            self.target_net.eval()
            self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
            self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, alpha=PER_ALPHA)
            self.beta = PER_BETA_START
            self.beta_increment_per_sampling = (1.0 - PER_BETA_START) / PER_BETA_FRAMES if PER_BETA_FRAMES > 0 else 0
            self.t_step_episode = 0 # For UPDATE_EVERY (agent's own steps)
        
        self.model_save_path = model_save_path
        self.global_steps_for_target_update = 0 # For TARGET_UPDATE_EVERY (shared across agent's training steps)


    def _initialize_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)

    def _unpack_viewcone_tile(self, tile_value): # From CNNDQN.py
        return [float((tile_value >> i) & 1) for i in range(VIEWCONE_CHANNELS)] 

    def process_observation(self, observation_dict): # From CNNDQN.py
        raw_viewcone = observation_dict.get("viewcone", np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8))
        if not isinstance(raw_viewcone, np.ndarray): raw_viewcone = np.array(raw_viewcone)
        if raw_viewcone.shape != (VIEWCONE_HEIGHT, VIEWCONE_WIDTH):
            padded_viewcone = np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8)
            h, w = raw_viewcone.shape; h_min, w_min = min(h, VIEWCONE_HEIGHT), min(w, VIEWCONE_WIDTH)
            padded_viewcone[:h_min, :w_min] = raw_viewcone[:h_min, :w_min]; raw_viewcone = padded_viewcone
        processed_vc_data = np.zeros((VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.float32)
        for r_idx in range(VIEWCONE_HEIGHT):
            for c_idx in range(VIEWCONE_WIDTH):
                unpacked = self._unpack_viewcone_tile(raw_viewcone[r_idx, c_idx])
                for ch_idx in range(VIEWCONE_CHANNELS): processed_vc_data[ch_idx, r_idx, c_idx] = unpacked[ch_idx]
        
        other_list = []
        direction = observation_dict.get("direction", 0); dir_one_hot = [0.0]*4; dir_one_hot[direction%4]=1.0; other_list.extend(dir_one_hot)
        loc = observation_dict.get("location", [0,0]); norm_x=loc[0]/MAP_SIZE_X; norm_y=loc[1]/MAP_SIZE_Y; other_list.extend([norm_x, norm_y])
        other_list.append(float(observation_dict.get("scout", 0))) # This will be 0 for guards, 1 for scout
        other_list.append(observation_dict.get("step", 0)/MAX_STEPS_PER_EPISODE)
        state_other_np = np.array(other_list, dtype=np.float32)
        
        return processed_vc_data, state_other_np

    def select_action(self, state_viewcone_np, state_other_np, epsilon=0.0):
        if random.random() > epsilon:
            vc_tensor = torch.from_numpy(state_viewcone_np).float().unsqueeze(0).to(self.device)
            o_tensor = torch.from_numpy(state_other_np).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad(): action_values = self.policy_net(vc_tensor, o_tensor)
            self.policy_net.train() if self.is_learning_agent else self.policy_net.eval()
            return np.argmax(action_values.cpu().data.numpy())
        return random.choice(np.arange(OUTPUT_ACTIONS))

    def step(self, s_vc, s_o, act, rwd, next_s_vc, next_s_o, dn):
        if not self.is_learning_agent: return
        self.memory.add(s_vc, s_o, act, rwd, next_s_vc, next_s_o, dn)
        self.t_step_episode = (self.t_step_episode + 1) % UPDATE_EVERY
        if self.t_step_episode == 0 and len(self.memory) > BATCH_SIZE:
            experiences, indices, weights = self.memory.sample(BATCH_SIZE, beta=self.beta)
            if experiences[0].nelement() > 0: self.learn(experiences, indices, weights, GAMMA)
            self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
            self.global_steps_for_target_update +=1 # Count learning steps for target update
            if self.global_steps_for_target_update % TARGET_UPDATE_EVERY == 0:
                self.update_target_net()
    
    def learn(self, experiences, indices, importance_sampling_weights, gamma):
        s_vc, s_o, act, rwd, next_s_vc, next_s_o, dn = experiences
        if s_vc.nelement() == 0: return

        q_next_policy_actions = self.policy_net(next_s_vc, next_s_o).detach().max(1)[1].unsqueeze(1)
        q_targets_next = self.target_net(next_s_vc, next_s_o).detach().gather(1, q_next_policy_actions)
        q_targets = rwd + (gamma * q_targets_next * (1 - dn))
        q_expected = self.policy_net(s_vc, s_o).gather(1, act)
        
        td_errors = (q_targets - q_expected).abs().cpu().detach().numpy().flatten()
        self.memory.update_priorities(indices, td_errors)
        loss = (importance_sampling_weights * nn.MSELoss(reduction='none')(q_expected, q_targets)).mean()
        
        self.optimizer.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

    def update_target_net(self):
        if not self.is_learning_agent: return
        self.target_net.load_state_dict(self.policy_net.state_dict())
        print(f"Guard Target network updated at global step {self.global_steps_for_target_update}.")

    def save_model(self):
        if self.model_save_path:
            torch.save(self.policy_net.state_dict(), self.model_save_path)
            role = "Guard" if self.is_learning_agent else "Scout"
            print(f"{role} model saved to {self.model_save_path}")

# --- Main Training Loop ---
def train_guards_with_scout(env_module, num_episodes=2000, novice_track=False, 
                           initial_scout_model=INITIAL_SCOUT_MODEL_PATH,
                           load_guard_model_from=None, # Optional: path to existing guard.pth
                           save_scout_to=SCOUT_MODEL_SAVE_PATH,
                           save_guard_to=GUARD_MODEL_SAVE_PATH,
                           render_mode=None, video_folder=None):
    
    env = env_module.env(env_wrappers=[CustomWrapper], render_mode=render_mode, novice=novice_track, rewards_dict=CUSTOM_REWARDS_DICT)
    if render_mode == "rgb_array" and video_folder: os.makedirs(video_folder, exist_ok=True)

    print(f"Possible agents: {env.possible_agents}")
    # Assuming agent_0 is scout, others are guards. This needs to be robust if agent IDs change.
    # The obs['scout'] flag is the primary way to distinguish.
    scout_agent_id_example = env.possible_agents[0] # For print statements
    guard_agent_id_example = env.possible_agents[1] if len(env.possible_agents) > 1 else env.possible_agents[0]
    print(f"Example Scout ID for obs space check: {scout_agent_id_example}")
    print(f"Example Guard ID for obs space check: {guard_agent_id_example}")


    # --- Initialize Scout Agent (Not Learning) ---
    scout_agent = TrainableRLAgent(model_load_path=initial_scout_model, 
                                   model_save_path=save_scout_to, 
                                   is_learning_agent=False)
    scout_agent.policy_net.eval() # Scout is purely in evaluation mode

    # --- Initialize Guard Agent (Learning) ---
    # Load guard model if path provided, else load from initial scout model
    guard_initial_load_path = load_guard_model_from
    if not guard_initial_load_path or not os.path.exists(guard_initial_load_path):
        print(f"Guard model at '{load_guard_model_from}' not found or not provided. Initializing guards from scout model: '{initial_scout_model}'")
        guard_initial_load_path = initial_scout_model 
        if not os.path.exists(guard_initial_load_path):
             print(f"CRITICAL WARNING: Initial scout model '{initial_scout_model}' also not found for guard initialization!")


    guard_trainer_agent = TrainableRLAgent(model_load_path=guard_initial_load_path, 
                                           model_save_path=save_guard_to, 
                                           is_learning_agent=True)
    
    scores_deque = deque(maxlen=100) # Tracks guard scores
    epsilon_guard = EPSILON_START
    
    # Stores (state_vc, state_other, action_taken) for guards pending next state/reward
    pending_guard_experiences = {} 

    for i_episode in range(1, num_episodes + 1):
        env.reset()
        pending_guard_experiences.clear()
        current_episode_rewards = {agent_id: 0 for agent_id in env.possible_agents}
        visited_guard_locations_episode = set() # For exploration bonus for guards
        episode_frames = []
        should_record_video = (render_mode == "rgb_array" and video_folder and i_episode % 100 == 0)
        
        # Reset agent specific step counters for learning/target updates if they are per episode
        # guard_trainer_agent.t_step_episode = 0 # Already handled by step method logic

        for pet_agent_id in env.agent_iter():
            observation_raw, reward, termination, truncation, info = env.last()
            done = termination or truncation
            
            # Accumulate reward for the agent whose turn it was previously
            # The reward passed to env.last() is for the action that *led* to this current observation_raw
            if pet_agent_id in current_episode_rewards: # Check if agent is still active
                 current_episode_rewards[pet_agent_id] += reward


            if should_record_video:
                try: frame = env.render(); episode_frames.append(frame)
                except Exception as e: print(f"Frame render error: {e}")

            # --- 1. Complete pending transition for the current guard agent (if any) ---
            if pet_agent_id in pending_guard_experiences:
                prev_s_vc, prev_s_o, prev_action = pending_guard_experiences.pop(pet_agent_id)
                
                exploration_bonus = 0.0
                next_s_vc_np, next_s_o_np = None, None

                if not done and observation_raw is not None:
                    # Ensure obs_dict uses current observation_raw
                    obs_dict = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in observation_raw.items()}
                    next_s_vc_np, next_s_o_np = guard_trainer_agent.process_observation(obs_dict)
                    
                    # Exploration bonus for guards visiting new locations
                    current_loc_tuple = tuple(obs_dict.get("location", [None,None]))
                    if current_loc_tuple != (None,None) and current_loc_tuple not in visited_guard_locations_episode:
                        visited_guard_locations_episode.add(current_loc_tuple)
                        exploration_bonus = EXPLORATION_BONUS_REWARD
                else: # Terminal state for this agent
                    next_s_vc_np, next_s_o_np = np.zeros_like(prev_s_vc), np.zeros_like(prev_s_o)
                
                final_reward = reward + exploration_bonus
                guard_trainer_agent.step(prev_s_vc, prev_s_o, prev_action, final_reward, next_s_vc_np, next_s_o_np, done)

            # --- 2. Agent selects and takes an action ---
            action_to_take = None
            if done:
                action_to_take = None # No action if agent is done
            else:
                if observation_raw is None: # Should not happen if not done
                    action_to_take = env.action_space(pet_agent_id).sample() if env.action_space(pet_agent_id) is not None else None
                else:
                    obs_dict = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in observation_raw.items()}
                    is_scout_turn = obs_dict.get("scout", 0) == 1

                    if is_scout_turn:
                        s_vc, s_o = scout_agent.process_observation(obs_dict)
                        action_to_take = scout_agent.select_action(s_vc, s_o, SCOUT_EPSILON_INFERENCE)
                    else: # Guard's turn
                        s_vc, s_o = guard_trainer_agent.process_observation(obs_dict)
                        action_to_take = guard_trainer_agent.select_action(s_vc, s_o, epsilon_guard)
                        pending_guard_experiences[pet_agent_id] = (s_vc, s_o, action_to_take)
            
            env.step(action_to_take)

        # --- End of Episode ---
        # Calculate total reward for guards this episode for score tracking
        total_guard_score_episode = sum(rwd for ag_id, rwd in current_episode_rewards.items() if env.possible_agents.index(ag_id) != 0) # Assuming agent_0 is scout
        # Or track a specific guard if IDs are stable: current_episode_rewards.get(guard_agent_id_example, 0)
        
        scores_deque.append(total_guard_score_episode) # Track average guard score
        
        if guard_trainer_agent.global_steps_for_target_update > MIN_EPSILON_FRAMES_GUARD : # Decay epsilon for guards
            epsilon_guard = max(EPSILON_END, EPSILON_DECAY * epsilon_guard)
        
        if should_record_video and video_folder and episode_frames:
            try: imageio.mimsave(os.path.join(video_folder, f"ep_{i_episode:04d}.mp4"), episode_frames, fps=15)
            except Exception as e: print(f"Video save error: {e}")
       
        avg_score_str = f"{np.mean(scores_deque):.2f}" if scores_deque else "N/A"
        print(f'\rEp {i_episode}\tAvg Guard Score: {avg_score_str}\tGuard Eps: {epsilon_guard:.3f}\tGuard Train Steps: {guard_trainer_agent.global_steps_for_target_update}', end="")
        if i_episode % 100 == 0:
            print(f'\rEp {i_episode}\tAvg Guard Score: {avg_score_str}\tGuard Eps: {epsilon_guard:.3f}\tGuard Train Steps: {guard_trainer_agent.global_steps_for_target_update}')
            guard_trainer_agent.save_model()
            scout_agent.save_model() # Re-save scout model (though it's not changing)
            
    env.close()
    print("\nTraining finished.")
    guard_trainer_agent.save_model() # Save final guard model
    scout_agent.save_model()       # Save final scout model
    return scores_deque # Return guard scores

if __name__ == '__main__':
    t_start = time.time()
    try:
        from til_environment import gridworld
        print("Successfully imported til_environment.gridworld")
        
        guard_scores = train_guards_with_scout(
            gridworld, 
            num_episodes=50000, # Adjust
            novice_track=False,
            initial_scout_model=INITIAL_SCOUT_MODEL_PATH,
            load_guard_model_from=GUARD_MODEL_SAVE_PATH, # Try to load existing guard model
            save_scout_to=SCOUT_MODEL_SAVE_PATH,
            save_guard_to=GUARD_MODEL_SAVE_PATH,
            render_mode="rgb_array", # "human" or "rgb_array" or None
            video_folder="./rl_renders_guard_vs_scout"
        )
        
        if guard_scores:
            import matplotlib.pyplot as plt
            plt.plot(np.arange(len(guard_scores)), guard_scores)
            plt.ylabel('Total Guard Score per Episode')
            plt.xlabel('Episode #')
            plt.title('Guard Training Performance')
            plt.savefig("guard_training_scores.png")
            plt.show()

    except ImportError:
        print("Could not import 'til_environment.gridworld'. Ensure it's accessible.")
    except Exception as e:
        print(f"An error occurred: {e}")
        import traceback
        traceback.print_exc()
    
    t_end = time.time()
    print(f"Total script time: {(t_end - t_start)/60:.2f} minutes")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
# import imageio # Removed
from collections import deque, namedtuple
import time

from til_environment.gridworld import RewardNames # Ensure this import works

# import functools # Not used directly, can be removed if CustomWrapper doesn't need it explicitly
from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType
from pettingzoo.utils.wrappers.base import BaseWrapper

# --- Configuration ---
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100

VIEWCONE_CHANNELS = 8
VIEWCONE_HEIGHT = 7
VIEWCONE_WIDTH = 5
OTHER_FEATURES_SIZE = 4 + 2 + 1 + 1

CNN_OUTPUT_CHANNELS_1 = 16
CNN_OUTPUT_CHANNELS_2 = 32
KERNEL_SIZE_1 = (3, 3); STRIDE_1 = 1
KERNEL_SIZE_2 = (3, 3); STRIDE_2 = 1
MLP_HIDDEN_LAYER_1_SIZE = 128
MLP_HIDDEN_LAYER_2_SIZE = 128
OUTPUT_ACTIONS = 5
DROPOUT_RATE = 0.2

BUFFER_SIZE = int(1e5)
BATCH_SIZE = 32
GAMMA = 0.99
LEARNING_RATE = 1e-4

# Epsilon for Scout
SCOUT_EPSILON_START = 0.1
SCOUT_EPSILON_END = 0.01
SCOUT_EPSILON_DECAY = 0.9999
MIN_EPSILON_FRAMES_SCOUT = int(2e3)

# Epsilon for Guards
GUARD_EPSILON_START = 0.5
GUARD_EPSILON_END = 0.05
GUARD_EPSILON_DECAY = 0.9995
MIN_EPSILON_FRAMES_GUARD = int(5e3)

TARGET_UPDATE_EVERY = 1000
UPDATE_EVERY = 4

PER_ALPHA = 0.6; PER_BETA_START = 0.4; PER_BETA_FRAMES = int(1e5); PER_EPSILON = 1e-6
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CUSTOM_REWARDS_DICT = {
    RewardNames.SCOUT_MISSION: 100.0, RewardNames.SCOUT_RECON: 10.0,
    RewardNames.SCOUT_TRUNCATION: 50.0, RewardNames.SCOUT_CAPTURED: -100.0,
    RewardNames.GUARD_WINS: 50.0, RewardNames.GUARD_CAPTURES: 100.0,
    RewardNames.GUARD_TRUNCATION: 20.0, RewardNames.WALL_COLLISION: -10.0,
    RewardNames.AGENT_COLLIDER: -2.0, RewardNames.AGENT_COLLIDEE: -1.0,
    RewardNames.STATIONARY_PENALTY: -0.5, RewardNames.SCOUT_STEP: -0.1,
    RewardNames.GUARD_STEP: -0.05,
}
# EXPLORATION_BONUS = 0.1 # Not currently used in the loop, can be removed or re-added if needed

INITIAL_SCOUT_MODEL_PATH = "my_wargame_cnn_agent_35500.pth"
SCOUT_MODEL_SAVE_PATH = "scout_learning.pth"
GUARD_MODEL_SAVE_PATH = "guard_learning.pth"


class CustomWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
    def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
        super().__init__(env)

class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity; self.tree = np.zeros(2 * capacity - 1); self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0; self.n_entries = 0
    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1; self.data[self.data_pointer] = data; self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity: self.n_entries += 1
    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]; self.tree[tree_idx] = priority
        while tree_idx != 0: tree_idx = (tree_idx - 1) // 2; self.tree[tree_idx] += change
    def get_leaf(self, value):
        parent_idx = 0
        while True:
            l_child = 2*parent_idx+1; r_child = l_child+1
            if l_child >= len(self.tree): leaf_idx=parent_idx; break
            if value <= self.tree[l_child]: parent_idx=l_child
            else: value -= self.tree[l_child]; parent_idx=r_child
        return leaf_idx, self.tree[leaf_idx], self.data[leaf_idx - self.capacity + 1]
    @property
    def total_priority(self): return self.tree[0]

Experience = namedtuple("Experience", ["state_viewcone", "state_other", "action", "reward", "next_state_viewcone", "next_state_other", "done"])
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity); self.alpha = alpha; self.max_priority = 1.0
    def add(self, s_vc, s_o, act, r, next_s_vc, next_s_o, d):
        self.tree.add(self.max_priority, Experience(s_vc,s_o,act,r,next_s_vc,next_s_o,d))
    def sample(self, batch_size, beta=PER_BETA_START):
        b_idx, b_data, w = np.empty(batch_size,dtype=np.int32), np.empty(batch_size,dtype=object), np.empty(batch_size,dtype=np.float32)
        pri_seg = self.tree.total_priority/batch_size if self.tree.n_entries > 0 else 0
        if self.tree.n_entries < batch_size : return (torch.empty(0),torch.empty(0),torch.empty(0),torch.empty(0),torch.empty(0),torch.empty(0),torch.empty(0)), np.array([]), torch.empty(0) # Ensure enough samples
        for i in range(batch_size):
            val = np.random.uniform(pri_seg*i, pri_seg*(i+1))
            idx, pri, data = self.tree.get_leaf(val)
            probs = pri/self.tree.total_priority if self.tree.total_priority > 0 else 0
            w[i] = np.power(self.tree.n_entries * probs + 1e-8, -beta); b_idx[i],b_data[i] = idx,data
        w /= (w.max() if w.max() > 0 else 1.0)
        s_vc,s_o,act,r,next_s_vc,next_s_o,d = zip(*[e for e in b_data if e is not None])
        if not s_vc : return (torch.empty(0),torch.empty(0),torch.empty(0),torch.empty(0),torch.empty(0),torch.empty(0),torch.empty(0)), np.array([]), torch.empty(0)
        s_vc_t = torch.from_numpy(np.array(s_vc)).float().to(DEVICE); s_o_t = torch.from_numpy(np.array(s_o)).float().to(DEVICE)
        act_t = torch.from_numpy(np.vstack(act)).long().to(DEVICE); r_t = torch.from_numpy(np.vstack(r)).float().to(DEVICE)
        next_s_vc_t = torch.from_numpy(np.array(next_s_vc)).float().to(DEVICE); next_s_o_t = torch.from_numpy(np.array(next_s_o)).float().to(DEVICE)
        d_t = torch.from_numpy(np.vstack(d).astype(np.uint8)).float().to(DEVICE)
        return (s_vc_t,s_o_t,act_t,r_t,next_s_vc_t,next_s_o_t,d_t), b_idx, torch.from_numpy(w).float().to(DEVICE)
    def update_priorities(self, b_indices, td_errs):
        if len(b_indices)==0: return
        prios = np.abs(td_errs)+PER_EPSILON; prios = np.power(prios,self.alpha)
        for idx,p in zip(b_indices,prios): self.tree.update(idx,p)
        if prios.size>0: self.max_priority=max(self.max_priority,np.max(prios))
    def __len__(self): return self.tree.n_entries

class CNNDQN(nn.Module):
    def __init__(self, v_c, v_h, v_w, o_f_s, mlp_h1, mlp_h2, n_a, dr):
        super(CNNDQN,self).__init__(); self.conv1=nn.Conv2d(v_c,CNN_OUTPUT_CHANNELS_1,KERNEL_SIZE_1,STRIDE_1,padding=1); self.relu_conv1=nn.ReLU()
        h1=(v_h+2*1-KERNEL_SIZE_1[0])//STRIDE_1+1; w1=(v_w+2*1-KERNEL_SIZE_1[1])//STRIDE_1+1
        self.conv2=nn.Conv2d(CNN_OUTPUT_CHANNELS_1,CNN_OUTPUT_CHANNELS_2,KERNEL_SIZE_2,STRIDE_2,padding=1); self.relu_conv2=nn.ReLU()
        h2=(h1+2*1-KERNEL_SIZE_2[0])//STRIDE_2+1; w2=(w1+2*1-KERNEL_SIZE_2[1])//STRIDE_2+1
        self.cnn_flat_size=CNN_OUTPUT_CHANNELS_2*h2*w2
        self.fc1_mlp=nn.Linear(self.cnn_flat_size+o_f_s,mlp_h1); self.relu_fc1=nn.ReLU(); self.dropout1=nn.Dropout(dr)
        self.fc2_mlp=nn.Linear(mlp_h1,mlp_h2); self.relu_fc2=nn.ReLU(); self.dropout2=nn.Dropout(dr)
        self.fc_output=nn.Linear(mlp_h2,n_a)
    def forward(self, vc_in, of_in):
        x=self.relu_conv1(self.conv1(vc_in)); x=self.relu_conv2(self.conv2(x)); x_flat=x.view(-1,self.cnn_flat_size)
        comb=torch.cat((x_flat,of_in),dim=1); x=self.relu_fc1(self.fc1_mlp(comb)); x=self.dropout1(x)
        x=self.relu_fc2(self.fc2_mlp(x)); x=self.dropout2(x); return self.fc_output(x)

class TrainableRLAgent:
    def __init__(self, agent_role_name, model_load_path=None, model_save_path="trained_model.pth"):
        self.device=DEVICE; self.agent_role_name = agent_role_name
        print(f"Agent Role: {self.agent_role_name} | Using device: {self.device}")
        self.policy_net = CNNDQN(VIEWCONE_CHANNELS,VIEWCONE_HEIGHT,VIEWCONE_WIDTH,OTHER_FEATURES_SIZE,MLP_HIDDEN_LAYER_1_SIZE,MLP_HIDDEN_LAYER_2_SIZE,OUTPUT_ACTIONS,DROPOUT_RATE).to(self.device)
        self.target_net = CNNDQN(VIEWCONE_CHANNELS,VIEWCONE_HEIGHT,VIEWCONE_WIDTH,OTHER_FEATURES_SIZE,MLP_HIDDEN_LAYER_1_SIZE,MLP_HIDDEN_LAYER_2_SIZE,OUTPUT_ACTIONS,DROPOUT_RATE).to(self.device)
        if model_load_path and os.path.exists(model_load_path):
            try: self.policy_net.load_state_dict(torch.load(model_load_path,map_location=self.device)); print(f"Loaded {self.agent_role_name} policy_net from {model_load_path}")
            except Exception as e: print(f"Error loading {self.agent_role_name} model: {e}. Init random."); self.policy_net.apply(self._init_w)
        else:
            if model_load_path: print(f"{self.agent_role_name} model not found at {model_load_path}. Init random.")
            else: print(f"No model_load_path for {self.agent_role_name}. Init random."); self.policy_net.apply(self._init_w)
        self.target_net.load_state_dict(self.policy_net.state_dict()); self.target_net.eval()
        self.optimizer=optim.Adam(self.policy_net.parameters(),lr=LEARNING_RATE); self.memory=PrioritizedReplayBuffer(BUFFER_SIZE,PER_ALPHA)
        self.beta=PER_BETA_START; self.beta_inc=(1.0-PER_BETA_START)/PER_BETA_FRAMES if PER_BETA_FRAMES>0 else 0
        self.t_step_episode=0; self.model_save_path=model_save_path; self.total_learning_steps=0
    def _init_w(self,m):
        if isinstance(m,(nn.Linear,nn.Conv2d)): nn.init.xavier_uniform_(m.weight); nn.init.constant_(m.bias,0) if m.bias is not None else None
    def _unpack_vc_tile(self,val): return [float((val>>i)&1) for i in range(VIEWCONE_CHANNELS)]
    def proc_obs(self,obs_d):
        vc_raw=obs_d.get("viewcone",np.zeros((VIEWCONE_HEIGHT,VIEWCONE_WIDTH),dtype=np.uint8))
        if not isinstance(vc_raw,np.ndarray):vc_raw=np.array(vc_raw)
        if vc_raw.shape!=(VIEWCONE_HEIGHT,VIEWCONE_WIDTH):
            pad_vc=np.zeros((VIEWCONE_HEIGHT,VIEWCONE_WIDTH),dtype=np.uint8);h,w=vc_raw.shape;h_m,w_m=min(h,VIEWCONE_HEIGHT),min(w,VIEWCONE_WIDTH)
            pad_vc[:h_m,:w_m]=vc_raw[:h_m,:w_m];vc_raw=pad_vc
        vc_proc=np.zeros((VIEWCONE_CHANNELS,VIEWCONE_HEIGHT,VIEWCONE_WIDTH),dtype=np.float32)
        for r in range(VIEWCONE_HEIGHT):
            for c in range(VIEWCONE_WIDTH):
                unp=self._unpack_vc_tile(vc_raw[r,c])
                for ch in range(VIEWCONE_CHANNELS):vc_proc[ch,r,c]=unp[ch]
        o_list=[];d=obs_d.get("direction",0);d_oh=[0.]*4;d_oh[d%4]=1.;o_list.extend(d_oh)
        l=obs_d.get("location",[0,0]);nx=l[0]/MAP_SIZE_X;ny=l[1]/MAP_SIZE_Y;o_list.extend([nx,ny])
        o_list.append(float(obs_d.get("scout",0)));o_list.append(obs_d.get("step",0)/MAX_STEPS_PER_EPISODE)
        s_o_np=np.array(o_list,dtype=np.float32)
        return vc_proc,s_o_np
    def sel_act(self,s_vc,s_o,eps=0.):
        if random.random()>eps:
            vc_t=torch.from_numpy(s_vc).float().unsqueeze(0).to(self.device);o_t=torch.from_numpy(s_o).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad():av=self.policy_net(vc_t,o_t)
            self.policy_net.train()
            return np.argmax(av.cpu().data.numpy())
        return random.choice(np.arange(OUTPUT_ACTIONS))
    def step(self,s_vc,s_o,act,rwd,next_s_vc,next_s_o,dn):
        self.memory.add(s_vc,s_o,act,rwd,next_s_vc,next_s_o,dn); self.t_step_episode=(self.t_step_episode+1)%UPDATE_EVERY
        if self.t_step_episode==0 and len(self.memory)>BATCH_SIZE:
            exp,idx,w=self.memory.sample(BATCH_SIZE,self.beta)
            if exp[0].nelement()>0:self.learn(exp,idx,w,GAMMA)
            self.beta=min(1.,self.beta+self.beta_inc);self.total_learning_steps+=1
            if self.total_learning_steps%TARGET_UPDATE_EVERY==0:self.update_target_net()
    def learn(self,exp,idx,is_w,gam):
        s_vc,s_o,act,rwd,next_s_vc,next_s_o,dn=exp
        if s_vc.nelement()==0:return
        q_next_pol_act=self.policy_net(next_s_vc,next_s_o).detach().max(1)[1].unsqueeze(1)
        q_targets_next=self.target_net(next_s_vc,next_s_o).detach().gather(1,q_next_pol_act)
        q_targets=rwd+(gam*q_targets_next*(1-dn));q_exp=self.policy_net(s_vc,s_o).gather(1,act)
        td_errs=(q_targets-q_exp).abs().cpu().detach().numpy().flatten();self.memory.update_priorities(idx,td_errs)
        loss=(is_w*nn.MSELoss(reduction='none')(q_exp,q_targets)).mean()
        self.optimizer.zero_grad();loss.backward();torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(),1.)
        self.optimizer.step()
    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
        print(f"{self.agent_role_name} Target network updated at {self.total_learning_steps} learning steps.")
    def save_model(self):
        if self.model_save_path:torch.save(self.policy_net.state_dict(),self.model_save_path);print(f"{self.agent_role_name} model saved: {self.model_save_path}")

def train_scout_and_guards(env_module, num_episodes=2000, novice_track=False,
                           initial_scout_path=INITIAL_SCOUT_MODEL_PATH,
                           initial_guard_path=GUARD_MODEL_SAVE_PATH,
                           save_scout_to=SCOUT_MODEL_SAVE_PATH,
                           save_guard_to=GUARD_MODEL_SAVE_PATH): # Removed render_mode, video_folder
    
    # render_mode is None for faster training
    env = env_module.env(env_wrappers=[CustomWrapper], render_mode=None, novice=novice_track, rewards_dict=CUSTOM_REWARDS_DICT)
    # if render_mode=="rgb_array" and video_folder: os.makedirs(video_folder,exist_ok=True) # Removed

    scout_trainer = TrainableRLAgent(agent_role_name="Scout", model_load_path=initial_scout_path, model_save_path=save_scout_to)
    
    effective_guard_load_path = initial_guard_path
    if not os.path.exists(effective_guard_load_path):
        print(f"Guard model at '{initial_guard_path}' not found. Initializing guards from scout model: '{initial_scout_path}'")
        effective_guard_load_path = initial_scout_path 
        if not os.path.exists(effective_guard_load_path):
            print(f"CRITICAL: Scout model '{initial_scout_path}' also not found for guard init. Guards start random.")
            effective_guard_load_path = None
    guard_trainer = TrainableRLAgent(agent_role_name="Guard", model_load_path=effective_guard_load_path, model_save_path=save_guard_to)

    eps_scout, eps_guard = SCOUT_EPSILON_START, GUARD_EPSILON_START
    scout_scores_q, guard_scores_q = deque(maxlen=100), deque(maxlen=100)
    pending_exps = {} 

    for i_episode in range(1, num_episodes + 1):
        env.reset(); pending_exps.clear()
        ep_rewards = {ag_id:0 for ag_id in env.possible_agents}
        # ep_frames = []; video_this_ep = (render_mode=="rgb_array" and video_folder and i_episode%100==0) # Removed

        for pet_id in env.agent_iter():
            obs_raw, reward, term, trunc, info = env.last()
            done = term or trunc
            if pet_id in ep_rewards: ep_rewards[pet_id] += reward
            # if video_this_ep: # Removed
            #     try: ep_frames.append(env.render())
            #     except Exception as e: pass 

            current_agent_is_scout = False 
            agent_obj_for_turn = guard_trainer 
            current_epsilon_for_turn = eps_guard 

            if obs_raw is not None: 
                obs_d_temp = {k:v.tolist() if isinstance(v,np.ndarray) else v for k,v in obs_raw.items()}
                current_agent_is_scout = obs_d_temp.get("scout",0)==1
                if current_agent_is_scout:
                    agent_obj_for_turn = scout_trainer
                    current_epsilon_for_turn = eps_scout
            
            if pet_id in pending_exps:
                prev_s_vc, prev_s_o, prev_act, prev_was_scout = pending_exps.pop(pet_id)
                learning_agent = scout_trainer if prev_was_scout else guard_trainer
                
                next_s_vc_np, next_s_o_np = None, None
                if not done and obs_raw is not None:
                    next_s_vc_np, next_s_o_np = learning_agent.proc_obs(obs_d_temp)
                else:
                    next_s_vc_np,next_s_o_np = np.zeros_like(prev_s_vc),np.zeros_like(prev_s_o)
                
                final_rwd = reward 
                learning_agent.step(prev_s_vc, prev_s_o, prev_act, final_rwd, next_s_vc_np, next_s_o_np, done)

            act_to_take = None
            if done: pass 
            elif obs_raw is None: act_to_take = env.action_space(pet_id).sample() if env.action_space(pet_id) else None
            else: 
                s_vc, s_o = agent_obj_for_turn.proc_obs(obs_d_temp)
                act_to_take = agent_obj_for_turn.sel_act(s_vc, s_o, current_epsilon_for_turn)
                pending_exps[pet_id] = (s_vc, s_o, act_to_take, current_agent_is_scout)
            env.step(act_to_take)

        scout_ep_score = sum(r for ag_id,r in ep_rewards.items() if env.possible_agents.index(ag_id)==0) 
        guard_ep_score = sum(r for ag_id,r in ep_rewards.items() if env.possible_agents.index(ag_id)!=0)
        scout_scores_q.append(scout_ep_score); guard_scores_q.append(guard_ep_score)

        if scout_trainer.total_learning_steps > MIN_EPSILON_FRAMES_SCOUT: eps_scout = max(SCOUT_EPSILON_END, SCOUT_EPSILON_DECAY*eps_scout)
        if guard_trainer.total_learning_steps > MIN_EPSILON_FRAMES_GUARD: eps_guard = max(GUARD_EPSILON_END, GUARD_EPSILON_DECAY*eps_guard)
        
        # if video_this_ep and ep_frames: # Removed
        #     try: imageio.mimsave(os.path.join(video_folder,f"ep_{i_episode:04d}.mp4"),ep_frames,fps=10)
        #     except Exception as e: pass 
        
        avg_s_scr=f"{np.mean(scout_scores_q):.2f}" if scout_scores_q else "N/A"
        avg_g_scr=f"{np.mean(guard_scores_q):.2f}" if guard_scores_q else "N/A"
        print(f"\rEp {i_episode}| Scout AvgS: {avg_s_scr} (Eps:{eps_scout:.3f} Steps:{scout_trainer.total_learning_steps}) Guard AvgS: {avg_g_scr} (Eps:{eps_guard:.3f} Steps:{guard_trainer.total_learning_steps})",end="")
        if i_episode%100==0:
            print(f"\rEp {i_episode}| Scout AvgS: {avg_s_scr} (Eps:{eps_scout:.3f} Steps:{scout_trainer.total_learning_steps}) Guard AvgS: {avg_g_scr} (Eps:{eps_guard:.3f} Steps:{guard_trainer.total_learning_steps})")
            scout_trainer.save_model(); guard_trainer.save_model()
            
    env.close(); print("\nTraining Done."); scout_trainer.save_model(); guard_trainer.save_model()
    # return scout_scores_q, guard_scores_q # Removed return for plotting

if __name__ == '__main__':
    t_start=time.time()
    try:
        from til_environment import gridworld
        print("Imported til_environment.gridworld")
        train_scout_and_guards( # Modified to not expect return values for plotting
            gridworld, num_episodes=50000, novice_track=False, 
            initial_scout_path=INITIAL_SCOUT_MODEL_PATH, 
            initial_guard_path=GUARD_MODEL_SAVE_PATH,    
            save_scout_to=SCOUT_MODEL_SAVE_PATH,
            save_guard_to=GUARD_MODEL_SAVE_PATH
            # render_mode and video_folder args removed from call
        )
        # if s_scores and g_scores: # Removed plotting block
        #     import matplotlib.pyplot as plt
        #     ...
    except ImportError: print("Could not import 'til_environment.gridworld'.")
    except Exception as e: print(f"Error: {e}"); import traceback; traceback.print_exc()
    print(f"Total script time: {(time.time()-t_start)/60:.2f} minutes")

Imported til_environment.gridworld
Agent Role: Scout | Using device: cuda
Loaded Scout policy_net from my_wargame_cnn_agent_35500.pth
Guard model at 'guard_learning.pth' not found. Initializing guards from scout model: 'my_wargame_cnn_agent_35500.pth'
Agent Role: Guard | Using device: cuda
Loaded Guard policy_net from my_wargame_cnn_agent_35500.pth
Ep 15| Scout AvgS: 398.84 (Eps:0.100 Steps:320) Guard AvgS: 649.19 (Eps:0.500 Steps:976)Guard Target network updated at 1000 learning steps.
Ep 29| Scout AvgS: 304.54 (Eps:0.100 Steps:638) Guard AvgS: 422.47 (Eps:0.500 Steps:1931)Guard Target network updated at 2000 learning steps.
Ep 44| Scout AvgS: 230.64 (Eps:0.100 Steps:980) Guard AvgS: 408.73 (Eps:0.500 Steps:2957)Guard Target network updated at 3000 learning steps.
Scout Target network updated at 1000 learning steps.
Ep 59| Scout AvgS: 200.17 (Eps:0.100 Steps:1325) Guard AvgS: 472.61 (Eps:0.500 Steps:3991)Guard Target network updated at 4000 learning steps.
Ep 73| Scout AvgS: 221.79 (E