In [None]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from collections import deque, namedtuple
import time # For timing training
import imageio # For video export

# --- Configuration ---
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100
VIEWCONE_CHANNELS = 8
VIEWCONE_HEIGHT = 7
VIEWCONE_WIDTH = 5
OTHER_FEATURES_SIZE = 4 + 2 + 1 + 1

CNN_OUTPUT_CHANNELS_1 = 16
CNN_OUTPUT_CHANNELS_2 = 32
KERNEL_SIZE_1 = (3, 3)
STRIDE_1 = 1
KERNEL_SIZE_2 = (3, 3)
STRIDE_2 = 1
MLP_HIDDEN_LAYER_1_SIZE = 128
MLP_HIDDEN_LAYER_2_SIZE = 128
OUTPUT_ACTIONS = 5
DROPOUT_RATE = 0.2

BUFFER_SIZE = int(1e5)
BATCH_SIZE = 64
GAMMA = 0.99
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
TARGET_UPDATE_EVERY = 1000 # Global steps
UPDATE_EVERY = 4 # Agent steps within an episode

EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY_RATE = 0.999 # Applied per global step after min frames
MIN_EPSILON_FRAMES = int(5e4)

PER_ALPHA = 0.6
PER_BETA_START = 0.4
PER_BETA_FRAMES = int(1e5) # Global steps for beta annealing
PER_EPSILON = 1e-6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VIDEO_OUTPUT_DIR = "training_videos_cnn_dqn" # Directory for videos

# --- SumTree and PrioritizedReplayBuffer (largely unchanged) ---
Experience = namedtuple("Experience", field_names=["state_viewcone", "state_other", "action", "reward", "next_state_viewcone", "next_state_other", "done"])

