# Train Multiple Agents

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from collections import deque, namedtuple
import time

# --- Configuration ---
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100
VIEWCONE_CHANNELS = 8
VIEWCONE_HEIGHT = 7
VIEWCONE_WIDTH = 5
OTHER_FEATURES_SIZE = 4 + 2 + 1 + 1

CNN_OUTPUT_CHANNELS_1 = 16
CNN_OUTPUT_CHANNELS_2 = 32
KERNEL_SIZE_1 = (3, 3)
STRIDE_1 = 1
KERNEL_SIZE_2 = (3, 3)
STRIDE_2 = 1
MLP_HIDDEN_LAYER_1_SIZE = 128
MLP_HIDDEN_LAYER_2_SIZE = 128
OUTPUT_ACTIONS = 5
DROPOUT_RATE = 0.2

BUFFER_SIZE = int(1e5) 
BATCH_SIZE = 64
GAMMA = 0.99
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 1e-5
TARGET_UPDATE_EVERY = 1000 
UPDATE_EVERY = 4 

EPSILON_START = 0.8      
EPSILON_END = 0.1        
EPSILON_DECAY_RATE = 0.9999 
MIN_EPSILON_FRAMES = int(1e4) 
PER_ALPHA = 0.6
PER_BETA_START = 0.4 
PER_BETA_FRAMES = int(1e5) 
PER_EPSILON = 1e-6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Experience = namedtuple("Experience", field_names=["state_viewcone", "state_other", "action", "reward", "next_state_viewcone", "next_state_other", "done"])

