In [None]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from collections import deque, namedtuple
import time # For total training time

# --- Configuration ---
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100

INPUT_FEATURES = 288
HIDDEN_LAYER_1_SIZE = 256
HIDDEN_LAYER_2_SIZE = 256
OUTPUT_ACTIONS = 5

BUFFER_SIZE = int(1e5)
BATCH_SIZE = 32
GAMMA = 0.99
LEARNING_RATE = 1e-4
TARGET_UPDATE_EVERY = 100 # Updated based on agent's own steps
UPDATE_EVERY = 4

EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 0.995 # Per episode

PER_ALPHA = 0.6
PER_BETA_START = 0.4
PER_BETA_FRAMES = int(1e5)
PER_EPSILON = 1e-6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- SumTree for Prioritized Replay Buffer ---
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.n_entries = 0

    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity:
            self.n_entries += 1

    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_leaf(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                if value <= self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    value -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_priority(self):
        return self.tree[0]

# --- Prioritized Replay Buffer ---
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.max_priority = 1.0

    def add(self, state, action, reward, next_state, done):
        experience = Experience(state, action, reward, next_state, done)
        self.tree.add(self.max_priority, experience)

    def sample(self, batch_size, beta=PER_BETA_START):
        batch_idx = np.empty(batch_size, dtype=np.int32)
        batch_data = np.empty(batch_size, dtype=object)
        weights = np.empty(batch_size, dtype=np.float32)
        priority_segment = self.tree.total_priority / batch_size
        
        for i in range(batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority, data = self.tree.get_leaf(value)
            
            sampling_probabilities = priority / self.tree.total_priority if self.tree.total_priority > 0 else 0
            weights[i] = np.power(self.tree.n_entries * sampling_probabilities + 1e-8, -beta) # Added epsilon for stability
            batch_idx[i], batch_data[i] = index, data
        
        weights /= (weights.max() if weights.max() > 0 else 1.0) # Normalize

        states, actions, rewards, next_states, dones = zip(*[e for e in batch_data])
        states = torch.from_numpy(np.vstack(states)).float().to(DEVICE)
        actions = torch.from_numpy(np.vstack(actions)).long().to(DEVICE)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(DEVICE)
        next_states = torch.from_numpy(np.vstack(next_states)).float().to(DEVICE)
        dones = torch.from_numpy(np.vstack(dones).astype(np.uint8)).float().to(DEVICE)
        
        return (states, actions, rewards, next_states, dones), batch_idx, torch.from_numpy(weights).float().to(DEVICE)

    def update_priorities(self, batch_indices, td_errors):
        priorities = np.abs(td_errors) + PER_EPSILON
        priorities = np.power(priorities, self.alpha)
        for idx, priority in zip(batch_indices, priorities):
            self.tree.update(idx, priority)
        self.max_priority = max(self.max_priority, priorities.max() if priorities.size > 0 else self.max_priority)


    def __len__(self):
        return self.tree.n_entries

# --- Deep Q-Network (DQN) Model ---
class DQN(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim2, output_dim)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        return self.fc3(x)

# --- Trainable RL Agent ---
class TrainableRLAgent:
    def __init__(self, model_load_path=None, model_save_path="trained_dqn_model.pth"):
        self.device = DEVICE
        self.policy_net = DQN(INPUT_FEATURES, HIDDEN_LAYER_1_SIZE, HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS).to(self.device)
        self.target_net = DQN(INPUT_FEATURES, HIDDEN_LAYER_1_SIZE, HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS).to(self.device)
        
        if model_load_path and os.path.exists(model_load_path):
            try:
                self.policy_net.load_state_dict(torch.load(model_load_path, map_location=self.device))
                # print(f"Loaded pre-trained policy_net from {model_load_path}")
            except Exception as e:
                print(f"Warning: Error loading model from {model_load_path}: {e}. Initializing new model.")
                self.policy_net.apply(self._initialize_weights)
        else:
            # print(f"No model path provided or path invalid. Initializing policy_net with random weights.")
            self.policy_net.apply(self._initialize_weights)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, alpha=PER_ALPHA)
        
        self.model_save_path = model_save_path
        self.learn_step_counter = 0 # For UPDATE_EVERY
        self.total_agent_steps = 0 # For TARGET_UPDATE_EVERY
        self.beta = PER_BETA_START
        self.beta_increment_per_sampling = (1.0 - PER_BETA_START) / PER_BETA_FRAMES if PER_BETA_FRAMES > 0 else 0


    def _initialize_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)

    def _unpack_viewcone_tile(self, tile_value):
        return [float((tile_value >> i) & 1) for i in range(8)] # Simplified

    def process_observation(self, observation_dict):
        processed_features = []
        viewcone = observation_dict.get("viewcone", [])
        for r in range(7): # Rows in viewcone
            for c in range(5): # Columns in viewcone
                tile_value = viewcone[r][c] if r < len(viewcone) and c < len(viewcone[r]) else 0
                processed_features.extend(self._unpack_viewcone_tile(tile_value))
        
        direction = observation_dict.get("direction", 0)
        direction_one_hot = [0.0] * 4
        if 0 <= direction < 4: direction_one_hot[direction] = 1.0
        processed_features.extend(direction_one_hot)

        location = observation_dict.get("location", [0, 0])
        norm_x = location[0] / MAP_SIZE_X if MAP_SIZE_X > 0 else 0.0
        norm_y = location[1] / MAP_SIZE_Y if MAP_SIZE_Y > 0 else 0.0
        processed_features.extend([norm_x, norm_y])

        processed_features.append(float(observation_dict.get("scout", 0)))
        norm_step = observation_dict.get("step", 0) / MAX_STEPS_PER_EPISODE if MAX_STEPS_PER_EPISODE > 0 else 0.0
        processed_features.append(norm_step)
        
        if len(processed_features) != INPUT_FEATURES:
            raise ValueError(f"Feature length mismatch. Expected {INPUT_FEATURES}, got {len(processed_features)}")
        return np.array(processed_features, dtype=np.float32)

    def select_action(self, state_np, epsilon=0.0):
        if random.random() > epsilon:
            state_tensor = torch.from_numpy(state_np).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad(): action_values = self.policy_net(state_tensor)
            self.policy_net.train()
            return np.argmax(action_values.cpu().data.numpy())
        return random.choice(np.arange(OUTPUT_ACTIONS))

    def step(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)
        self.total_agent_steps += 1
        
        self.learn_step_counter = (self.learn_step_counter + 1) % UPDATE_EVERY
        if self.learn_step_counter == 0 and len(self.memory) > BATCH_SIZE:
            experiences, indices, weights = self.memory.sample(BATCH_SIZE, beta=self.beta)
            self.learn(experiences, indices, weights, GAMMA)
            self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)

        if self.total_agent_steps % TARGET_UPDATE_EVERY == 0:
            self.update_target_net()

    def learn(self, experiences, indices, importance_sampling_weights, gamma):
        states, actions, rewards, next_states, dones = experiences
        q_next_policy_actions = self.policy_net(next_states).detach().max(1)[1].unsqueeze(1)
        q_targets_next = self.target_net(next_states).detach().gather(1, q_next_policy_actions)
        q_targets = rewards + (gamma * q_targets_next * (1 - dones))
        q_expected = self.policy_net(states).gather(1, actions)

        td_errors = (q_targets - q_expected).abs().cpu().detach().numpy().flatten()
        self.memory.update_priorities(indices, td_errors)

        loss = (importance_sampling_weights * nn.MSELoss(reduction='none')(q_expected, q_targets)).mean()
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save_model(self):
        torch.save(self.policy_net.state_dict(), self.model_save_path)
        print(f"Model saved to {self.model_save_path}")

    def reset_state(self): pass

