In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from collections import deque, namedtuple
import time

# --- Configuration (Same as your Script_ImprovedResumed) ---
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100
VIEWCONE_CHANNELS = 8
VIEWCONE_HEIGHT = 7
VIEWCONE_WIDTH = 5
OTHER_FEATURES_SIZE = 4 + 2 + 1 + 1

CNN_OUTPUT_CHANNELS_1 = 16
CNN_OUTPUT_CHANNELS_2 = 32
KERNEL_SIZE_1 = (3, 3)
STRIDE_1 = 1
KERNEL_SIZE_2 = (3, 3)
STRIDE_2 = 1
MLP_HIDDEN_LAYER_1_SIZE = 128
MLP_HIDDEN_LAYER_2_SIZE = 128
OUTPUT_ACTIONS = 5
DROPOUT_RATE = 0.2

BUFFER_SIZE = int(1e5)
BATCH_SIZE = 64
GAMMA = 0.99
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
TARGET_UPDATE_EVERY = 1000 # Global steps
UPDATE_EVERY = 4 # Agent steps

LR_SCHEDULER_STEP_SIZE = 30000 # Episodes
LR_SCHEDULER_GAMMA = 0.5

EPSILON_START = 0.8 # Will be overridden if loading from checkpoint
EPSILON_END = 0.1
EPSILON_DECAY_RATE = 0.9999
MIN_EPSILON_FRAMES = int(1e4)

PER_ALPHA = 0.6
PER_BETA_START = 0.4 # Will be overridden if loading from checkpoint
PER_BETA_FRAMES = int(1e5)
PER_EPSILON = 1e-6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Experience = namedtuple("Experience", field_names=["state_viewcone", "state_other", "action", "reward", "next_state_viewcone", "next_state_other", "done"])