class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.n_entries = 0

    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity:
            self.n_entries += 1

    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_leaf(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                if value <= self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    value -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_priority(self):
        return self.tree[0]

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.max_priority = 1.0

    def add(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done):
        experience = Experience(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.tree.add(self.max_priority, experience)

    def sample(self, batch_size, beta=PER_BETA_START):
        batch_idx = np.empty(batch_size, dtype=np.int32)
        batch_data = np.empty(batch_size, dtype=object)
        weights = np.empty(batch_size, dtype=np.float32)
        priority_segment = self.tree.total_priority / batch_size if batch_size > 0 and self.tree.total_priority > 0 else 0
        
        if self.tree.n_entries == 0: # Avoid division by zero if buffer is empty
             return (torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0)), np.array([]), torch.empty(0)


        for i in range(batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority, data = self.tree.get_leaf(value)
            
            sampling_probabilities = priority / self.tree.total_priority if self.tree.total_priority > 0 else 0
            weights[i] = np.power(self.tree.n_entries * sampling_probabilities + 1e-8, -beta) # Added 1e-8 for stability
            batch_idx[i], batch_data[i] = index, data
        
        weights /= (weights.max() if weights.max() > 0 else 1.0)

        states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones = zip(*[e for e in batch_data if e is not None]) # Ensure data is not None
        
        if not states_viewcone: # Handle cases where batch_data might have yielded no valid experiences
            return (torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0)), np.array([]), torch.empty(0)

        states_viewcone = torch.from_numpy(np.array(states_viewcone)).float().to(DEVICE)
        states_other = torch.from_numpy(np.array(states_other)).float().to(DEVICE)
        actions = torch.from_numpy(np.vstack(actions)).long().to(DEVICE)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(DEVICE)
        next_states_viewcone = torch.from_numpy(np.array(next_states_viewcone)).float().to(DEVICE)
        next_states_other = torch.from_numpy(np.array(next_states_other)).float().to(DEVICE)
        dones = torch.from_numpy(np.vstack(dones).astype(np.uint8)).float().to(DEVICE)
        
        return (states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones), batch_idx, torch.from_numpy(weights).float().to(DEVICE)

    def update_priorities(self, batch_indices, td_errors):
        if len(batch_indices) == 0: return
        priorities = np.abs(td_errors) + PER_EPSILON
        priorities = np.power(priorities, self.alpha)
        for idx, priority_val in zip(batch_indices, priorities):
            self.tree.update(idx, priority_val)
        if priorities.size > 0: # Check if priorities is not empty
            self.max_priority = max(self.max_priority, np.max(priorities)) # Use np.max for numpy array

    def __len__(self):
        return self.tree.n_entries

class CNNDQN(nn.Module):
    def __init__(self, viewcone_channels, viewcone_height, viewcone_width, other_features_size, mlp_hidden1, mlp_hidden2, num_actions, dropout_rate):
        super(CNNDQN, self).__init__()
        self.conv1 = nn.Conv2d(viewcone_channels, CNN_OUTPUT_CHANNELS_1, kernel_size=KERNEL_SIZE_1, stride=STRIDE_1, padding=1)
        self.relu_conv1 = nn.ReLU()
        h_out1 = (viewcone_height + 2 * 1 - KERNEL_SIZE_1[0]) // STRIDE_1 + 1
        w_out1 = (viewcone_width + 2 * 1 - KERNEL_SIZE_1[1]) // STRIDE_1 + 1
        
        self.conv2 = nn.Conv2d(CNN_OUTPUT_CHANNELS_1, CNN_OUTPUT_CHANNELS_2, kernel_size=KERNEL_SIZE_2, stride=STRIDE_2, padding=1)
        self.relu_conv2 = nn.ReLU()
        h_out2 = (h_out1 + 2 * 1 - KERNEL_SIZE_2[0]) // STRIDE_2 + 1
        w_out2 = (w_out1 + 2 * 1 - KERNEL_SIZE_2[1]) // STRIDE_2 + 1

        self.cnn_output_flat_size = CNN_OUTPUT_CHANNELS_2 * h_out2 * w_out2
        self.fc1_mlp = nn.Linear(self.cnn_output_flat_size + other_features_size, mlp_hidden1)
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2_mlp = nn.Linear(mlp_hidden1, mlp_hidden2)
        self.relu_fc2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc_output = nn.Linear(mlp_hidden2, num_actions)

    def forward(self, viewcone_input, other_features_input):
        x_cnn = self.relu_conv1(self.conv1(viewcone_input))
        x_cnn = self.relu_conv2(self.conv2(x_cnn))
        x_cnn_flat = x_cnn.view(-1, self.cnn_output_flat_size)
        combined_features = torch.cat((x_cnn_flat, other_features_input), dim=1)
        x = self.relu_fc1(self.fc1_mlp(combined_features))
        x = self.dropout1(x)
        x = self.relu_fc2(self.fc2_mlp(x))
        x = self.dropout2(x)
        return self.fc_output(x)

class TrainableRLAgent:
    def __init__(self, model_load_path=None, model_save_path="trained_cnn_dqn_model.pth"):
        self.device = DEVICE
        self.policy_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH, 
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE, 
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)
        self.target_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH, 
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE, 
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)
        
        if model_load_path and os.path.exists(model_load_path):
            try:
                print(f"Loading model from {model_load_path}")
                self.policy_net.load_state_dict(torch.load(model_load_path, map_location=self.device))
            except Exception as e:
                print(f"Warning: Error loading model from {model_load_path}: {e}. Initializing new model.")
                self.policy_net.apply(self._initialize_weights)
        else:
            if model_load_path: print(f"Warning: Model file not found at {model_load_path}. Initializing new model.")
            else: print("No model_load_path specified. Initializing new model.")
            self.policy_net.apply(self._initialize_weights)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, alpha=PER_ALPHA)
        self.model_save_path = model_save_path
        self.t_step_episode = 0 
        self.beta = PER_BETA_START
        self.beta_increment_per_sampling = (1.0 - PER_BETA_START) / PER_BETA_FRAMES if PER_BETA_FRAMES > 0 else 0

    def _initialize_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)

    def _unpack_viewcone_tile(self, tile_value):
        return [float((tile_value >> i) & 1) for i in range(VIEWCONE_CHANNELS)] 

    def process_observation(self, observation_dict):
        raw_viewcone = observation_dict.get("viewcone", np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8))
        if not isinstance(raw_viewcone, np.ndarray): raw_viewcone = np.array(raw_viewcone)
        if raw_viewcone.shape != (VIEWCONE_HEIGHT, VIEWCONE_WIDTH):
            padded_viewcone = np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8)
            h, w = raw_viewcone.shape
            h_min, w_min = min(h, VIEWCONE_HEIGHT), min(w, VIEWCONE_WIDTH)
            padded_viewcone[:h_min, :w_min] = raw_viewcone[:h_min, :w_min]
            raw_viewcone = padded_viewcone

        processed_viewcone_channels_data = np.zeros((VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.float32)
        for r in range(VIEWCONE_HEIGHT):
            for c in range(VIEWCONE_WIDTH):
                tile_value = raw_viewcone[r, c]
                unpacked_features = self._unpack_viewcone_tile(tile_value)
                for channel_idx in range(VIEWCONE_CHANNELS):
                    processed_viewcone_channels_data[channel_idx, r, c] = unpacked_features[channel_idx]
        
        other_features_list = []
        direction = observation_dict.get("direction", 0)
        direction_one_hot = [0.0] * 4; direction_one_hot[direction % 4] = 1.0
        other_features_list.extend(direction_one_hot)
        location = observation_dict.get("location", [0,0]); norm_x = location[0]/MAP_SIZE_X; norm_y = location[1]/MAP_SIZE_Y
        other_features_list.extend([norm_x, norm_y])
        other_features_list.append(float(observation_dict.get("scout", 0)))
        other_features_list.append(observation_dict.get("step", 0)/MAX_STEPS_PER_EPISODE)
        state_other_np = np.array(other_features_list, dtype=np.float32)
        
        return processed_viewcone_channels_data, state_other_np

    def select_action(self, state_viewcone_np, state_other_np, epsilon=0.0):
        if random.random() > epsilon:
            state_viewcone_tensor = torch.from_numpy(state_viewcone_np).float().unsqueeze(0).to(self.device)
            state_other_tensor = torch.from_numpy(state_other_np).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad(): action_values = self.policy_net(state_viewcone_tensor, state_other_tensor)
            self.policy_net.train()
            return np.argmax(action_values.cpu().data.numpy())
        return random.choice(np.arange(OUTPUT_ACTIONS))

    def step(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done):
        self.memory.add(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.t_step_episode = (self.t_step_episode + 1) % UPDATE_EVERY
        if self.t_step_episode == 0 and len(self.memory) > BATCH_SIZE:
            experiences, indices, weights = self.memory.sample(BATCH_SIZE, beta=self.beta)
            if experiences[0].nelement() > 0: # Check if tensors are not empty
                 self.learn(experiences, indices, weights, GAMMA)
            self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
    
    def learn(self, experiences, indices, importance_sampling_weights, gamma):
        states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones = experiences
        
        # Check for empty batch again, critical if sample() could return empty tensors directly
        if states_viewcone.nelement() == 0:
            return

        q_next_actions_policy = self.policy_net(next_states_viewcone, next_states_other).detach().max(1)[1].unsqueeze(1)
        q_targets_next = self.target_net(next_states_viewcone, next_states_other).detach().gather(1, q_next_actions_policy)
        q_targets = rewards + (gamma * q_targets_next * (1 - dones))
        q_expected = self.policy_net(states_viewcone, states_other).gather(1, actions)
        td_errors = (q_targets - q_expected).abs().cpu().detach().numpy().flatten()
        self.memory.update_priorities(indices, td_errors)
        loss = (importance_sampling_weights * nn.MSELoss(reduction='none')(q_expected, q_targets)).mean()
        self.optimizer.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save_model(self):
        if self.model_save_path:
            torch.save(self.policy_net.state_dict(), self.model_save_path)
            print(f"Model saved to {self.model_save_path}")

    def reset_episode_counters(self): self.t_step_episode = 0


def train_agent(env_module, num_episodes=100000, novice_track=False, load_model_from=None, save_model_to="trained_cnn_dqn_agent.pth"):
    print(f"Starting CNN DQN training: {num_episodes} episodes, Novice: {novice_track}")
    if load_model_from: print(f"Attempting to load model from: {load_model_from}")
    print(f"Models will be saved to: {save_model_to}")
    print(f"Using device: {DEVICE}")
    print(f"Epsilon for learning agent will start at {EPSILON_START:.4f} and decay towards {EPSILON_END:.4f}")

    agent = TrainableRLAgent(model_load_path=load_model_from, model_save_path=save_model_to)
    scores_deque = deque(maxlen=100) 
    epsilon = EPSILON_START 
    global_total_steps = 0 
    
    env = env_module.env(env_wrappers=[], render_mode=None, novice=novice_track)
    
    # my_agent_id will be the agent that learns. Other agents use its policy greedily.
    my_agent_id = env.possible_agents[0] if env.possible_agents else "agent_0"
    print(f"Primary learning agent ID: {my_agent_id}")
    other_agent_ids = [ag_id for ag_id in env.possible_agents if ag_id != my_agent_id]
    if other_agent_ids:
        print(f"Other agents ({other_agent_ids}) will use {my_agent_id}'s policy greedily (epsilon=0.0).")

    for i_episode in range(1, num_episodes + 1):
        env.reset() 
        agent.reset_episode_counters() # Reset for the shared agent instance
        current_episode_rewards_accumulator = {id: 0.0 for id in env.possible_agents}
        
        last_processed_exp_learning_agent = {} # Stores (prev_s_vc, prev_s_other, prev_a) for my_agent_id
        
        for agent_id_turn in env.agent_iter(): # agent_id_turn is the agent whose turn it is
            current_obs_raw, reward_for_current_agent_turn, termination, truncation, info = env.last()
            
            if agent_id_turn in current_episode_rewards_accumulator:
                 current_episode_rewards_accumulator[agent_id_turn] += reward_for_current_agent_turn

            done = termination or truncation
            action_to_take = None

            if done:
                if agent_id_turn == my_agent_id and my_agent_id in last_processed_exp_learning_agent:
                    prev_s_vc, prev_s_other, prev_a = last_processed_exp_learning_agent.pop(my_agent_id)
                    obs_dict_terminal = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in current_obs_raw.items()}
                    terminal_s_vc, terminal_s_other = agent.process_observation(obs_dict_terminal)
                    agent.step(prev_s_vc, prev_s_other, prev_a, reward_for_current_agent_turn, terminal_s_vc, terminal_s_other, True)
                # action_to_take remains None, env.step(None) will be called
            else:
                obs_dict_current = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in current_obs_raw.items()}
                current_s_vc, current_s_other = agent.process_observation(obs_dict_current)

                if agent_id_turn == my_agent_id:
                    if my_agent_id in last_processed_exp_learning_agent:
                        prev_s_vc, prev_s_other, prev_a = last_processed_exp_learning_agent.pop(my_agent_id)
                        agent.step(prev_s_vc, prev_s_other, prev_a, reward_for_current_agent_turn, current_s_vc, current_s_other, False)
                    
                    action_to_take = agent.select_action(current_s_vc, current_s_other, epsilon)
                    last_processed_exp_learning_agent[my_agent_id] = (current_s_vc, current_s_other, action_to_take)
                    
                    global_total_steps += 1
                    if global_total_steps > MIN_EPSILON_FRAMES and epsilon > EPSILON_END : 
                        epsilon *= EPSILON_DECAY_RATE
                        epsilon = max(EPSILON_END, epsilon)
                    if global_total_steps % TARGET_UPDATE_EVERY == 0 and global_total_steps > 0: 
                        agent.update_target_net()
                else: # Other agents' turns
                    action_to_take = agent.select_action(current_s_vc, current_s_other, epsilon=0.2) # Use learned policy greedily
            
            env.step(action_to_take) 
        
        episode_score_my_agent = current_episode_rewards_accumulator.get(my_agent_id, 0.0)
        scores_deque.append(episode_score_my_agent)

        if i_episode % 100 == 0: 
            avg_score_str = f"{np.mean(scores_deque):.2f}" if scores_deque else "N/A"
            print(f'\rEp {i_episode}\tAvgScore({my_agent_id}, last 100): {avg_score_str}\tEps: {epsilon:.4f}\tGlobalSteps: {global_total_steps}\tBeta: {agent.beta:.4f}')
            if save_model_to: agent.save_model() 
            
    env.close()
    if save_model_to: agent.save_model() 
    print(f"\nCNN DQN Training finished. Final model saved to {save_model_to if save_model_to else 'N/A'}")