# --- Main Training Loop ---
def train_agent(env_module, num_episodes=2000, novice_track=False, load_model_from=None, save_model_to="trained_dqn_agent.pth"):
    print(f"Starting training: {num_episodes} episodes, Novice: {novice_track}, Load: {load_model_from}, Save: {save_model_to}")
    print(f"Using device: {DEVICE}")

    env = env_module.env(env_wrappers=[], render_mode=None, novice=novice_track)
    
    my_agent_id = env.possible_agents[0] if env.possible_agents else "agent_0" # Fallback if empty
    # print(f"Training agent ID: {my_agent_id}")

    agent = TrainableRLAgent(model_load_path=load_model_from, model_save_path=save_model_to)
    scores_deque = deque(maxlen=100)
    scores = []
    epsilon = EPSILON_START
    
    # Store (state_np, action_int) for my_agent_id, to be completed with (reward, next_state_np, done)
    pending_experience_my_agent = {} 

    for i_episode in range(1, num_episodes + 1):
        env.reset()
        agent.reset_state()
        current_episode_rewards = {agent_id: 0.0 for agent_id in env.possible_agents}
        
        for pet_agent_id in env.agent_iter():
            current_observation_raw, current_reward, termination, truncation, info = env.last()
            done = termination or truncation

            # Update reward for the agent whose turn it was *before* this env.last() call
            # env.rewards contains rewards for all agents from the *previous* env.step()
            for r_agent_id, r_value in env.rewards.items():
                 if r_agent_id in current_episode_rewards: # Ensure agent is still tracked
                    current_episode_rewards[r_agent_id] += r_value

            action_to_take = None
            if pet_agent_id == my_agent_id:
                obs_dict_current_turn = {k: v if isinstance(v, (int, float)) else np.array(v).tolist() for k, v in current_observation_raw.items()}
                
                # If there's a pending experience (S, A) for my_agent_id, complete it now.
                # S' is obs_dict_current_turn, R is current_reward (for S,A leading to S'), D is done.
                if my_agent_id in pending_experience_my_agent:
                    prev_state_np, prev_action = pending_experience_my_agent.pop(my_agent_id)
                    next_state_np = agent.process_observation(obs_dict_current_turn)
                    agent.step(prev_state_np, prev_action, current_reward, next_state_np, done)
                
                if done:
                    action_to_take = None 
                else:
                    current_state_np_for_action = agent.process_observation(obs_dict_current_turn)
                    action_to_take = agent.select_action(current_state_np_for_action, epsilon)
                    pending_experience_my_agent[my_agent_id] = (current_state_np_for_action, action_to_take)
            
            elif not done : # Other agents take random actions if not done
                if env.action_space(pet_agent_id) is not None:
                    action_to_take = env.action_space(pet_agent_id).sample()
            
            env.step(action_to_take)
        
        # After episode loop, handle any final pending experience if episode ended mid-turn for my_agent_id
        # This case should be covered if 'done' is true for my_agent_id in its last turn handling.

        episode_score = current_episode_rewards.get(my_agent_id, 0.0)
        scores_deque.append(episode_score)
        scores.append(episode_score)
        epsilon = max(EPSILON_END, EPSILON_DECAY * epsilon)

        if i_episode % 20 == 0: # Print less frequently
            print(f'\rEpisode {i_episode}\tAvg Score (100): {np.mean(scores_deque):.2f}\tEpsilon: {epsilon:.4f}\tSteps: {agent.total_agent_steps}', end="")
        if i_episode % 100 == 0:
            print(f'\rEpisode {i_episode}\tAvg Score (100): {np.mean(scores_deque):.2f}\tEpsilon: {epsilon:.4f}\tSteps: {agent.total_agent_steps}')
            if save_model_to: agent.save_model()
        
        # Example early stopping condition
        # if len(scores_deque) == 100 and np.mean(scores_deque) >= 200.0:
        #     print(f'\nEnvironment solved in {i_episode} episodes!\tAvg Score: {np.mean(scores_deque):.2f}')
        #     if save_model_to: agent.save_model()
        #     break
            
    env.close()
    if save_model_to: agent.save_model() # Save one last time
    print(f"\nTraining finished. Final model saved to {save_model_to if save_model_to else 'N/A'}")
    return scores