# --- SumTree Class (Identical to previous script) ---
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.n_entries = 0
    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity:
            self.n_entries += 1
    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change
    def get_leaf(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                if value <= self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    value -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
    @property
    def total_priority(self):
        return self.tree[0]

# --- PrioritizedReplayBuffer Class (Identical to previous script) ---
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.max_priority = 1.0
    def add(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done):
        experience = Experience(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.tree.add(self.max_priority, experience)
    def sample(self, batch_size, beta=PER_BETA_START):
        batch_idx = np.empty(batch_size, dtype=np.int32)
        batch_data = np.empty(batch_size, dtype=object)
        weights = np.empty(batch_size, dtype=np.float32)
        priority_segment = self.tree.total_priority / batch_size if batch_size > 0 and self.tree.total_priority > 0 else 0
        if self.tree.n_entries < batch_size :
             return (torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0)), np.array([]), torch.empty(0)
        for i in range(batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority, data = self.tree.get_leaf(value)
            sampling_probabilities = priority / self.tree.total_priority if self.tree.total_priority > 0 else 0
            weights[i] = np.power(self.tree.n_entries * sampling_probabilities + 1e-8, -beta)
            batch_idx[i], batch_data[i] = index, data
        weights /= (weights.max() if weights.max() > 0 else 1.0)
        states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones = zip(*[e for e in batch_data if e is not None])
        if not states_viewcone:
            return (torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0), torch.empty(0)), np.array([]), torch.empty(0)
        states_viewcone = torch.from_numpy(np.array(states_viewcone)).float().to(DEVICE)
        states_other = torch.from_numpy(np.array(states_other)).float().to(DEVICE)
        actions = torch.from_numpy(np.vstack(actions)).long().to(DEVICE)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(DEVICE)
        next_states_viewcone = torch.from_numpy(np.array(next_states_viewcone)).float().to(DEVICE)
        next_states_other = torch.from_numpy(np.array(next_states_other)).float().to(DEVICE)
        dones = torch.from_numpy(np.vstack(dones).astype(np.uint8)).float().to(DEVICE)
        return (states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones), batch_idx, torch.from_numpy(weights).float().to(DEVICE)
    def update_priorities(self, batch_indices, td_errors):
        if len(batch_indices) == 0: return
        priorities = np.abs(td_errors) + PER_EPSILON
        priorities = np.power(priorities, self.alpha)
        for idx, priority_val in zip(batch_indices, priorities):
            self.tree.update(idx, priority_val)
        if priorities.size > 0:
            self.max_priority = max(self.max_priority, np.max(priorities))
    def __len__(self):
        return self.tree.n_entries

# --- CNNDQN Class (Identical to previous script) ---
class CNNDQN(nn.Module):
    def __init__(self, viewcone_channels, viewcone_height, viewcone_width, other_features_size, mlp_hidden1, mlp_hidden2, num_actions, dropout_rate):
        super(CNNDQN, self).__init__()
        self.conv1 = nn.Conv2d(viewcone_channels, CNN_OUTPUT_CHANNELS_1, kernel_size=KERNEL_SIZE_1, stride=STRIDE_1, padding=1)
        self.relu_conv1 = nn.ReLU()
        h_out1 = (viewcone_height + 2 * 1 - KERNEL_SIZE_1[0]) // STRIDE_1 + 1
        w_out1 = (viewcone_width + 2 * 1 - KERNEL_SIZE_1[1]) // STRIDE_1 + 1
        self.conv2 = nn.Conv2d(CNN_OUTPUT_CHANNELS_1, CNN_OUTPUT_CHANNELS_2, kernel_size=KERNEL_SIZE_2, stride=STRIDE_2, padding=1)
        self.relu_conv2 = nn.ReLU()
        h_out2 = (h_out1 + 2 * 1 - KERNEL_SIZE_2[0]) // STRIDE_2 + 1
        w_out2 = (w_out1 + 2 * 1 - KERNEL_SIZE_2[1]) // STRIDE_2 + 1
        self.cnn_output_flat_size = CNN_OUTPUT_CHANNELS_2 * h_out2 * w_out2
        self.fc1_mlp = nn.Linear(self.cnn_output_flat_size + other_features_size, mlp_hidden1)
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2_mlp = nn.Linear(mlp_hidden1, mlp_hidden2)
        self.relu_fc2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc_output = nn.Linear(mlp_hidden2, num_actions)
    def forward(self, viewcone_input, other_features_input):
        x_cnn = self.relu_conv1(self.conv1(viewcone_input))
        x_cnn = self.relu_conv2(self.conv2(x_cnn))
        x_cnn_flat = x_cnn.view(-1, self.cnn_output_flat_size)
        combined_features = torch.cat((x_cnn_flat, other_features_input), dim=1)
        x = self.relu_fc1(self.fc1_mlp(combined_features))
        x = self.dropout1(x)
        x = self.relu_fc2(self.fc2_mlp(x))
        x = self.dropout2(x)
        return self.fc_output(x)

# --- TrainableRLAgent Class (Modified for checkpointing awareness, though logic is in train_agent) ---
class TrainableRLAgent:
    def __init__(self, model_load_path=None, model_save_path="trained_cnn_dqn_model.pth"): # model_save_path is for FINAL model
        self.device = DEVICE
        self.policy_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH,
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE,
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)
        self.target_net = CNNDQN(VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH,
                                 OTHER_FEATURES_SIZE, MLP_HIDDEN_LAYER_1_SIZE,
                                 MLP_HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS, DROPOUT_RATE).to(self.device)

        # Initial model loading (policy_net only) if model_load_path is provided AND no checkpoint is loaded later
        self.initial_model_load_path = model_load_path
        if self.initial_model_load_path and os.path.exists(self.initial_model_load_path):
            try:
                print(f"TrainableRLAgent: Attempting to load initial policy_net weights from {self.initial_model_load_path}")
                # Load with strict=False to allow for optimizer/scheduler mismatch if only model is loaded
                self.policy_net.load_state_dict(torch.load(self.initial_model_load_path, map_location=self.device), strict=False)
            except Exception as e:
                print(f"TrainableRLAgent: Warning: Error loading initial policy_net from {self.initial_model_load_path}: {e}. Initializing new model.")
                self.policy_net.apply(self._initialize_weights)
        else:
            if self.initial_model_load_path: print(f"TrainableRLAgent: Warning: Initial model file not found at {self.initial_model_load_path}. Initializing new model.")
            else: print("TrainableRLAgent: No initial model_load_path specified. Initializing new model.")
            self.policy_net.apply(self._initialize_weights)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                      step_size=LR_SCHEDULER_STEP_SIZE,
                                                      gamma=LR_SCHEDULER_GAMMA)
        self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, alpha=PER_ALPHA)
        self.model_final_save_path = model_save_path # Renamed to avoid confusion with checkpoint
        self.t_step_episode = 0
        self.beta = PER_BETA_START
        self.beta_increment_per_sampling = (1.0 - PER_BETA_START) / PER_BETA_FRAMES if PER_BETA_FRAMES > 0 else 0
        self.current_loss = 0.0
    def _initialize_weights(self, m): # Identical
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
    def _unpack_viewcone_tile(self, tile_value): # Identical
        return [float((tile_value >> i) & 1) for i in range(VIEWCONE_CHANNELS)]
    def process_observation(self, observation_dict): # Identical
        raw_viewcone = observation_dict.get("viewcone", np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8))
        if not isinstance(raw_viewcone, np.ndarray): raw_viewcone = np.array(raw_viewcone)
        if raw_viewcone.shape != (VIEWCONE_HEIGHT, VIEWCONE_WIDTH):
            padded_viewcone = np.zeros((VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.uint8)
            h, w = raw_viewcone.shape; h_min, w_min = min(h, VIEWCONE_HEIGHT), min(w, VIEWCONE_WIDTH)
            padded_viewcone[:h_min, :w_min] = raw_viewcone[:h_min, :w_min]; raw_viewcone = padded_viewcone
        processed_viewcone_channels_data = np.zeros((VIEWCONE_CHANNELS, VIEWCONE_HEIGHT, VIEWCONE_WIDTH), dtype=np.float32)
        for r in range(VIEWCONE_HEIGHT):
            for c in range(VIEWCONE_WIDTH):
                tile_value = raw_viewcone[r, c]; unpacked_features = self._unpack_viewcone_tile(tile_value)
                for channel_idx in range(VIEWCONE_CHANNELS): processed_viewcone_channels_data[channel_idx, r, c] = unpacked_features[channel_idx]
        other_features_list = []
        direction = observation_dict.get("direction", 0); direction_one_hot = [0.0] * 4; direction_one_hot[direction % 4] = 1.0
        other_features_list.extend(direction_one_hot)
        location = observation_dict.get("location", [0,0]); norm_x = location[0]/MAP_SIZE_X; norm_y = location[1]/MAP_SIZE_Y
        other_features_list.extend([norm_x, norm_y])
        other_features_list.append(float(observation_dict.get("scout", 0)))
        other_features_list.append(observation_dict.get("step", 0)/MAX_STEPS_PER_EPISODE)
        state_other_np = np.array(other_features_list, dtype=np.float32)
        return processed_viewcone_channels_data, state_other_np
    def select_action(self, state_viewcone_np, state_other_np, epsilon=0.0): # Identical
        if random.random() > epsilon:
            state_viewcone_tensor = torch.from_numpy(state_viewcone_np).float().unsqueeze(0).to(self.device)
            state_other_tensor = torch.from_numpy(state_other_np).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad(): action_values = self.policy_net(state_viewcone_tensor, state_other_tensor)
            self.policy_net.train()
            return np.argmax(action_values.cpu().data.numpy())
        return random.choice(np.arange(OUTPUT_ACTIONS))
    def step(self, state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done): # Identical
        self.memory.add(state_viewcone, state_other, action, reward, next_state_viewcone, next_state_other, done)
        self.t_step_episode = (self.t_step_episode + 1) % UPDATE_EVERY
        if self.t_step_episode == 0 and len(self.memory) >= BATCH_SIZE:
            experiences, indices, weights = self.memory.sample(BATCH_SIZE, beta=self.beta)
            if experiences[0].nelement() > 0:
                 self.learn(experiences, indices, weights, GAMMA)
            self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)
    def learn(self, experiences, indices, importance_sampling_weights, gamma): # Identical
        states_viewcone, states_other, actions, rewards, next_states_viewcone, next_states_other, dones = experiences
        if states_viewcone.nelement() == 0: return
        q_next_actions_policy = self.policy_net(next_states_viewcone, next_states_other).detach().max(1)[1].unsqueeze(1)
        q_targets_next = self.target_net(next_states_viewcone, next_states_other).detach().gather(1, q_next_actions_policy)
        q_targets = rewards + (gamma * q_targets_next * (1 - dones))
        q_expected = self.policy_net(states_viewcone, states_other).gather(1, actions)
        td_errors = (q_targets - q_expected).abs().cpu().detach().numpy().flatten()
        self.memory.update_priorities(indices, td_errors)
        loss = (importance_sampling_weights * nn.functional.mse_loss(q_expected, q_targets, reduction='none')).mean()
        self.current_loss = loss.item()
        self.optimizer.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()
    def update_target_net(self): # Identical
        self.target_net.load_state_dict(self.policy_net.state_dict())
        print(f"\nTarget network updated. Current LR: {self.optimizer.param_groups[0]['lr']:.2e}")
    def save_final_model(self): # Renamed from save_model for clarity
        if self.model_final_save_path:
            torch.save(self.policy_net.state_dict(), self.model_final_save_path)
            print(f"Final model saved to {self.model_final_save_path}")
    def reset_episode_counters(self): self.t_step_episode = 0