if __name__ == '__main__':
    training_start_time = time.time()
    print(f"Initiating CNN DQN training at {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(training_start_time))} UTC")
    try:
        from til_environment import gridworld # Ensure this is your environment module
        
        NUM_OVERNIGHT_EPISODES = 100000 
        # --- MODIFIED PATHS ---
        LOAD_MODEL_PATH = "my_wargame_cnn_agent_35500.pth" 
        SAVE_MODEL_PATH = "all_cnn_agent_135500.pth" # New save path
        # --- END MODIFIED PATHS ---
        NOVICE_MODE = False 

        train_agent(
            gridworld, 
            num_episodes=NUM_OVERNIGHT_EPISODES,
            novice_track=NOVICE_MODE,
            load_model_from=LOAD_MODEL_PATH, 
            save_model_to=SAVE_MODEL_PATH
        )

    except ImportError:
        print("Could not import 'til_environment.gridworld'. Ensure it's accessible and contains your PettingZoo environment.")
    except FileNotFoundError as fnf_error:
        print(f"Error: Model file not found. {fnf_error}")
    except Exception as e:
        print(f"An error occurred during CNN DQN training: {e}")
        import traceback
        traceback.print_exc()
    
    total_time_seconds = time.time() - training_start_time
    print(f"Total CNN DQN training time for this session: {total_time_seconds:.2f} seconds ({total_time_seconds/3600:.2f} hours).")