if __name__ == '__main__':
    training_start_time = time.time()
    try:
        from til_environment import gridworld
        # print("Successfully imported til_environment.gridworld")
        
        trained_scores = train_agent(
            gridworld, 
            num_episodes=10000, # Shorter for quick testing, increase for real training
            novice_track=False,  # Set to False for varied maps if your env supports it
            load_model_from="trained_dqn_agent_1k.pth", # "agent_166k_eps.pth", # Set to a .pth file to resume
            save_model_to="trained_dqn_agent_11k.pth"
        )

    except ImportError:
        print("Could not import 'til_environment.gridworld'. Ensure it's accessible.")
    except Exception as e:
        print(f"An error occurred during training: {e}")
        import traceback
        traceback.print_exc()

    training_end_time = time.time()
    total_time = training_end_time - training_start_time
    print(f"Total training time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")

## Gemini 2.5 v2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from collections import deque, namedtuple
import time # For total training time

# --- Configuration ---
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100

INPUT_FEATURES = 288
HIDDEN_LAYER_1_SIZE = 256
HIDDEN_LAYER_2_SIZE = 256
OUTPUT_ACTIONS = 5

BUFFER_SIZE = int(1e5)
BATCH_SIZE = 32
GAMMA = 0.99
LEARNING_RATE = 1e-4
TARGET_UPDATE_EVERY = 100 # Updated based on agent's own steps
UPDATE_EVERY = 4

EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 0.995 # Per episode