# --- Main Training Loop (Modified for Checkpointing) ---
def train_agent(env_module, num_episodes=100000, novice_track=False,
                initial_model_load_path=None, # For initial policy weights if no checkpoint
                final_model_save_path="trained_cnn_dqn_agent.pth",
                checkpoint_path="training_checkpoint.pth"): # Path for saving/loading full checkpoints

    print(f"Starting CNN DQN training: {num_episodes} episodes, Novice: {novice_track}")
    print(f"Final models will be saved to: {final_model_save_path}")
    print(f"Checkpoints will be saved to: {checkpoint_path}")
    print(f"Using device: {DEVICE}")

    agent = TrainableRLAgent(model_load_path=initial_model_load_path, model_save_path=final_model_save_path)
    
    scores_deque = deque(maxlen=100)
    epsilon = EPSILON_START
    global_total_steps = 0
    start_episode = 1

    # --- Load from Checkpoint if available ---
    if checkpoint_path and os.path.exists(checkpoint_path):
        try:
            print(f"Loading checkpoint from {checkpoint_path}...")
            checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
            
            agent.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
            agent.target_net.load_state_dict(checkpoint['target_net_state_dict'])
            agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            agent.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
            
            start_episode = checkpoint['episode'] + 1
            global_total_steps = checkpoint['global_total_steps']
            epsilon = checkpoint['epsilon']
            agent.beta = checkpoint['beta']
            # Ensure memory is not re-initialized if we were to save/load it.
            # For now, memory rebuilds.
            
            # Load scores_deque if available in checkpoint
            if 'scores_deque' in checkpoint:
                scores_deque = deque(checkpoint['scores_deque'], maxlen=100)

            print(f"Resuming training from episode {start_episode}, global steps {global_total_steps}, epsilon {epsilon:.4f}, beta {agent.beta:.4f}")
            print(f"Optimizer and LR Scheduler states loaded. Current LR: {agent.optimizer.param_groups[0]['lr']:.2e}")

        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting fresh or from initial_model_load_path if provided.")
            # If checkpoint loading fails, agent will use weights from initial_model_load_path or random.
            # Other variables (epsilon, global_total_steps, etc.) remain at their initial values.
    elif initial_model_load_path:
         print(f"No checkpoint found at {checkpoint_path}. Using initial model from {initial_model_load_path} if provided in agent.")
    else:
        print(f"No checkpoint found at {checkpoint_path} and no initial_model_load_path. Starting training from scratch.")


    print(f"Epsilon for learning agent will start/resume at {epsilon:.4f} and decay towards {EPSILON_END:.4f}")
    print(f"Learning rate scheduler: StepLR, step_size={LR_SCHEDULER_STEP_SIZE} episodes, gamma={LR_SCHEDULER_GAMMA}")

    env = env_module.env(env_wrappers=[], render_mode=None, novice=novice_track)
    my_agent_id = env.possible_agents[0] if env.possible_agents else "agent_0"
    print(f"Primary learning agent ID: {my_agent_id}")

    for i_episode in range(start_episode, num_episodes + 1): # Start from start_episode
        env.reset()
        agent.reset_episode_counters()
        current_episode_rewards_accumulator = {id: 0.0 for id in env.possible_agents}
        last_processed_exp_learning_agent = {}

        for agent_id_turn in env.agent_iter():
            current_obs_raw, reward_for_current_agent_turn, termination, truncation, info = env.last()
            if agent_id_turn in current_episode_rewards_accumulator:
                 current_episode_rewards_accumulator[agent_id_turn] += reward_for_current_agent_turn
            done = termination or truncation
            action_to_take = None

            if done:
                if agent_id_turn == my_agent_id and my_agent_id in last_processed_exp_learning_agent:
                    prev_s_vc, prev_s_other, prev_a = last_processed_exp_learning_agent.pop(my_agent_id)
                    obs_dict_terminal = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in current_obs_raw.items()} if current_obs_raw else {}
                    terminal_s_vc, terminal_s_other = agent.process_observation(obs_dict_terminal) if current_obs_raw else (np.zeros_like(prev_s_vc), np.zeros_like(prev_s_other))
                    agent.step(prev_s_vc, prev_s_other, prev_a, reward_for_current_agent_turn, terminal_s_vc, terminal_s_other, True)
            else:
                if current_obs_raw is None:
                    action_to_take = env.action_space(agent_id_turn).sample() if env.action_space(agent_id_turn) else None
                else:
                    obs_dict_current = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in current_obs_raw.items()}
                    current_s_vc, current_s_other = agent.process_observation(obs_dict_current)
                    if agent_id_turn == my_agent_id:
                        if my_agent_id in last_processed_exp_learning_agent:
                            prev_s_vc, prev_s_other, prev_a = last_processed_exp_learning_agent.pop(my_agent_id)
                            agent.step(prev_s_vc, prev_s_other, prev_a, reward_for_current_agent_turn, current_s_vc, current_s_other, False)
                        action_to_take = agent.select_action(current_s_vc, current_s_other, epsilon)
                        last_processed_exp_learning_agent[my_agent_id] = (current_s_vc, current_s_other, action_to_take)
                        global_total_steps += 1
                        if global_total_steps > MIN_EPSILON_FRAMES and epsilon > EPSILON_END :
                            epsilon *= EPSILON_DECAY_RATE
                            epsilon = max(EPSILON_END, epsilon)
                        if global_total_steps % TARGET_UPDATE_EVERY == 0 and global_total_steps > 0:
                            agent.update_target_net()
                    else:
                        action_to_take = agent.select_action(current_s_vc, current_s_other, epsilon=0.2)
            env.step(action_to_take)

        agent.lr_scheduler.step()
        episode_score_my_agent = current_episode_rewards_accumulator.get(my_agent_id, 0.0)
        scores_deque.append(episode_score_my_agent)

        if i_episode % 100 == 0:
            avg_score_str = f"{np.mean(scores_deque):.2f}" if scores_deque else "N/A"
            current_lr = agent.optimizer.param_groups[0]['lr']
            print(f'\rEp {i_episode}/{num_episodes}\tAvgScore: {avg_score_str}\tEps: {epsilon:.4f}\tLR: {current_lr:.2e}\tLoss: {agent.current_loss:.4f}\tBeta: {agent.beta:.3f}\tSteps: {global_total_steps}')
            
            # --- Save Checkpoint ---
            if checkpoint_path:
                checkpoint_data = {
                    'episode': i_episode,
                    'global_total_steps': global_total_steps,
                    'epsilon': epsilon,
                    'beta': agent.beta,
                    'policy_net_state_dict': agent.policy_net.state_dict(),
                    'target_net_state_dict': agent.target_net.state_dict(),
                    'optimizer_state_dict': agent.optimizer.state_dict(),
                    'lr_scheduler_state_dict': agent.lr_scheduler.state_dict(),
                    'scores_deque': list(scores_deque) # Save recent scores
                }
                torch.save(checkpoint_data, checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path} at episode {i_episode}")
            
            # Save final model periodically as well (optional, checkpoint is more comprehensive)
            # agent.save_final_model()

    env.close()
    agent.save_final_model() # Save the final model at the end of training
    print(f"\nCNN DQN Training finished. Final model saved to {agent.model_final_save_path if agent.model_final_save_path else 'N/A'}")