class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.n_entries = 0

    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity:
            self.n_entries += 1

    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_leaf(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                if value <= self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    value -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_priority(self):
        return self.tree[0]

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.max_priority = 1.0

    def add(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done):
        experience = Experience(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.tree.add(self.max_priority, experience)

    def sample(self, batch_size, beta=PER_BETA_START):
        batch_idx = np.empty(batch_size, dtype=np.int32)
        batch_data = np.empty(batch_size, dtype=object)
        weights = np.empty(batch_size, dtype=np.float32)
        priority_segment = self.tree.total_priority / batch_size if batch_size > 0 and self.tree.total_priority > 0 else 0
        
        for i in range(batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority, data = self.tree.get_leaf(value)
            
            sampling_probabilities = priority / self.tree.total_priority if self.tree.total_priority > 0 else 0
            weights[i] = np.power(self.tree.n_entries * sampling_probabilities + 1e-8, -beta)
            batch_idx[i], batch_data[i] = index, data
        
        weights /= (weights.max() if weights.max() > 0 else 1.0)

        states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones = zip(*[e for e in batch_data])
        states_viewcone = torch.from_numpy(np.array(states_viewcone)).float().to(DEVICE)
        states_other = torch.from_numpy(np.array(states_other)).float().to(DEVICE)
        actions = torch.from_numpy(np.vstack(actions)).long().to(DEVICE)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(DEVICE)
        next_states_viewcone = torch.from_numpy(np.array(next_states_viewcone)).float().to(DEVICE)
        next_states_other = torch.from_numpy(np.array(next_states_other)).float().to(DEVICE)
        dones = torch.from_numpy(np.vstack(dones).astype(np.uint8)).float().to(DEVICE)
        
        return (states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones), batch_idx, torch.from_numpy(weights).float().to(DEVICE)

    def update_priorities(self, batch_indices, td_errors):
        priorities = np.abs(td_errors) + PER_EPSILON
        priorities = np.power(priorities, self.alpha)
        for idx, priority_val in zip(batch_indices, priorities):
            self.tree.update(idx, priority_val)
        if priorities.size > 0:
            self.max_priority = max(self.max_priority, priorities.max())

    def __len__(self):
        return self.tree.n_entries

# --- CNN-DQN Model ---
class CNNDQN(nn.Module):
    def __init__(self, viewcone_channels, viewcone_height, viewcone_width, other_features_size, mlp_hidden1, mlp_hidden2, num_actions, dropout_rate):
        super(CNNDQN, self).__init__()
        self.conv1 = nn.Conv2d(viewcone_channels, CNN_OUTPUT_CHANNELS_1, kernel_size=KERNEL_SIZE_1, stride=STRIDE_1, padding=1)
        self.relu_conv1 = nn.ReLU()
        h_out1 = (viewcone_height + 2 * 1 - KERNEL_SIZE_1[0]) // STRIDE_1 + 1
        w_out1 = (viewcone_width + 2 * 1 - KERNEL_SIZE_1[1]) // STRIDE_1 + 1
        
        self.conv2 = nn.Conv2d(CNN_OUTPUT_CHANNELS_1, CNN_OUTPUT_CHANNELS_2, kernel_size=KERNEL_SIZE_2, stride=STRIDE_2, padding=1)
        self.relu_conv2 = nn.ReLU()
        h_out2 = (h_out1 + 2 * 1 - KERNEL_SIZE_2[0]) // STRIDE_2 + 1
        w_out2 = (w_out1 + 2 * 1 - KERNEL_SIZE_2[1]) // STRIDE_2 + 1

        self.cnn_output_flat_size = CNN_OUTPUT_CHANNELS_2 * h_out2 * w_out2
        # print(f"CNN output HxW: {h_out2}x{w_out2}, Flattened CNN output size: {self.cnn_output_flat_size}")

        self.fc1_mlp = nn.Linear(self.cnn_output_flat_size + other_features_size, mlp_hidden1)
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2_mlp = nn.Linear(mlp_hidden1, mlp_hidden2)
        self.relu_fc2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc_output = nn.Linear(mlp_hidden2, num_actions)

    def forward(self, viewcone_input, other_features_input):
        x_cnn = self.relu_conv1(self.conv1(viewcone_input))
        x_cnn = self.relu_conv2(self.conv2(x_cnn))
        x_cnn_flat = x_cnn.view(-1, self.cnn_output_flat_size)
        combined_features = torch.cat((x_cnn_flat, other_features_input), dim=1)
        x = self.relu_fc1(self.fc1_mlp(combined_features))
        x = self.dropout1(x)
        x = self.relu_fc2(self.fc2_mlp(x))
        x = self.dropout2(x)
        return self.fc_output(x)

# --- Trainable RL Agent ---
class TrainableRLAgent:
    def __init__(self, model_load_path=None, model_save_path="trained_cnn_dqn_model.pth"):
        self.device = DEVICE
        self.policy_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH, 
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE, 
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)
        self.target_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH, 
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE, 
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)
        
        if model_load_path and os.path.exists(model_load_path):
            try:
                self.policy_net.load_state_dict(torch.load(model_load_path, map_location=self.device))
            except Exception as e:
                print(f"Warning: Error loading model from {model_load_path}: {e}. Initializing new model.")
                self.policy_net.apply(self._initialize_weights)
        else:
            self.policy_net.apply(self._initialize_weights)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, alpha=PER_ALPHA)
        self.model_save_path = model_save_path
        self.t_step_episode = 0 # For UPDATE_EVERY
        self.beta = PER_BETA_START
        self.beta_increment_per_sampling = (1.0 - PER_BETA_START) / PER_BETA_FRAMES if PER_BETA_FRAMES > 0 else 0

    def _initialize_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)

    def _unpack_viewcone_tile(self, tile_value):
        return [float((tile_value >> i) & 1) for i in range(VIEWCONE_CHANNELS)] # Assumes 8 channels

    def process_observation(self, observation_dict):
        raw_viewcone = observation_dict.get("viewcone", np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8))
        if not isinstance(raw_viewcone, np.ndarray): raw_viewcone = np.array(raw_viewcone)
        if raw_viewcone.shape != (VIEWCONE_HEIGHT, VIEWCONE_WIDTH):
            # print(f"Warning: Viewcone shape mismatch. Expected ({VIEWCONE_HEIGHT},{VIEWCONE_WIDTH}), got {raw_viewcone.shape}. Using zeros.")
            padded_viewcone = np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8)
            h, w = raw_viewcone.shape
            h_min, w_min = min(h, VIEWCONE_HEIGHT), min(w, VIEWCONE_WIDTH)
            padded_viewcone[:h_min, :w_min] = raw_viewcone[:h_min, :w_min]
            raw_viewcone = padded_viewcone

        processed_viewcone_channels_data = np.zeros((VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.float32)
        for r in range(VIEWCONE_HEIGHT):
            for c in range(VIEWCONE_WIDTH):
                tile_value = raw_viewcone[r, c]
                unpacked_features = self._unpack_viewcone_tile(tile_value)
                for channel_idx in range(VIEWCONE_CHANNELS):
                    processed_viewcone_channels_data[channel_idx, r, c] = unpacked_features[channel_idx]
        
        other_features_list = []
        direction = observation_dict.get("direction", 0)
        direction_one_hot = [0.0] * 4; direction_one_hot[direction % 4] = 1.0
        other_features_list.extend(direction_one_hot)
        location = observation_dict.get("location", [0,0]); norm_x = location[0]/MAP_SIZE_X; norm_y = location[1]/MAP_SIZE_Y
        other_features_list.extend([norm_x, norm_y])
        other_features_list.append(float(observation_dict.get("scout", 0)))
        other_features_list.append(observation_dict.get("step", 0)/MAX_STEPS_PER_EPISODE)
        state_other_np = np.array(other_features_list, dtype=np.float32)
        
        return processed_viewcone_channels_data, state_other_np

    def select_action(self, state_viewcone_np, state_other_np, epsilon=0.0):
        if random.random() > epsilon:
            state_viewcone_tensor = torch.from_numpy(state_viewcone_np).float().unsqueeze(0).to(self.device)
            state_other_tensor = torch.from_numpy(state_other_np).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad(): action_values = self.policy_net(state_viewcone_tensor, state_other_tensor)
            self.policy_net.train()
            return np.argmax(action_values.cpu().data.numpy())
        return random.choice(np.arange(OUTPUT_ACTIONS))

    def step(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done):
        self.memory.add(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.t_step_episode = (self.t_step_episode + 1) % UPDATE_EVERY
        if self.t_step_episode == 0 and len(self.memory) > BATCH_SIZE:
            experiences, indices, weights = self.memory.sample(BATCH_SIZE, beta=self.beta)
            self.learn(experiences, indices, weights, GAMMA)
            self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
    
    def learn(self, experiences, indices, importance_sampling_weights, gamma):
        states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones = experiences
        q_next_actions_policy = self.policy_net(next_states_viewcone, next_states_other).detach().max(1)[1].unsqueeze(1)
        q_targets_next = self.target_net(next_states_viewcone, next_states_other).detach().gather(1, q_next_actions_policy)
        q_targets = rewards + (gamma * q_targets_next * (1 - dones))
        q_expected = self.policy_net(states_viewcone, states_other).gather(1, actions)
        td_errors = (q_targets - q_expected).abs().cpu().detach().numpy().flatten()
        self.memory.update_priorities(indices, td_errors)
        loss = (importance_sampling_weights * nn.MSELoss(reduction='none')(q_expected, q_targets)).mean()
        self.optimizer.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save_model(self):
        if self.model_save_path:
            torch.save(self.policy_net.state_dict(), self.model_save_path)
            print(f"Model saved to {self.model_save_path}")

    def reset_episode_counters(self): self.t_step_episode = 0

# --- Main Training Loop ---
def train_agent(env_module, num_episodes=10000, novice_track=False, load_model_from=None, save_model_to="trained_cnn_dqn_agent.pth", video_export_every_n_episodes=100):
    print(f"Starting CNN DQN training: {num_episodes} episodes, Novice: {novice_track}, Video every: {video_export_every_n_episodes} episodes")
    print(f"Load: {load_model_from}, Save: {save_model_to}")
    print(f"Using device: {DEVICE}")
    if video_export_every_n_episodes > 0 and not os.path.exists(VIDEO_OUTPUT_DIR):
        os.makedirs(VIDEO_OUTPUT_DIR)
        print(f"Created video output directory: {VIDEO_OUTPUT_DIR}")

    agent = TrainableRLAgent(model_load_path=load_model_from, model_save_path=save_model_to)
    scores_deque = deque(maxlen=100)
    scores = []
    epsilon = EPSILON_START
    global_total_steps = 0
    
    # Initialize env first time without render mode for performance
    current_render_mode = None
    env = env_module.env(env_wrappers=[], render_mode=current_render_mode, novice=novice_track)
    my_agent_id = env.possible_agents[0] if env.possible_agents else "agent_0"

    for i_episode in range(1, num_episodes + 1):
        episode_frames = []
        record_video_this_episode = (video_export_every_n_episodes > 0 and i_episode % video_export_every_n_episodes == 0)

        if record_video_this_episode and current_render_mode != "rgb_array":
            env.close() # Close existing env
            env = env_module.env(env_wrappers=[], render_mode="rgb_array", novice=novice_track)
            current_render_mode = "rgb_array"
            # print(f"Episode {i_episode}: Switched to rgb_array for video.")
        elif not record_video_this_episode and current_render_mode == "rgb_array":
            env.close()
            env = env_module.env(env_wrappers=[], render_mode=None, novice=novice_track)
            current_render_mode = None
            # print(f"Episode {i_episode}: Switched to render_mode=None.")

        env.reset() 
        agent.reset_episode_counters()
        current_episode_rewards = {id: 0.0 for id in env.possible_agents}
        last_processed_exp_my_agent = {} # Stores (prev_s_vc, prev_s_other, prev_a)
        
        for pet_agent_id_turn in env.agent_iter():
            obs_raw, _, termination, truncation, info = env.last()
            reward_for_last_action = env.rewards.get(my_agent_id, 0.0) # Reward for my_agent's last action
            
            for r_ag_id, r_val in env.rewards.items(): # Accumulate all rewards for episode score
                if r_ag_id in current_episode_rewards: current_episode_rewards[r_ag_id] += r_val
            
            done = termination or truncation
            action_to_take = None

            if pet_agent_id_turn == my_agent_id:
                obs_dict_current = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in obs_raw.items()}
                current_s_vc, current_s_other = agent.process_observation(obs_dict_current)

                if my_agent_id in last_processed_exp_my_agent:
                    prev_s_vc, prev_s_other, prev_a = last_processed_exp_my_agent.pop(my_agent_id)
                    agent.step(prev_s_vc, prev_s_other, prev_a, reward_for_last_action, current_s_vc, current_s_other, done)
                
                if done:
                    action_to_take = None
                else:
                    action_to_take = agent.select_action(current_s_vc, current_s_other, epsilon)
                    last_processed_exp_my_agent[my_agent_id] = (current_s_vc, current_s_other, action_to_take)
                
                global_total_steps += 1
                if global_total_steps > MIN_EPSILON_FRAMES: epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY_RATE)
                if global_total_steps % TARGET_UPDATE_EVERY == 0 and global_total_steps > 0: agent.update_target_net()

            elif not done and env.action_space(pet_agent_id_turn) is not None:
                action_to_take = env.action_space(pet_agent_id_turn).sample()
            
            env.step(action_to_take)
            if record_video_this_episode and current_render_mode == "rgb_array": # Add frame after step
                frame = env.render()
                if frame is not None: episode_frames.append(frame)
        
        # End of episode
        if record_video_this_episode and episode_frames:
            video_path = os.path.join(VIDEO_OUTPUT_DIR, f"episode_{i_episode:05d}.mp4")
            try:
                imageio.mimsave(video_path, episode_frames, fps=10)
                # print(f"Saved video: {video_path}")
            except Exception as e:
                print(f"Error saving video for episode {i_episode}: {e}")

        episode_score = current_episode_rewards.get(my_agent_id, 0.0)
        scores_deque.append(episode_score); scores.append(episode_score)

        if i_episode % 20 == 0:
            print(f'\rEp {i_episode}\tAvgScore(100): {np.mean(scores_deque):.2f}\tEps: {epsilon:.4f}\tGlobalSteps: {global_total_steps}', end="")
        if i_episode % 100 == 0:
            print(f'\rEp {i_episode}\tAvgScore(100): {np.mean(scores_deque):.2f}\tEps: {epsilon:.4f}\tGlobalSteps: {global_total_steps}')
            if save_model_to: agent.save_model()
            
    env.close()
    if save_model_to: agent.save_model()
    print(f"\nCNN DQN Training finished. Final model saved to {save_model_to if save_model_to else 'N/A'}")
    return scores