PER_ALPHA = 0.6
PER_BETA_START = 0.4
PER_BETA_FRAMES = int(1e5)
PER_EPSILON = 1e-6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- SumTree for Prioritized Replay Buffer ---
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.n_entries = 0

    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity:
            self.n_entries += 1

    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_leaf(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                if value <= self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    value -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_priority(self):
        return self.tree[0]

# --- Prioritized Replay Buffer ---
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.max_priority = 1.0

    def add(self, state, action, reward, next_state, done):
        experience = Experience(state, action, reward, next_state, done)
        self.tree.add(self.max_priority, experience)

    def sample(self, batch_size, beta=PER_BETA_START):
        batch_idx = np.empty(batch_size, dtype=np.int32)
        batch_data = np.empty(batch_size, dtype=object)
        weights = np.empty(batch_size, dtype=np.float32)
        priority_segment = self.tree.total_priority / batch_size
        
        for i in range(batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority, data = self.tree.get_leaf(value)
            
            sampling_probabilities = priority / self.tree.total_priority if self.tree.total_priority > 0 else 0
            weights[i] = np.power(self.tree.n_entries * sampling_probabilities + 1e-8, -beta) # Added epsilon for stability
            batch_idx[i], batch_data[i] = index, data
        
        weights /= (weights.max() if weights.max() > 0 else 1.0) # Normalize

        states, actions, rewards, next_states, dones = zip(*[e for e in batch_data])
        states = torch.from_numpy(np.vstack(states)).float().to(DEVICE)
        actions = torch.from_numpy(np.vstack(actions)).long().to(DEVICE)
        rewards = torch.from_numpy(np.vstack(rewards)).float().to(DEVICE)
        next_states = torch.from_numpy(np.vstack(next_states)).float().to(DEVICE)
        dones = torch.from_numpy(np.vstack(dones).astype(np.uint8)).float().to(DEVICE)
        
        return (states, actions, rewards, next_states, dones), batch_idx, torch.from_numpy(weights).float().to(DEVICE)

    def update_priorities(self, batch_indices, td_errors):
        priorities = np.abs(td_errors) + PER_EPSILON
        priorities = np.power(priorities, self.alpha)
        for idx, priority in zip(batch_indices, priorities):
            self.tree.update(idx, priority)
        self.max_priority = max(self.max_priority, priorities.max() if priorities.size > 0 else self.max_priority)


    def __len__(self):
        return self.tree.n_entries

# --- Deep Q-Network (DQN) Model ---
class DQN(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim2, output_dim)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        return self.fc3(x)

# --- Trainable RL Agent ---
class TrainableRLAgent:
    def __init__(self, model_load_path=None, model_save_path="trained_dqn_model.pth"):
        self.device = DEVICE
        self.policy_net = DQN(INPUT_FEATURES, HIDDEN_LAYER_1_SIZE, HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS).to(self.device)
        self.target_net = DQN(INPUT_FEATURES, HIDDEN_LAYER_1_SIZE, HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS).to(self.device)
        
        self.is_pretrained = False # Flag to indicate if a model was successfully loaded
        if model_load_path and os.path.exists(model_load_path):
            try:
                self.policy_net.load_state_dict(torch.load(model_load_path, map_location=self.device))
                print(f"Successfully loaded pre-trained policy_net from {model_load_path}")
                self.is_pretrained = True
            except Exception as e:
                print(f"Warning: Error loading model from {model_load_path}: {e}. Initializing new model with random weights.")
                self.policy_net.apply(self._initialize_weights)
        elif model_load_path: # Path provided but does not exist
             print(f"Model path {model_load_path} not found. Initializing policy_net with random weights.")
             self.policy_net.apply(self._initialize_weights)
        else: # No path provided
            print(f"No model load path provided. Initializing policy_net with random weights.")
            self.policy_net.apply(self._initialize_weights)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, alpha=PER_ALPHA)
        
        self.model_save_path = model_save_path
        self.learn_step_counter = 0 # For UPDATE_EVERY
        self.total_agent_steps = 0 # For TARGET_UPDATE_EVERY
        self.beta = PER_BETA_START
        self.beta_increment_per_sampling = (1.0 - PER_BETA_START) / PER_BETA_FRAMES if PER_BETA_FRAMES > 0 else 0

    def _initialize_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)

    def _unpack_viewcone_tile(self, tile_value):
        return [float((tile_value >> i) & 1) for i in range(8)] # Simplified

    def process_observation(self, observation_dict):
        processed_features = []
        viewcone = observation_dict.get("viewcone", [])
        for r in range(7): # Rows in viewcone
            for c in range(5): # Columns in viewcone
                tile_value = viewcone[r][c] if r < len(viewcone) and c < len(viewcone[r]) else 0
                processed_features.extend(self._unpack_viewcone_tile(tile_value))
        
        direction = observation_dict.get("direction", 0)
        direction_one_hot = [0.0] * 4
        if 0 <= direction < 4: direction_one_hot[direction] = 1.0
        processed_features.extend(direction_one_hot)

        location = observation_dict.get("location", [0, 0])
        norm_x = location[0] / MAP_SIZE_X if MAP_SIZE_X > 0 else 0.0
        norm_y = location[1] / MAP_SIZE_Y if MAP_SIZE_Y > 0 else 0.0
        processed_features.extend([norm_x, norm_y])

        processed_features.append(float(observation_dict.get("scout", 0)))
        norm_step = observation_dict.get("step", 0) / MAX_STEPS_PER_EPISODE if MAX_STEPS_PER_EPISODE > 0 else 0.0
        processed_features.append(norm_step)
        
        if len(processed_features) != INPUT_FEATURES:
            raise ValueError(f"Feature length mismatch. Expected {INPUT_FEATURES}, got {len(processed_features)}")
        return np.array(processed_features, dtype=np.float32)

    def select_action(self, state_np, epsilon=0.0):
        if random.random() > epsilon:
            state_tensor = torch.from_numpy(state_np).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad(): action_values = self.policy_net(state_tensor)
            self.policy_net.train()
            return np.argmax(action_values.cpu().data.numpy())
        return random.choice(np.arange(OUTPUT_ACTIONS))

    def step(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)
        self.total_agent_steps += 1
        
        self.learn_step_counter = (self.learn_step_counter + 1) % UPDATE_EVERY
        if self.learn_step_counter == 0 and len(self.memory) > BATCH_SIZE:
            experiences, indices, weights = self.memory.sample(BATCH_SIZE, beta=self.beta)
            self.learn(experiences, indices, weights, GAMMA)
            self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)

        if self.total_agent_steps % TARGET_UPDATE_EVERY == 0:
            self.update_target_net()

    def learn(self, experiences, indices, importance_sampling_weights, gamma):
        states, actions, rewards, next_states, dones = experiences
        q_next_policy_actions = self.policy_net(next_states).detach().max(1)[1].unsqueeze(1)
        q_targets_next = self.target_net(next_states).detach().gather(1, q_next_policy_actions)
        q_targets = rewards + (gamma * q_targets_next * (1 - dones))
        q_expected = self.policy_net(states).gather(1, actions)

        td_errors = (q_targets - q_expected).abs().cpu().detach().numpy().flatten()
        self.memory.update_priorities(indices, td_errors)

        loss = (importance_sampling_weights * nn.MSELoss(reduction='none')(q_expected, q_targets)).mean()
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save_model(self):
        torch.save(self.policy_net.state_dict(), self.model_save_path)
        print(f"Model saved to {self.model_save_path}")

    def reset_state(self): pass