if __name__ == '__main__':
    training_start_time = time.time()
    print(f"Initiating CNN DQN training with Checkpointing at {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(training_start_time))} UTC")
    
    # --- PATHS FOR RESUMING TRAINING ---
    MODEL_VERSION_TAG = "best_100000_resumed_chkpt" # Example tag for this run

    # Path for initial policy network weights (only used if checkpoint_path doesn't exist or fails to load)
    INITIAL_POLICY_LOAD_PATH = f"my_wargame_cnn_agent_best.pth" 
    
    # Path for the final saved model after this training session completes
    FINAL_MODEL_SAVE_PATH = f"my_wargame_cnn_agent_{MODEL_VERSION_TAG}.pth"
    
    # Path for comprehensive training checkpoints
    CHECKPOINT_SAVE_LOAD_PATH = f"training_checkpoint_{MODEL_VERSION_TAG}.pth"
    
    NUM_TOTAL_EPISODES = 200000 # Target total episodes (adjust if resuming from a specific point)
                                # The script will run for (NUM_TOTAL_EPISODES - checkpoint_episode) more episodes.
                                # Or simply set num_episodes in train_agent to how many MORE episodes you want to run.
    
    # For clarity, let's define how many *additional* episodes to run in this session
    NUM_ADDITIONAL_EPISODES_THIS_SESSION = 100000


    NOVICE_MODE = False 

    try:
        from til_environment import gridworld
        
        train_agent(
            gridworld,
            num_episodes=NUM_ADDITIONAL_EPISODES_THIS_SESSION, # How many more episodes to run
            novice_track=NOVICE_MODE,
            initial_model_load_path=INITIAL_POLICY_LOAD_PATH, # For initial weights if no checkpoint
            final_model_save_path=FINAL_MODEL_SAVE_PATH,    # Where the final model goes
            checkpoint_path=CHECKPOINT_SAVE_LOAD_PATH       # For saving/loading full state
        )

    except ImportError:
        print("Could not import 'til_environment.gridworld'. Ensure it's accessible and contains your PettingZoo environment.")
    except FileNotFoundError as fnf_error:
        print(f"Error: A model or checkpoint file not found. {fnf_error}")
    except Exception as e:
        print(f"An error occurred during CNN DQN training: {e}")
        import traceback
        traceback.print_exc()
    
    total_time_seconds = time.time() - training_start_time
    print(f"Total CNN DQN training time for this session: {total_time_seconds:.2f} seconds ({total_time_seconds/3600:.2f} hours).")


Initiating CNN DQN training with Checkpointing at 2025-05-24 12:37:26 UTC
Starting CNN DQN training: 100000 episodes, Novice: False
Final models will be saved to: my_wargame_cnn_agent_best_100000_resumed_chkpt.pth
Checkpoints will be saved to: training_checkpoint_best_100000_resumed_chkpt.pth
Using device: cuda
TrainableRLAgent: Attempting to load initial policy_net weights from my_wargame_cnn_agent_best.pth
Loading checkpoint from training_checkpoint_best_100000_resumed_chkpt.pth...
Resuming training from episode 3601, global steps 287119, epsilon 0.1000, beta 0.8273
Optimizer and LR Scheduler states loaded. Current LR: 1.00e-04
Epsilon for learning agent will start/resume at 0.1000 and decay towards 0.1000
Learning rate scheduler: StepLR, step_size=30000 episodes, gamma=0.5
Primary learning agent ID: player_0

Target network updated. Current LR: 1.00e-04

Target network updated. Current LR: 1.00e-04

Target network updated. Current LR: 1.00e-04

Target network updated. Current LR: 1.