if __name__ == '__main__':
    training_start_time = time.time()
    print(f"Initiating CNN DQN training at {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(training_start_time))} UTC")
    try:
        from til_environment import gridworld 
        
        NUM_TRAIN_EPISODES = 10000 # Matched to MLP script's typical setting
        VIDEO_EVERY = 200 # Export video every 200 episodes (can be adjusted)

        trained_scores = train_agent(
            gridworld, 
            num_episodes=NUM_TRAIN_EPISODES,
            novice_track=False, # Or False for varied maps
            load_model_from=None, # "my_wargame_cnn_agent.pth" to resume
            save_model_to="my_wargame_cnn_agent_final.pth",
            video_export_every_n_episodes=VIDEO_EVERY
        )

    except ImportError:
        print("Could not import 'til_environment.gridworld'. Ensure it's accessible.")
    except Exception as e:
        print(f"An error occurred during CNN DQN training: {e}")
        import traceback
        traceback.print_exc()
    
    total_time_seconds = time.time() - training_start_time
    print(f"Total CNN DQN training time: {total_time_seconds:.2f} seconds ({total_time_seconds/60:.2f} minutes or {total_time_seconds/3600:.2f} hours).")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from collections import deque, namedtuple
import time

# --- Configuration ---
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100
VIEWCONE_CHANNELS = 8
VIEWCONE_HEIGHT = 7
VIEWCONE_WIDTH = 5
OTHER_FEATURES_SIZE = 4 + 2 + 1 + 1