# --- Main Training Loop ---
def train_agent(env_module, num_episodes=2000, novice_track=False, load_model_from=None, save_model_to="trained_dqn_agent.pth"):
    print(f"Starting training: {num_episodes} episodes, Novice: {novice_track}, Load: {load_model_from}, Save: {save_model_to}")
    print(f"Using device: {DEVICE}")

    env = env_module.env(env_wrappers=[], render_mode=None, novice=novice_track)
    
    my_agent_id = env.possible_agents[0] if env.possible_agents else "agent_0" # Fallback if empty
    # print(f"Training agent ID: {my_agent_id}")

    agent = TrainableRLAgent(model_load_path=load_model_from, model_save_path=save_model_to)
    scores_deque = deque(maxlen=100)
    scores = []
    
    if agent.is_pretrained:
        print(f"Resuming with loaded model. Setting epsilon to EPSILON_END: {EPSILON_END}")
        epsilon = EPSILON_END
    else:
        print(f"Starting with new or re-initialized model. Setting epsilon to EPSILON_START: {EPSILON_START}")
        epsilon = EPSILON_START
    
    pending_experience_my_agent = {} 

    for i_episode in range(1, num_episodes + 1):
        env.reset()
        agent.reset_state()
        current_episode_rewards = {agent_id: 0.0 for agent_id in env.possible_agents}
        
        for pet_agent_id in env.agent_iter():
            current_observation_raw, current_reward, termination, truncation, info = env.last()
            done = termination or truncation

            for r_agent_id, r_value in env.rewards.items():
                 if r_agent_id in current_episode_rewards: 
                    current_episode_rewards[r_agent_id] += r_value

            action_to_take = None
            if pet_agent_id == my_agent_id:
                obs_dict_current_turn = {k: v if isinstance(v, (int, float)) else np.array(v).tolist() for k, v in current_observation_raw.items()}
                # Process current observation once for use as next_state and current_state_for_action
                processed_current_state_np = agent.process_observation(obs_dict_current_turn)
                
                if my_agent_id in pending_experience_my_agent:
                    prev_state_np, prev_action = pending_experience_my_agent.pop(my_agent_id)
                    # Use the processed_current_state_np as the next_state for the completed experience
                    agent.step(prev_state_np, prev_action, current_reward, processed_current_state_np, done)
                
                if done:
                    action_to_take = None 
                else:
                    # Use the same processed_current_state_np for selecting the current action
                    action_to_take = agent.select_action(processed_current_state_np, epsilon)
                    # Store the processed state (used for action selection) as the 'state' for the next step's experience
                    pending_experience_my_agent[my_agent_id] = (processed_current_state_np, action_to_take)
            
            elif not done : 
                if env.action_space(pet_agent_id) is not None:
                    action_to_take = env.action_space(pet_agent_id).sample()
            
            env.step(action_to_take)

        episode_score = current_episode_rewards.get(my_agent_id, 0.0)
        scores_deque.append(episode_score)
        scores.append(episode_score)
        epsilon = max(EPSILON_END, EPSILON_DECAY * epsilon)

        if i_episode % 20 == 0: 
            print(f'\rEpisode {i_episode}\tAvg Score (100): {np.mean(scores_deque):.2f}\tEpsilon: {epsilon:.4f}\tSteps: {agent.total_agent_steps}', end="")
        if i_episode % 100 == 0:
            print(f'\rEpisode {i_episode}\tAvg Score (100): {np.mean(scores_deque):.2f}\tEpsilon: {epsilon:.4f}\tSteps: {agent.total_agent_steps}')
            if save_model_to: agent.save_model()
            
    env.close()
    if save_model_to: agent.save_model() 
    print(f"\nTraining finished. Final model saved to {save_model_to if save_model_to else 'N/A'}")
    return scores

if __name__ == '__main__':
    training_start_time = time.time()
    try:
        from til_environment import gridworld
        import cProfile
        import pstats

        # To profile your main training function
        profiler = cProfile.Profile()
        profiler.enable()
        trained_scores = train_agent(
            gridworld, 
            num_episodes=100, 
            novice_track=False, 
            load_model_from="trained_dqn_agent_1300.pth", 
            save_model_to="trained_dqn_agent_11k.pth"
        )
        profiler.disable()
        stats = pstats.Stats(profiler).sort_stats('cumtime') # or 'tottime'
        stats.print_stats(20) # Print top 20 time-consuming functions

    except ImportError:
        print("Could not import 'til_environment.gridworld'. Ensure it's accessible.")
    except Exception as e:
        print(f"An error occurred during training: {e}")
        import traceback
        traceback.print_exc()

    training_end_time = time.time()
    total_time = training_end_time - training_start_time
    print(f"Total training time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")

## Gemini 2.5 V3
This profiling output is very insightful! It clearly shows where the time is being spent.

Key Takeaways from the Profiler:

Environment (gridworld.py) is the Major Bottleneck:

gridworld.py:401(observe): 121.812 seconds (cumulative time). This is the single most time-consuming function.
gridworld.py:760(step): 63.757 seconds.
Helper functions within your environment like _is_visible (54.278s) and supercover_line (24.421s) are also very significant due to high call counts and their contribution to observe.
Conclusion: Optimizing the environment's observe and step methods, and their internal helpers, would yield the largest speedups. Since I can't modify gridworld.py, my suggestions will focus on your agent and training loop code (dqnv2.py).
Data Transfer to/from Device (.to() and .cpu()):

{method 'to' of 'torch._C.TensorBase' objects}: 79.240 seconds (total time). This is primarily moving data (observations, batch data) to your DEVICE (e.g., GPU).
{method 'cpu' of 'torch._C.TensorBase' objects}: 45.789 seconds (total time). This is moving data back to the CPU, for example, before converting to NumPy in select_action.
Replay Buffer Sampling (PrioritizedReplayBuffer.sample):

The sample method itself has a cumulative time of 62.627 seconds. A significant portion of this is likely the data conversion and transfer (.to()) mentioned above, as batches of experiences are prepared.
Action Selection (TrainableRLAgent.select_action):

Cumulative time of 65.778 seconds. This involves a forward pass of the policy network and the data transfers to and from the device.
Optimizations for dqnv2.py Based on Profiler:

Here are changes to your dqnv2.py script focusing on the identified bottlenecks within your agent's code:

1. Optimize Data Transfers in PrioritizedReplayBuffer.sample()

We can make the transfer of batched experiences to the device more efficient using pinned memory and non-blocking transfers. This allows the CPU to continue working while the GPU is busy with the data transfer.

2. Optimize select_action() by using torch.argmax

Instead of moving action values to the CPU to use np.argmax, we can use torch.argmax directly on the GPU tensor and then get the single action item.

Here's the modified code:


dqnv2.py
v2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
Further Recommendations (Beyond these code changes):