Initiating CNN DQN training at 2025-05-23 08:31:20 UTC
Starting CNN DQN training: 100000 episodes, Novice: False
Attempting to load model from: my_wargame_cnn_agent_35500.pth
Models will be saved to: all_cnn_agent_135500.pth
Using device: cuda
Epsilon for learning agent will start at 0.8000 and decay towards 0.1000
Loading model from my_wargame_cnn_agent_35500.pth
Primary learning agent ID: player_0
Other agents (['player_1', 'player_2', 'player_3']) will use player_0's policy greedily (epsilon=0.0).
Ep 100	AvgScore(player_0, last 100): 0.56	Eps: 0.8000	GlobalSteps: 7385	Beta: 0.4109
Model saved to all_cnn_agent_135500.pth
Ep 200	AvgScore(player_0, last 100): 4.38	Eps: 0.4618	GlobalSteps: 15494	Beta: 0.4229
Model saved to all_cnn_agent_135500.pth
Ep 300	AvgScore(player_0, last 100): 7.39	Eps: 0.2311	GlobalSteps: 22418	Beta: 0.4332
Model saved to all_cnn_agent_135500.pth
Ep 400	AvgScore(player_0, last 100): 10.51	Eps: 0.1141	GlobalSteps: 29471	Beta: 0.4436
Model saved to all_cnn_agent_1