CNN_OUTPUT_CHANNELS_1 = 16
CNN_OUTPUT_CHANNELS_2 = 32
KERNEL_SIZE_1 = (3, 3)
STRIDE_1 = 1
KERNEL_SIZE_2 = (3, 3)
STRIDE_2 = 1
MLP_HIDDEN_LAYER_1_SIZE = 128
MLP_HIDDEN_LAYER_2_SIZE = 128
OUTPUT_ACTIONS = 5
DROPOUT_RATE = 0.2

BUFFER_SIZE = int(1e5) # Consider increasing if memory allows for very long runs
BATCH_SIZE = 64
GAMMA = 0.99
LEARNING_RATE = 1e-4 # Could consider a slightly smaller LR for very long fine-tuning, e.g., 5e-5
WEIGHT_DECAY = 1e-5
TARGET_UPDATE_EVERY = 1000 # Global steps
UPDATE_EVERY = 4 # Agent steps within an episode

# Epsilon settings for long overnight training, resuming from 0.05
EPSILON_START = 0.5      # Start at the previous end value
EPSILON_END = 0.01        # Target lower epsilon for more exploitation
EPSILON_DECAY_RATE = 0.9999 # Slower decay for a very long run to reach 0.01 gradually
MIN_EPSILON_FRAMES = int(1e4) # Number of global steps before decay significantly takes hold or starts.
                              # Adjust based on how quickly you want it to drop from 0.05.
                              # With 100k episodes (potentially millions of steps), decay will be gradual.