Optimize gridworld.py: This is crucial. Profile gridworld.py itself to pinpoint slow parts within observe and step. Vectorize operations with NumPy if there are Python loops handling grid data, optimize visibility calculations, etc.
Parallel Environment Stepping: As discussed before, running multiple environment instances in parallel to collect data is one of the most effective ways to speed up RL training when the environment is slow. This involves more significant architectural changes (e.g., using multiprocessing or libraries like Ray).
torch.compile() (PyTorch 2.0+): If you are using PyTorch 2.0 or newer, uncomment the torch.compile lines for your DQN models. This can JIT-compile your models and potentially speed up their execution significantly with minimal code change.
Mixed Precision Training: If you have a compatible NVIDIA GPU, using torch.cuda.amp for mixed-precision training can speed up GPU computations and reduce memory usage. This involves using GradScaler and autocast.
Re-profile: After these changes, re-profile to see how the bottlenecks have shifted and identify the next areas for optimization.
The changes I've made should help reduce the overhead from data transfers and action selection in your agent's code. However, the environment's performance remains the dominant factor according to your profiler.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from collections import deque, namedtuple
import time # For total training time
import cProfile 
import pstats   

# --- Configuration ---
MAP_SIZE_X = 16
MAP_SIZE_Y = 16
MAX_STEPS_PER_EPISODE = 100

INPUT_FEATURES = 288
HIDDEN_LAYER_1_SIZE = 256
HIDDEN_LAYER_2_SIZE = 256
OUTPUT_ACTIONS = 5

BUFFER_SIZE = int(1e5)
BATCH_SIZE = 32
GAMMA = 0.99
LEARNING_RATE = 1e-4
TARGET_UPDATE_EVERY = 100 # Updated based on agent's own steps
UPDATE_EVERY = 4

EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 0.995 # Per episode

PER_ALPHA = 0.6
PER_BETA_START = 0.4
PER_BETA_FRAMES = int(1e5)
PER_EPSILON = 1e-6

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- SumTree for Prioritized Replay Buffer ---
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.data_pointer = 0
        self.n_entries = 0

    def add(self, priority, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, priority)
        self.data_pointer = (self.data_pointer + 1) % self.capacity
        if self.n_entries < self.capacity:
            self.n_entries += 1

    def update(self, tree_idx, priority):
        change = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_leaf(self, value):
        parent_idx = 0
        while True:
            left_child_idx = 2 * parent_idx + 1
            right_child_idx = left_child_idx + 1
            if left_child_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                if value <= self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    value -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_priority(self):
        return self.tree[0]

# --- Prioritized Replay Buffer ---
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=PER_ALPHA):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.max_priority = 1.0

    def add(self, state, action, reward, next_state, done):
        experience = Experience(state, action, reward, next_state, done)
        self.tree.add(self.max_priority, experience)

    def sample(self, batch_size, beta=PER_BETA_START):
        batch_idx = np.empty(batch_size, dtype=np.int32)
        batch_data = np.empty(batch_size, dtype=object)
        weights_np = np.empty(batch_size, dtype=np.float32) 
        priority_segment = self.tree.total_priority / batch_size
        
        for i in range(batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority, data = self.tree.get_leaf(value)
            
            sampling_probabilities = priority / self.tree.total_priority if self.tree.total_priority > 0 else 0
            weights_np[i] = np.power(self.tree.n_entries * sampling_probabilities + 1e-8, -beta) 
            batch_idx[i], batch_data[i] = index, data
        
        weights_np /= (weights_np.max() if weights_np.max() > 0 else 1.0) 

        states_list, actions_list, rewards_list, next_states_list, dones_list = zip(*[e for e in batch_data])
        
        states_np = np.ascontiguousarray(np.vstack(states_list), dtype=np.float32)
        actions_np = np.ascontiguousarray(np.vstack(actions_list), dtype=np.int64) 
        rewards_np = np.ascontiguousarray(np.vstack(rewards_list), dtype=np.float32)
        next_states_np = np.ascontiguousarray(np.vstack(next_states_list), dtype=np.float32)
        dones_np = np.ascontiguousarray(np.vstack(dones_list).astype(np.uint8), dtype=np.float32)

        use_pinned_memory = DEVICE.type == 'cuda'

        states = torch.from_numpy(states_np).float()
        actions = torch.from_numpy(actions_np).long()
        rewards = torch.from_numpy(rewards_np).float()
        next_states = torch.from_numpy(next_states_np).float()
        dones = torch.from_numpy(dones_np).float() 
        weights = torch.from_numpy(weights_np).float()

        if use_pinned_memory:
            states = states.pin_memory()
            actions = actions.pin_memory()
            rewards = rewards.pin_memory()
            next_states = next_states.pin_memory()
            dones = dones.pin_memory()
            weights = weights.pin_memory()

        states = states.to(DEVICE, non_blocking=use_pinned_memory)
        actions = actions.to(DEVICE, non_blocking=use_pinned_memory)
        rewards = rewards.to(DEVICE, non_blocking=use_pinned_memory)
        next_states = next_states.to(DEVICE, non_blocking=use_pinned_memory)
        dones = dones.to(DEVICE, non_blocking=use_pinned_memory)
        weights = weights.to(DEVICE, non_blocking=use_pinned_memory)
        
        return (states, actions, rewards, next_states, dones), batch_idx, weights

    def update_priorities(self, batch_indices, td_errors):
        priorities = np.abs(td_errors) + PER_EPSILON
        priorities = np.power(priorities, self.alpha)
        for idx, priority_val in zip(batch_indices, priorities): 
            self.tree.update(idx, priority_val)
        self.max_priority = max(self.max_priority, priorities.max() if priorities.size > 0 else self.max_priority)

    def __len__(self):
        return self.tree.n_entries

# --- Deep Q-Network (DQN) Model ---
class DQN(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim2, output_dim)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        return self.fc3(x)

# --- Trainable RL Agent ---
class TrainableRLAgent:
    def __init__(self, model_load_path=None, model_save_path="trained_dqn_model.pth"):
        self.device = DEVICE
        self.policy_net = DQN(INPUT_FEATURES, HIDDEN_LAYER_1_SIZE, HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS).to(self.device)
        self.target_net = DQN(INPUT_FEATURES, HIDDEN_LAYER_1_SIZE, HIDDEN_LAYER_2_SIZE, OUTPUT_ACTIONS).to(self.device)
        
        self.is_pretrained = False 
        if model_load_path and os.path.exists(model_load_path):
            try:
                self.policy_net.load_state_dict(torch.load(model_load_path, map_location=self.device))
                print(f"Successfully loaded pre-trained policy_net from {model_load_path}")
                self.is_pretrained = True
            except Exception as e:
                print(f"Warning: Error loading model from {model_load_path}: {e}. Initializing new model with random weights.")
                self.policy_net.apply(self._initialize_weights)
        elif model_load_path: 
             print(f"Model path {model_load_path} not found. Initializing policy_net with random weights.")
             self.policy_net.apply(self._initialize_weights)
        else: 
            print(f"No model load path provided. Initializing policy_net with random weights.")
            self.policy_net.apply(self._initialize_weights)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        self.memory = PrioritizedReplayBuffer(BUFFER_SIZE, alpha=PER_ALPHA)
        
        self.model_save_path = model_save_path
        self.learn_step_counter = 0 
        self.total_agent_steps = 0 
        self.beta = PER_BETA_START
        self.beta_increment_per_sampling = (1.0 - PER_BETA_START) / PER_BETA_FRAMES if PER_BETA_FRAMES > 0 else 0

    def _initialize_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)

    def _unpack_viewcone_tile(self, tile_value):
        return [float((tile_value >> i) & 1) for i in range(8)] 

    def process_observation(self, observation_dict):
        processed_features = []
        viewcone = observation_dict.get("viewcone", [])
        for r in range(7): 
            for c in range(5): 
                tile_value = viewcone[r][c] if r < len(viewcone) and c < len(viewcone[r]) else 0
                processed_features.extend(self._unpack_viewcone_tile(tile_value))
        
        direction = observation_dict.get("direction", 0)
        direction_one_hot = [0.0] * 4
        if 0 <= direction < 4: direction_one_hot[direction] = 1.0
        processed_features.extend(direction_one_hot)

        location = observation_dict.get("location", [0, 0])
        norm_x = location[0] / MAP_SIZE_X if MAP_SIZE_X > 0 else 0.0
        norm_y = location[1] / MAP_SIZE_Y if MAP_SIZE_Y > 0 else 0.0
        processed_features.extend([norm_x, norm_y])

        processed_features.append(float(observation_dict.get("scout", 0)))
        norm_step = observation_dict.get("step", 0) / MAX_STEPS_PER_EPISODE if MAX_STEPS_PER_EPISODE > 0 else 0.0
        processed_features.append(norm_step)
        
        if len(processed_features) != INPUT_FEATURES:
            raise ValueError(f"Feature length mismatch. Expected {INPUT_FEATURES}, got {len(processed_features)}")
        return np.array(processed_features, dtype=np.float32)

    def select_action(self, state_np, epsilon=0.0):
        # If exploring (random.random() <= epsilon), choose a random action first
        if random.random() <= epsilon:
            return random.choice(np.arange(OUTPUT_ACTIONS))
        # Else (not exploring), use the policy network
        else:
            state_tensor = torch.from_numpy(np.ascontiguousarray(state_np)).float().unsqueeze(0).to(self.device)
            self.policy_net.eval()
            with torch.no_grad(): 
                action_values = self.policy_net(state_tensor)
            self.policy_net.train()
            return torch.argmax(action_values, dim=1).item()

    def step(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)
        self.total_agent_steps += 1
        
        self.learn_step_counter = (self.learn_step_counter + 1) % UPDATE_EVERY
        if self.learn_step_counter == 0 and len(self.memory) > BATCH_SIZE:
            experiences, indices, weights = self.memory.sample(BATCH_SIZE, beta=self.beta)
            self.learn(experiences, indices, weights, GAMMA)
            self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)

        if self.total_agent_steps % TARGET_UPDATE_EVERY == 0:
            self.update_target_net()

    def learn(self, experiences, indices, importance_sampling_weights, gamma):
        states, actions, rewards, next_states, dones = experiences
        
        with torch.no_grad(): 
            q_next_policy_actions = self.policy_net(next_states).max(1)[1].unsqueeze(1)
            q_targets_next = self.target_net(next_states).gather(1, q_next_policy_actions)
            q_targets = rewards + (gamma * q_targets_next * (1 - dones))

        q_expected = self.policy_net(states).gather(1, actions)
        td_errors_tensor = (q_targets - q_expected).abs() 
        self.memory.update_priorities(indices, td_errors_tensor.cpu().detach().numpy().flatten())
        loss = (importance_sampling_weights * nn.MSELoss(reduction='none')(q_expected, q_targets)).mean()
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def save_model(self):
        torch.save(self.policy_net.state_dict(), self.model_save_path)
        print(f"Model saved to {self.model_save_path}")

    def reset_state(self): pass