PER_ALPHA = 0.6
PER_BETA_START = 0.4 # Could also consider loading saved beta and global_step from a checkpoint if you had one
PER_BETA_FRAMES = int(1e5) # Global steps for beta annealing
PER_EPSILON = 1e-6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ... (SumTree, PrioritizedReplayBuffer, CNNDQN, TrainableRLAgent class definitions remain the same as in train_cnn_dqn_optimized_resumed_advanced.py) ...
# Make sure to copy the full class definitions from the previous complete script.
# For brevity, I'm omitting them here, but they are essential.

# --- SumTree and PrioritizedReplayBuffer (largely unchanged) ---
Experience = namedtuple("Experience", field_names=["state_viewcone", "state_other", "action", "reward", "next_state_viewcone", "next_state_other", "done"])

class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.n_entries = 0

    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity:
            self.n_entries += 1

    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_leaf(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                if value <= self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    value -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_priority(self):
        return self.tree[0]

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.max_priority = 1.0

    def add(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done):
        experience = Experience(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.tree.add(self.max_priority, experience)

    def sample(self, batch_size, beta=PER_BETA_START):
        batch_idx = np.empty(batch_size, dtype=np.int32)
        batch_data = np.empty(batch_size, dtype=object)
        weights = np.empty(batch_size, dtype=np.float32)
        priority_segment = self.tree.total_priority / batch_size if batch_size > 0 and self.tree.total_priority > 0 else 0
        
        for i in range(batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority, data = self.tree.get_leaf(value)
            
            sampling_probabilities = priority / self.tree.total_priority if self.tree.total_priority > 0 else 0
            weights[i] = np.power(self.tree.n_entries * sampling_probabilities + 1e-8, -beta)
            batch_idx[i], batch_data[i] = index, data
        
        weights /= (weights.max() if weights.max() > 0 else 1.0)

        states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones = zip(*[e for e in batch_data])
        states_viewcone = torch.from_numpy(np.array(states_viewcone)).float().to(DEVICE)
        states_other = torch.from_numpy(np.array(states_other)).float().to(DEVICE)
        actions = torch.from_numpy(np.vstack(actions)).long().to(DEVICE)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(DEVICE)
        next_states_viewcone = torch.from_numpy(np.array(next_states_viewcone)).float().to(DEVICE)
        next_states_other = torch.from_numpy(np.array(next_states_other)).float().to(DEVICE)
        dones = torch.from_numpy(np.vstack(dones).astype(np.uint8)).float().to(DEVICE)
        
        return (states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones), batch_idx, torch.from_numpy(weights).float().to(DEVICE)

    def update_priorities(self, batch_indices, td_errors):
        priorities = np.abs(td_errors) + PER_EPSILON
        priorities = np.power(priorities, self.alpha)
        for idx, priority_val in zip(batch_indices, priorities):
            self.tree.update(idx, priority_val)
        if priorities.size > 0:
            self.max_priority = max(self.max_priority, priorities.max())

    def __len__(self):
        return self.tree.n_entries

# --- CNN-DQN Model ---
class CNNDQN(nn.Module):
    def __init__(self, viewcone_channels, viewcone_height, viewcone_width, other_features_size, mlp_hidden1, mlp_hidden2, num_actions, dropout_rate):
        super(CNNDQN, self).__init__()
        self.conv1 = nn.Conv2d(viewcone_channels, CNN_OUTPUT_CHANNELS_1, kernel_size=KERNEL_SIZE_1, stride=STRIDE_1, padding=1)
        self.relu_conv1 = nn.ReLU()
        h_out1 = (viewcone_height + 2 * 1 - KERNEL_SIZE_1[0]) // STRIDE_1 + 1
        w_out1 = (viewcone_width + 2 * 1 - KERNEL_SIZE_1[1]) // STRIDE_1 + 1
        
        self.conv2 = nn.Conv2d(CNN_OUTPUT_CHANNELS_1, CNN_OUTPUT_CHANNELS_2, kernel_size=KERNEL_SIZE_2, stride=STRIDE_2, padding=1)
        self.relu_conv2 = nn.ReLU()
        h_out2 = (h_out1 + 2 * 1 - KERNEL_SIZE_2[0]) // STRIDE_2 + 1
        w_out2 = (w_out1 + 2 * 1 - KERNEL_SIZE_2[1]) // STRIDE_2 + 1

        self.cnn_output_flat_size = CNN_OUTPUT_CHANNELS_2 * h_out2 * w_out2
        # print(f"CNN output HxW: {h_out2}x{w_out2}, Flattened CNN output size: {self.cnn_output_flat_size}")

        self.fc1_mlp = nn.Linear(self.cnn_output_flat_size + other_features_size, mlp_hidden1)
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2_mlp = nn.Linear(mlp_hidden1, mlp_hidden2)
        self.relu_fc2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc_output = nn.Linear(mlp_hidden2, num_actions)

    def forward(self, viewcone_input, other_features_input):
        x_cnn = self.relu_conv1(self.conv1(viewcone_input))
        x_cnn = self.relu_conv2(self.conv2(x_cnn))
        x_cnn_flat = x_cnn.view(-1, self.cnn_output_flat_size)
        combined_features = torch.cat((x_cnn_flat, other_features_input), dim=1)
        x = self.relu_fc1(self.fc1_mlp(combined_features))
        x = self.dropout1(x)
        x = self.relu_fc2(self.fc2_mlp(x))
        x = self.dropout2(x)
        return self.fc_output(x)

# --- Trainable RL Agent ---
class TrainableRLAgent:
    def __init__(self, model_load_path=None, model_save_path="trained_cnn_dqn_model.pth"):
        self.device = DEVICE
        self.policy_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH, 
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE, 
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)
        self.target_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH, 
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE, 
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)
        
        if model_load_path and os.path.exists(model_load_path):
            try:
                self.policy_net.load_state_dict(torch.load(model_load_path, map_location=self.device))
            except Exception as e:
                print(f"Warning: Error loading model from {model_load_path}: {e}. Initializing new model.")
                self.policy_net.apply(self._initialize_weights)
        else:
            self.policy_net.apply(self._initialize_weights)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, alpha=PER_ALPHA)
        self.model_save_path = model_save_path
        self.t_step_episode = 0 # For UPDATE_EVERY
        self.beta = PER_BETA_START
        self.beta_increment_per_sampling = (1.0 - PER_BETA_START) / PER_BETA_FRAMES if PER_BETA_FRAMES > 0 else 0

    def _initialize_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)

    def _unpack_viewcone_tile(self, tile_value):
        return [float((tile_value >> i) & 1) for i in range(VIEWCONE_CHANNELS)] # Assumes 8 channels

    def process_observation(self, observation_dict):
        raw_viewcone = observation_dict.get("viewcone", np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8))
        if not isinstance(raw_viewcone, np.ndarray): raw_viewcone = np.array(raw_viewcone)
        if raw_viewcone.shape != (VIEWCONE_HEIGHT, VIEWCONE_WIDTH):
            # print(f"Warning: Viewcone shape mismatch. Expected ({VIEWCONE_HEIGHT},{VIEWCONE_WIDTH}), got {raw_viewcone.shape}. Using zeros.")
            padded_viewcone = np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8)
            h, w = raw_viewcone.shape
            h_min, w_min = min(h, VIEWCONE_HEIGHT), min(w, VIEWCONE_WIDTH)
            padded_viewcone[:h_min, :w_min] = raw_viewcone[:h_min, :w_min]
            raw_viewcone = padded_viewcone

        processed_viewcone_channels_data = np.zeros((VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.float32)
        for r in range(VIEWCONE_HEIGHT):
            for c in range(VIEWCONE_WIDTH):
                tile_value = raw_viewcone[r, c]
                unpacked_features = self._unpack_viewcone_tile(tile_value)
                for channel_idx in range(VIEWCONE_CHANNELS):
                    processed_viewcone_channels_data[channel_idx, r, c] = unpacked_features[channel_idx]
        
        other_features_list = []
        direction = observation_dict.get("direction", 0)
        direction_one_hot = [0.0] * 4; direction_one_hot[direction % 4] = 1.0
        other_features_list.extend(direction_one_hot)
        location = observation_dict.get("location", [0,0]); norm_x = location[0]/MAP_SIZE_X; norm_y = location[1]/MAP_SIZE_Y
        other_features_list.extend([norm_x, norm_y])
        other_features_list.append(float(observation_dict.get("scout", 0)))
        other_features_list.append(observation_dict.get("step", 0)/MAX_STEPS_PER_EPISODE)
        state_other_np = np.array(other_features_list, dtype=np.float32)
        
        return processed_viewcone_channels_data, state_other_np

    def select_action(self, state_viewcone_np, state_other_np, epsilon=0.0):
        if random.random() > epsilon:
            state_viewcone_tensor = torch.from_numpy(state_viewcone_np).float().unsqueeze(0).to(self.device)
            state_other_tensor = torch.from_numpy(state_other_np).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad(): action_values = self.policy_net(state_viewcone_tensor, state_other_tensor)
            self.policy_net.train()
            return np.argmax(action_values.cpu().data.numpy())
        return random.choice(np.arange(OUTPUT_ACTIONS))

    def step(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done):
        self.memory.add(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.t_step_episode = (self.t_step_episode + 1) % UPDATE_EVERY
        if self.t_step_episode == 0 and len(self.memory) > BATCH_SIZE:
            experiences, indices, weights = self.memory.sample(BATCH_SIZE, beta=self.beta)
            self.learn(experiences, indices, weights, GAMMA)
            self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
    
    def learn(self, experiences, indices, importance_sampling_weights, gamma):
        states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones = experiences
        q_next_actions_policy = self.policy_net(next_states_viewcone, next_states_other).detach().max(1)[1].unsqueeze(1)
        q_targets_next = self.target_net(next_states_viewcone, next_states_other).detach().gather(1, q_next_actions_policy)
        q_targets = rewards + (gamma * q_targets_next * (1 - dones))
        q_expected = self.policy_net(states_viewcone, states_other).gather(1, actions)
        td_errors = (q_targets - q_expected).abs().cpu().detach().numpy().flatten()
        self.memory.update_priorities(indices, td_errors)
        loss = (importance_sampling_weights * nn.MSELoss(reduction='none')(q_expected, q_targets)).mean()
        self.optimizer.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save_model(self):
        if self.model_save_path:
            torch.save(self.policy_net.state_dict(), self.model_save_path)
            print(f"Model saved to {self.model_save_path}")

    def reset_episode_counters(self): self.t_step_episode = 0


# --- Main Training Loop (Same structure as train_cnn_dqn_optimized_resumed_advanced.py) ---
def train_agent(env_module, num_episodes=100000, novice_track=False, load_model_from=None, save_model_to="trained_cnn_dqn_agent.pth"):
    print(f"Starting CNN DQN training: {num_episodes} episodes, Novice: {novice_track}")
    if load_model_from: print(f"Attempting to load model from: {load_model_from}")
    print(f"Models will be saved to: {save_model_to}")
    print(f"Using device: {DEVICE}")
    print(f"Epsilon will start at {EPSILON_START:.4f} and decay towards {EPSILON_END:.4f}")

    agent = TrainableRLAgent(model_load_path=load_model_from, model_save_path=save_model_to)
    scores_deque = deque(maxlen=100) # For tracking average score
    epsilon = EPSILON_START 
    global_total_steps = 0 # This should ideally be loaded if resuming, to continue epsilon decay correctly.
                           # For this setup, it resets, so epsilon decay path also resets.
    
    # If you have the global_total_steps from the previous run, load it here:
    # Example: loaded_checkpoint_data = torch.load("checkpoint.pth")
    # global_total_steps = loaded_checkpoint_data.get('global_total_steps', 0)
    # agent.beta = loaded_checkpoint_data.get('beta', PER_BETA_START)
    # epsilon = loaded_checkpoint_data.get('epsilon', EPSILON_START) # More robust way to resume epsilon

    env = env_module.env(env_wrappers=[], render_mode=None, novice=novice_track)
    my_agent_id = env.possible_agents[0] if env.possible_agents else "agent_0"

    for i_episode in range(1, num_episodes + 1):
        env.reset() 
        agent.reset_episode_counters()
        current_episode_rewards = {id: 0.0 for id in env.possible_agents}
        last_processed_exp_my_agent = {} # Stores (prev_s_vc, prev_s_other, prev_a)
        
        for pet_agent_id_turn in env.agent_iter():
            obs_raw, _, termination, truncation, info = env.last()
            # Reward for my_agent's action that led to obs_raw
            reward_for_last_action = env.rewards.get(my_agent_id, 0.0) 
            
            # Accumulate rewards for episode score tracking
            for r_ag_id, r_val in env.rewards.items():
                if r_ag_id in current_episode_rewards: current_episode_rewards[r_ag_id] += r_val
            
            done = termination or truncation
            action_to_take = None

            if pet_agent_id_turn == my_agent_id:
                obs_dict_current = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in obs_raw.items()}
                current_s_vc, current_s_other = agent.process_observation(obs_dict_current)

                # If there was a previous state and action for my_agent, complete the experience
                if my_agent_id in last_processed_exp_my_agent:
                    prev_s_vc, prev_s_other, prev_a = last_processed_exp_my_agent.pop(my_agent_id)
                    agent.step(prev_s_vc, prev_s_other, prev_a, reward_for_last_action, current_s_vc, current_s_other, done)
                
                if done:
                    action_to_take = None
                else:
                    # Agent selects an action based on the current processed state
                    action_to_take = agent.select_action(current_s_vc, current_s_other, epsilon)
                    # Store this state and action to be completed in the next turn for my_agent_id
                    last_processed_exp_my_agent[my_agent_id] = (current_s_vc, current_s_other, action_to_take)
                
                global_total_steps += 1
                # Epsilon decay logic
                if global_total_steps > MIN_EPSILON_FRAMES and epsilon > EPSILON_END : 
                    epsilon *= EPSILON_DECAY_RATE
                    epsilon = max(EPSILON_END, epsilon) # Ensure it doesn't go below EPSILON_END
                
                # Target network update
                if global_total_steps % TARGET_UPDATE_EVERY == 0 and global_total_steps > 0: 
                    agent.update_target_net()

            elif not done and env.action_space(pet_agent_id_turn) is not None: # Other agents' turns
                action_to_take = env.action_space(pet_agent_id_turn).sample() # Example: random action
            
            env.step(action_to_take) # Step the environment
        
        # End of episode
        episode_score = current_episode_rewards.get(my_agent_id, 0.0)
        scores_deque.append(episode_score)

        if i_episode % 100 == 0: # Print summary every 100 episodes
            print(f'\rEp {i_episode}\tAvgScore(100): {np.mean(scores_deque):.2f}\tEps: {epsilon:.4f}\tGlobalSteps: {global_total_steps}\tBeta: {agent.beta:.4f}')
            if save_model_to: agent.save_model() # Save model periodically
            
    env.close()
    if save_model_to: agent.save_model() # Save final model
    print(f"\nCNN DQN Training finished. Final model saved to {save_model_to if save_model_to else 'N/A'}")

if __name__ == '__main__':
    training_start_time = time.time()
    print(f"Initiating Overnight Optimized CNN DQN training at {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(training_start_time))} UTC")
    try:
        # ENSURE YOU HAVE THE FULL CLASS DEFINITIONS FOR SumTree, PrioritizedReplayBuffer, CNNDQN, TrainableRLAgent HERE
        # I've put stubs above for brevity in this response.
        from til_environment import gridworld # Ensure this is your environment module
        
        # --- Overnight Training Configuration for Advanced Track ---
        NUM_OVERNIGHT_EPISODES = 100000 
        LOAD_MODEL_PATH = "my_wargame_cnn_agent_best_100k.pth" 
        SAVE_MODEL_PATH = "my_wargame_cnn_agent_cont_100k.pth" # New save path
        NOVICE_MODE = False # For Advanced Track

        # Global Epsilon constants are now:
        # EPSILON_START = 0.05
        # EPSILON_END = 0.01
        # EPSILON_DECAY_RATE = 0.9999 (slower decay)
        # MIN_EPSILON_FRAMES = int(1e4) (decay starts after these many global steps)

        train_agent(
            gridworld, 
            num_episodes=NUM_OVERNIGHT_EPISODES,
            novice_track=NOVICE_MODE,
            load_model_from=LOAD_MODEL_PATH, 
            save_model_to=SAVE_MODEL_PATH
        )

    except ImportError:
        print("Could not import 'til_environment.gridworld'. Ensure it's accessible.")
    except FileNotFoundError as fnf_error:
        print(f"Error: Model file not found. {fnf_error}")
    except Exception as e:
        print(f"An error occurred during CNN DQN training: {e}")
        import traceback
        traceback.print_exc()
    
    total_time_seconds = time.time() - training_start_time
    print(f"Total Optimized CNN DQN training time for this session: {total_time_seconds:.2f} seconds ({total_time_seconds/3600:.2f} hours).")

Initiating Overnight Optimized CNN DQN training at 2025-05-24 17:47:14 UTC
Starting CNN DQN training: 100000 episodes, Novice: False
Attempting to load model from: my_wargame_cnn_agent_best_100k.pth
Models will be saved to: my_wargame_cnn_agent_cont_100k.pth
Using device: cuda
Epsilon will start at 0.5000 and decay towards 0.0100


  self.policy_net.load_state_dict(torch.load(model_load_path, map_location=self.device))


Ep 100	AvgScore(100): 30.98	Eps: 0.5000	GlobalSteps: 8712	Beta: 0.4127
Model saved to my_wargame_cnn_agent_cont_100k.pth
Ep 200	AvgScore(100): 37.93	Eps: 0.2514	GlobalSteps: 16876	Beta: 0.4248
Model saved to my_wargame_cnn_agent_cont_100k.pth
Ep 300	AvgScore(100): 48.14	Eps: 0.1201	GlobalSteps: 24264	Beta: 0.4356
Model saved to my_wargame_cnn_agent_cont_100k.pth
Ep 400	AvgScore(100): 39.77	Eps: 0.0573	GlobalSteps: 31655	Beta: 0.4464
Model saved to my_wargame_cnn_agent_cont_100k.pth
Ep 500	AvgScore(100): 52.78	Eps: 0.0295	GlobalSteps: 38284	Beta: 0.4561
Model saved to my_wargame_cnn_agent_cont_100k.pth
Ep 600	AvgScore(100): 46.33	Eps: 0.0145	GlobalSteps: 45429	Beta: 0.4666
Model saved to my_wargame_cnn_agent_cont_100k.pth
Ep 700	AvgScore(100): 44.66	Eps: 0.0100	GlobalSteps: 52710	Beta: 0.4772
Model saved to my_wargame_cnn_agent_cont_100k.pth
Ep 800	AvgScore(100): 46.29	Eps: 0.0100	GlobalSteps: 60085	Beta: 0.4881
Model saved to my_wargame_cnn_agent_cont_100k.pth
Ep 900	AvgScore(100): 45.