# --- Main Training Loop ---
def train_agent(env_module, num_episodes=2000, novice_track=False, load_model_from=None, save_model_to="trained_dqn_agent.pth"):
    print(f"Starting training: {num_episodes} episodes, Novice: {novice_track}, Load: {load_model_from}, Save: {save_model_to}")
    print(f"Using device: {DEVICE}")

    env = env_module.env(env_wrappers=[], render_mode=None, novice=novice_track)
    
    my_agent_id = env.possible_agents[0] if env.possible_agents else "agent_0" 

    agent = TrainableRLAgent(model_load_path=load_model_from, model_save_path=save_model_to)
    scores_deque = deque(maxlen=100)
    scores = []
    
    if agent.is_pretrained:
        print(f"Resuming with loaded model. Setting epsilon to EPSILON_END: {EPSILON_END}")
        epsilon = EPSILON_END
    else:
        print(f"Starting with new or re-initialized model. Setting epsilon to EPSILON_START: {EPSILON_START}")
        epsilon = EPSILON_START
    
    pending_experience_my_agent = {} 

    for i_episode in range(1, num_episodes + 1):
        env.reset()
        agent.reset_state()
        current_episode_rewards = {agent_id: 0.0 for agent_id in env.possible_agents}
        
        for pet_agent_id in env.agent_iter():
            current_observation_raw, current_reward, termination, truncation, info = env.last()
            done = termination or truncation

            for r_agent_id, r_value in env.rewards.items():
                 if r_agent_id in current_episode_rewards: 
                    current_episode_rewards[r_agent_id] += r_value

            action_to_take = None
            if pet_agent_id == my_agent_id:
                obs_dict_current_turn = {k: v if isinstance(v, (int, float)) else np.array(v).tolist() for k, v in current_observation_raw.items()}
                processed_current_state_np = agent.process_observation(obs_dict_current_turn)
                
                if my_agent_id in pending_experience_my_agent:
                    prev_state_np, prev_action = pending_experience_my_agent.pop(my_agent_id)
                    agent.step(prev_state_np, prev_action, current_reward, processed_current_state_np, done)
                
                if done:
                    action_to_take = None 
                else:
                    action_to_take = agent.select_action(processed_current_state_np, epsilon)
                    pending_experience_my_agent[my_agent_id] = (processed_current_state_np, action_to_take)
            
            elif not done : 
                if env.action_space(pet_agent_id) is not None:
                    action_to_take = env.action_space(pet_agent_id).sample()
            
            env.step(action_to_take)

        episode_score = current_episode_rewards.get(my_agent_id, 0.0)
        scores_deque.append(episode_score)
        scores.append(episode_score)
        if not agent.is_pretrained or i_episode > 100: 
             epsilon = max(EPSILON_END, EPSILON_DECAY * epsilon)

        if i_episode % 20 == 0: 
            print(f'\rEpisode {i_episode}\tAvg Score (100): {(np.mean(scores_deque) if scores_deque else 0.0):.2f}\tEpsilon: {epsilon:.4f}\tSteps: {agent.total_agent_steps}', end="")
        if i_episode % 100 == 0:
            print(f'\rEpisode {i_episode}\tAvg Score (100): {(np.mean(scores_deque) if scores_deque else 0.0):.2f}\tEpsilon: {epsilon:.4f}\tSteps: {agent.total_agent_steps}')
            if save_model_to: agent.save_model()
            
    env.close()
    if save_model_to: agent.save_model() 
    print(f"\nTraining finished. Final model saved to {save_model_to if save_model_to else 'N/A'}")
    return scores

if __name__ == '__main__':
    training_start_time = time.time()
    
    profiler_instance = cProfile.Profile()
    profiler_enabled_this_run = False
    
    try:
        from til_environment import gridworld 
        
        print("Attempting to enable profiler...")
        profiler_instance.enable()
        profiler_enabled_this_run = True
        print("Profiler enabled.")
        
        trained_scores = train_agent(
            gridworld, 
            num_episodes=100, 
            novice_track=False, 
            load_model_from="trained_dqn_agent_1500.pth", 
            save_model_to="trained_dqn_agent_1600.pth"
        )

    except ImportError:
        print("Could not import 'til_environment.gridworld'. Ensure it's accessible.")
    except Exception as e:
        print(f"An error occurred: {e}") 
        import traceback
        traceback.print_exc()
    finally:
        if profiler_enabled_this_run:
            print("Disabling profiler...")
            profiler_instance.disable()
            print("Profiler disabled.")
            print("Generating profiler stats...")
            stats = pstats.Stats(profiler_instance).sort_stats('cumtime') 
            stats.print_stats(20)
        elif 'profiler_instance' in locals() and hasattr(profiler_instance, '_active_count') and profiler_instance._active_count > 0:
            print("Profiler was active unexpectedly, attempting to disable...")
            profiler_instance.disable()

    training_end_time = time.time()
    total_time = training_end_time - training_start_time
    print(f"Total training time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")

Attempting to enable profiler...
Profiler enabled.
Starting training: 100 episodes, Novice: False, Load: trained_dqn_agent_1500.pth, Save: trained_dqn_agent_1600.pth
Using device: cuda
Successfully loaded pre-trained policy_net from trained_dqn_agent_1500.pth
Resuming with loaded model. Setting epsilon to EPSILON_END: 0.01
Episode 100	Avg Score (100): 14.21	Epsilon: 0.0100	Steps: 9941
Model saved to trained_dqn_agent_1600.pth
Model saved to trained_dqn_agent_1600.pth

Training finished. Final model saved to trained_dqn_agent_1600.pth
Disabling profiler...
Profiler disabled.
Generating profiler stats...
         89317443 function calls (87886765 primitive calls) in 234.716 seconds

   Ordered by: cumulative time
   List reduced from 734 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    80328   20.784    0.000  124.252    0.002 /home/jupyter/til-25-data-chefs/til-25-environment/til_environment/gridworld.py:401(observe)
    40164   