In [8]:
class TILAIEnv:
    def __init__(self, seed=None):
        """Initialize the TIL-AI environment."""
        self.grid_size = 16
        self.max_steps = 100
        self.obstacles = set()  # Will be initialized in reset()
        self.reset(seed=seed)
        
    def reset(self, seed=None):
        """Reset the environment for a new episode."""
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)

        self.step_count = 0
        self.is_scout = bool(random.randint(0, 1))  # scout is 1, guard is 0
        self.direction = random.randint(0, 3)  # 0: right, 1: down, 2: left, 3: up
        
        # Create obstacles - but not too many to ensure agent can move
        self.obstacles = set()
        for _ in range(30):
            self.obstacles.add((random.randint(0, 15), random.randint(0, 15)))
        
        # Place agent at a location without obstacles
        while True:
            self.location = [random.randint(0, 15), random.randint(0, 15)]
            if tuple(self.location) not in self.obstacles:
                break
        
        # Generate recon points and missions, excluding obstacle and agent locations
        self.recon_points = set()
        self.missions = set()
        
        for _ in range(100):
            while True:
                point = (random.randint(0, 15), random.randint(0, 15))
                if point not in self.obstacles and point != tuple(self.location):
                    self.recon_points.add(point)
                    break
                    
        for _ in range(20):
            while True:
                point = (random.randint(0, 15), random.randint(0, 15))
                if point not in self.obstacles and point != tuple(self.location) and point not in self.recon_points:
                    self.missions.add(point)
                    break
        
        self.visited = set()
        self.done = False
        return self._get_obs()

    def _get_obs(self):
        """Generate observation based on agent's current state."""
        # Create a simplified viewcone
        viewcone = np.zeros((7, 5), dtype=np.uint8)
        
        # Add some simple content to the viewcone based on agent's surroundings
        x, y = self.location
        direction = self.direction
        
        # Fill the viewcone with simplified information
        # This is a simplified implementation - the actual competition will have more complex logic
        for i in range(7):
            for j in range(5):
                # Calculate relative position in the grid
                dx = j - 2  # -2 to 2 (left to right)
                dy = i - 2  # -2 to 4 (back to front, with more visibility forward)
                
                # Rotate based on agent's direction
                if direction == 0:  # right
                    nx, ny = x + dx, y + dy - 2  # Adjust forward visibility
                elif direction == 1:  # down
                    nx, ny = x - dy + 2, y + dx  # Rotate 90° clockwise
                elif direction == 2:  # left
                    nx, ny = x - dx, y - dy + 2  # Rotate 180°
                elif direction == 3:  # up
                    nx, ny = x + dy - 2, y - dx  # Rotate 270° clockwise
                
                # Set value based on what's at this position
                if 0 <= nx < 16 and 0 <= ny < 16:  # Within grid bounds
                    if (nx, ny) in self.obstacles:
                        viewcone[i, j] = 129  # Empty tile (1) with walls (128)
                    elif (nx, ny) in self.recon_points:
                        viewcone[i, j] = 2  # Recon point
                    elif (nx, ny) in self.missions:
                        viewcone[i, j] = 3  # Mission
                    else:
                        viewcone[i, j] = 1  # Empty tile
                else:
                    viewcone[i, j] = 0  # No vision (out of bounds)
        
        return {
            "viewcone": viewcone.tolist(),
            "direction": self.direction,
            "location": self.location,
            "scout": int(self.is_scout),
            "step": self.step_count
        }

    def _min_manhattan_distance(self):
        """Calculate minimum Manhattan distance to relevant targets."""
        targets = self.recon_points if self.is_scout else self.missions
        if not targets:
            return 0 
        return min(abs(self.location[0] - tx) + abs(self.location[1] - ty) for tx, ty in targets)

    def step(self, action):
        """Take a step in the environment based on the agent's action."""
        if self.done:
            return self._get_obs(), 0, True

        self.step_count += 1
        
        # Store original location for debugging
        original_location = self.location.copy()
        
        # Process the action
        dx, dy = 0, 0
        if action == 0:  # Move forward
            dx, dy = self._move_vector(self.direction)
        elif action == 1:  # Move backward
            dx, dy = self._move_vector((self.direction + 2) % 4)
        elif action == 2:  # Turn left
            self.direction = (self.direction - 1) % 4
        elif action == 3:  # Turn right
            self.direction = (self.direction + 1) % 4
        # Action 4 is stay (do nothing)

        # Calculate new location
        if action in [0, 1]:  # Only move for forward/backward actions
            new_x = np.clip(self.location[0] + dx, 0, 15)
            new_y = np.clip(self.location[1] + dy, 0, 15)
            new_loc = (new_x, new_y)
            
            # Only update if not blocked by an obstacle
            if new_loc not in self.obstacles:
                self.location = [new_x, new_y]
        
        # Get current location as tuple for easier checking
        loc_tuple = tuple(self.location)
        
        # Base reward slightly negative to encourage efficient paths
        reward = -0.01
        
        # Penalty for revisiting locations
        if loc_tuple in self.visited:
            reward -= 0.01
        else:
            self.visited.add(loc_tuple)

        # Role-specific rewards
        if self.is_scout:
            # Scout collects recon points
            if loc_tuple in self.recon_points:
                reward += 1  # Match competition reward
                self.recon_points.remove(loc_tuple)
                
            # Scout completes missions
            if loc_tuple in self.missions:
                reward += 5  # Match competition reward
                self.missions.remove(loc_tuple)
                
            # Small chance of capture (game ending for scout)
            if random.random() < 0.01:
                reward -= 50  # Match competition punishment
                self.done = True
                
        else:  # Guard
            # Guard captures scout (simulation)
            if random.random() < 0.01:
                reward += 50  # Match competition reward
                self.done = True

        # Add small reward for getting closer to objectives
        reward += 0.01 * (1 / (1 + self._min_manhattan_distance()))
        
        # Check for episode termination
        if self.step_count >= self.max_steps:
            self.done = True

        return self._get_obs(), reward, self.done

    def _move_vector(self, direction):
        """Get the movement vector for a given direction."""
        return [(1, 0), (0, 1), (-1, 0), (0, -1)][direction]

    def render(self):
        """Render the current state of the environment."""
        grid = [['.' for _ in range(self.grid_size)] for _ in range(self.grid_size)]
        
        # Add obstacles, recon points, and missions to the grid
        for ox, oy in self.obstacles:
            grid[oy][ox] = 'X'
        
        for rx, ry in self.recon_points:
            grid[ry][rx] = 'R'
            
        for mx, my in self.missions:
            grid[my][mx] = 'M'
        
        # Add agent to the grid
        x, y = self.location
        grid[y][x] = 'S' if self.is_scout else 'G'
        
        # Print the grid
        print("\n".join(" ".join(row) for row in grid))
        print(f"Direction: {['Right', 'Down', 'Left', 'Up'][self.direction]}")
        print(f"Step: {self.step_count}/{self.max_steps}")
        print()
    
    # NEW HELPER METHOD: Convert observation to tensor state for the model
    def obs_to_tensor(self, obs):
        """Convert observation dict to flattened tensor state for the model."""
        # Extract components from observation
        viewcone = torch.tensor(obs["viewcone"], dtype=torch.float32).flatten()
        direction = F.one_hot(torch.tensor(obs["direction"]), num_classes=4).float()
        location = torch.tensor(obs["location"], dtype=torch.float32) / 15.0  # Normalize
        scout = torch.tensor([obs["scout"]], dtype=torch.float32)
        step = torch.tensor([obs["step"] / self.max_steps], dtype=torch.float32)  # Normalize
        
        # Concatenate all features
        state = torch.cat([viewcone, direction, location, scout, step])
        return state


# Implement the Attention Layer as recommended
class AttentionLayer(nn.Module):
    def __init__(self, input_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        # Ensure x is 2D for the attention mechanism
        if x.dim() == 1:
            x = x.unsqueeze(0)
        weights = self.attention(x)
        return x * weights  # Element-wise multiplication


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
import time
import math
import copy

# Attention Layer for enhanced feature focus
class AttentionLayer(nn.Module):
    def __init__(self, feature_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, 64),
            nn.ReLU(),
            nn.Linear(64, feature_dim),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        # Ensure x is 2D for attention
        if x.dim() == 1:
            x = x.unsqueeze(0)
        weights = self.attention(x)
        return x * weights

# Enhanced DQN architecture with attention and residual connections
class EnhancedDQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(EnhancedDQN, self).__init__()
        
        # Attention mechanism for focusing on important state features
        self.attention = AttentionLayer(state_dim)
        
        # First feature extraction layer
        self.feature1 = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.LayerNorm(256),  # LayerNorm for better stability
            nn.ReLU(),
            nn.Dropout(0.1)  # Add dropout for regularization
        )
        
        # Second feature extraction layer with residual connection
        self.feature2 = nn.Sequential(
            nn.Linear(256, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Value stream
        self.value = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        # Advantage stream
        self.advantage = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )
        
        # Initialize weights for better performance
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Apply attention mechanism
        x = self.attention(x)
        
        # Feature extraction with residual connection
        f1 = self.feature1(x)
        f2 = self.feature2(f1)
        features = f1 + f2  # Residual connection
        
        # Value and advantage streams
        value = self.value(features)
        adv = self.advantage(features)
        
        # Dueling architecture combination
        return value + adv - adv.mean(1, keepdim=True)

# Enhanced Prioritized Replay Buffer with better prioritization
class StabilizedReplayBuffer:
    def __init__(self, capacity=100000, alpha=0.6, beta_start=0.4, beta_frames=100000):
        self.capacity = capacity
        # Separate buffers for scout and guard
        self.buffers = {
            0: {'data': [], 'priorities': []},  # Guard
            1: {'data': [], 'priorities': []}   # Scout
        }
        self.pos = {0: 0, 1: 0}
        self.alpha = alpha
        # Beta annealing for importance sampling
        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame_idx = 0
        
    def get_beta(self):
        """Calculate current beta value for importance sampling with slower annealing"""
        fraction = min(self.frame_idx / self.beta_frames, 1.0)
        # Slower annealing curve
        return min(self.beta_start + (fraction**0.75) * (1.0 - self.beta_start), 1.0)
        
    def push(self, s, a, r, ns, d, role):
        """Add experience to the buffer with improved priority assignment"""
        self.frame_idx += 1
        buffer = self.buffers[role]
        
        # Use max priority for new samples or a default value
        max_prio = max(buffer['priorities'], default=1.0)
        
        # Boost priority for important experiences
        if abs(r) > 5.0:  # High reward or penalty
            max_prio = max_prio * 1.5
        
        data = (s, a, r, ns, d)
        
        if len(buffer['data']) < self.capacity:
            buffer['data'].append(data)
            buffer['priorities'].append(max_prio)
        else:
            buffer['data'][self.pos[role]] = data
            buffer['priorities'][self.pos[role]] = max_prio
            
        self.pos[role] = (self.pos[role] + 1) % self.capacity
        
    def sample(self, batch_size, role):
        """Sample with prioritization and proper beta scheduling"""
        buffer = self.buffers[role]
        if len(buffer['data']) < batch_size:
            return None
        
        # Get current beta value
        beta = self.get_beta()
            
        # Calculate sampling probabilities with smoothing
        probs = np.array(buffer['priorities']) ** self.alpha
        probs = np.clip(probs, 1e-5, 1e5)  # Prevent numerical instabilities
        probs /= probs.sum()
        
        # Sample with priority
        indices = np.random.choice(len(buffer['data']), batch_size, p=probs)
        samples = [buffer['data'][i] for i in indices]
        
        # Calculate importance sampling weights
        weights = (len(buffer['data']) * probs[indices]) ** -beta
        weights /= weights.max()  # Normalize weights
        
        s, a, r, ns, d = zip(*samples)
        return (np.array(s), np.array(a), np.array(r), np.array(ns),
                np.array(d), weights, indices)
                
    def update_priorities(self, indices, priorities, role):
        """Update priorities with more sophisticated boosting for high rewards"""
        # Clip priorities to prevent extreme values
        priorities = np.clip(priorities, 0.01, 15.0)  # Higher upper clip as recommended
        
        # Boost priorities for experiences with high rewards
        for i, p, idx in zip(range(len(indices)), priorities, indices):
            # Extract the reward from the experience
            _, _, r, _, _ = self.buffers[role]['data'][idx]
            
            # Apply boosting based on reward magnitude
            if abs(r) > 10.0:  # Very important experience
                p = p * 1.5
            elif abs(r) > 5.0:  # Important experience
                p = p * 1.2
            elif abs(r) > 1.0:  # Somewhat important
                p = p * 1.1
            
            # Additional boost for scout completing missions
            if role == 1 and r > 4.5:  # Scout completing missions (reward ~5)
                p = p * 1.3
                
            self.buffers[role]['priorities'][idx] = p
            
    def __len__(self, role):
        return len(self.buffers[role]['data'])

# Improved DQN Agent with recommended enhancements
class RobustDQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99):
        self.gamma = gamma
        
        # Role-specific configurations with refined parameters
        self.configs = {
            0: {  # Guard
                "lr": 2e-4,  # Higher initial learning rate as recommended
                "update_freq": 250,
                "grad_steps": 1,
                "target_update_freq": 1000  # Less frequent target updates for stability
            },
            1: {  # Scout
                "lr": 5e-5,  # Lower for more stability
                "update_freq": 200,
                "grad_steps": 2,  # More gradient steps for scout
                "target_update_freq": 800  # More frequent target updates for scout
            }
        }
        
        # Build separate agents for scout and guard
        self.agents = {
            0: self._build_agent(state_dim, action_dim, self.configs[0]),  # Guard
            1: self._build_agent(state_dim, action_dim, self.configs[1])   # Scout
        }
        
        # Exploration parameters with slower decay as recommended
        self.epsilon = {
            0: 0.3,  # Start with higher exploration
            1: 0.4   # Even higher for scout
        }
        self.epsilon_decay = {
            0: 0.9998,  # Slower decay for guard
            1: 0.9999   # Even slower decay for scout
        }
        self.epsilon_final = {
            0: 0.08,    # Higher final epsilon
            1: 0.12     # Even higher for scout
        }
        
        # Training counters
        self.update_counter = {0: 0, 1: 0}
        
        # Historical models for self-play (maintain this feature from original code)
        self.historical_models = {
            0: [],  # Guard history
            1: []   # Scout history
        }
        self.historical_model_episodes = []
        
        # Parameter noise for exploration
        self.param_noise_std = {0: 0.01, 1: 0.02}
        self.param_noise_decay = 0.9995
        
    def _build_agent(self, state_dim, action_dim, config):
        """Build agent components with improved architecture"""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = EnhancedDQN(state_dim, action_dim).to(device)
        target = EnhancedDQN(state_dim, action_dim).to(device)
        target.load_state_dict(model.state_dict())
        
        # Freeze target network parameters
        for param in target.parameters():
            param.requires_grad = False
            
        # Use Adam with improved parameters
        optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=config["lr"], 
            eps=1e-5,  # For numerical stability
            weight_decay=1e-5  # Light L2 regularization
        )
        
        # Add learning rate scheduler for better convergence
        # Use cosine annealing for smoother learning rate decay
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=10000,  # Longer cycle
            eta_min=config["lr"] / 10  # Don't reduce LR too much
        )
        
        return {
            "model": model, 
            "target": target, 
            "optimizer": optimizer,
            "scheduler": scheduler,
            "device": device,
            "config": config
        }
    
    # Store historical models for self-play
    def store_historical_model(self, episode):
        """Store a snapshot of current models for later self-play"""
        guard_model = copy.deepcopy(self.agents[0]["model"].state_dict())
        scout_model = copy.deepcopy(self.agents[1]["model"].state_dict())
        
        self.historical_models[0].append(guard_model)
        self.historical_models[1].append(scout_model)
        self.historical_model_episodes.append(episode)
        
        # Keep only last 5 historical models to save memory
        if len(self.historical_models[0]) > 5:
            self.historical_models[0].pop(0)
            self.historical_models[1].pop(0)
            self.historical_model_episodes.pop(0)
        
        print(f"Stored historical models at episode {episode}")
    
    def add_parameter_noise(self, role):
        """Add noise to model parameters for better exploration"""
        if self.param_noise_std[role] > 0.001:  # Only add noise if it's significant
            for param in self.agents[role]["model"].parameters():
                noise = torch.randn_like(param.data) * self.param_noise_std[role]
                param.data += noise
            
            # Decay noise standard deviation
            self.param_noise_std[role] *= self.param_noise_decay
    
    # Act using either current or historical model
    def act(self, state, role, evaluation=False, use_historical=False, historical_idx=None):
        """Select action using epsilon-greedy with optional historical model"""
        epsilon = 0.01 if evaluation else self.epsilon[role]  # Use minimal exploration in eval mode
        device = self.agents[role]["device"]
        
        # Epsilon-greedy exploration
        if random.random() < epsilon:
            return random.randint(0, 4)
        
        # Convert state to tensor
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        
        # Use historical model if specified
        if use_historical and historical_idx is not None and len(self.historical_models[role]) > historical_idx:
            # Create temporary model with historical weights
            # Access the dimensions correctly from the model itself
            in_features = self.agents[role]["model"].feature1[0].in_features
            out_features = self.agents[role]["model"].advantage[-1].out_features
            
            temp_model = EnhancedDQN(in_features, out_features).to(device)
            temp_model.load_state_dict(self.historical_models[role][historical_idx])
            temp_model.eval()
            
            with torch.no_grad():
                q_values = temp_model(state)
                action = q_values.argmax().item()
                
            return action
        
        # Otherwise use current model
        self.agents[role]["model"].eval()
        with torch.no_grad():
            q_values = self.agents[role]["model"](state)
            action = q_values.argmax().item()
        self.agents[role]["model"].train()
        
        return action
        
    def decay_epsilon(self, role):
        """Decay epsilon according to recommended slower schedule"""
        self.epsilon[role] = max(
            self.epsilon_final[role],
            self.epsilon[role] * self.epsilon_decay[role]
        )
            
    def update(self, buffer, role, batch_size=64):
        """Update the agent with improved training stability measures"""
        result = buffer.sample(batch_size, role)
        if result is None:
            return
            
        s, a, r, ns, d, w, idx = result
        agent = self.agents[role]
        device = agent["device"]
        config = agent["config"]
        
        # Convert numpy arrays to tensors
        s_t = torch.tensor(s, dtype=torch.float32).to(device)
        ns_t = torch.tensor(ns, dtype=torch.float32).to(device)
        a_t = torch.tensor(a, dtype=torch.long).unsqueeze(1).to(device)
        r_t = torch.tensor(r, dtype=torch.float32).unsqueeze(1).to(device)
        d_t = torch.tensor(d, dtype=torch.float32).unsqueeze(1).to(device)
        w_t = torch.tensor(w, dtype=torch.float32).unsqueeze(1).to(device)
        
        # Ensure model is in training mode for batch updates
        agent["model"].train()
        agent["target"].eval()
        
        loss_sum = 0
        
        # Update step with multiple gradient accumulations for stability
        for _ in range(config["grad_steps"]):
            # Calculate current Q values
            q_vals = agent["model"](s_t).gather(1, a_t)
            
            # Double DQN update - reduces overestimation bias
            with torch.no_grad():
                # Select actions using online network
                next_actions = agent["model"](ns_t).argmax(1, keepdim=True)
                # Evaluate Q-values using target network
                next_q = agent["target"](ns_t).gather(1, next_actions)
                # Calculate expected Q values
                expected = r_t + self.gamma * next_q * (1 - d_t)
            
            # Use Huber loss for more stable updates
            loss = F.smooth_l1_loss(q_vals, expected, reduction='none')
            weighted_loss = (w_t * loss).mean()
            
            loss_sum += weighted_loss.item()
            
            # Calculate TD errors for priority update
            with torch.no_grad():
                td_err = (q_vals - expected).abs().detach().cpu().numpy().flatten()
            
            # Optimization step
            agent["optimizer"].zero_grad()
            weighted_loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(agent["model"].parameters(), max_norm=10.0)
            
            agent["optimizer"].step()
        
        # Update priorities in the replay buffer
        buffer.update_priorities(idx, td_err, role)
        
        # Step the learning rate scheduler
        agent["scheduler"].step()
        
        # Target network update with role-specific frequency
        self.update_counter[role] += 1
        if self.update_counter[role] % config["target_update_freq"] == 0:
            agent["target"].load_state_dict(agent["model"].state_dict())
            
        return loss_sum / config["grad_steps"]
    
    def save(self, path_prefix="./model5/"):
        """Save both scout and guard models using safetensors."""
        try:
            from safetensors.torch import save_file
            os.makedirs(path_prefix, exist_ok=True)
            flattened = {}
            for role_name, role_id in [("guard", 0), ("scout", 1)]:
                for part in ["model", "target"]:
                    state_dict = self.agents[role_id][part].state_dict()
                    for key, tensor in state_dict.items():
                        if isinstance(tensor, torch.Tensor):
                            flattened[f"{role_name}_{part}.{key}"] = tensor
            
            save_file(flattened, os.path.join(path_prefix, "model.safetensors"))
            print(f"Safetensors model saved to {path_prefix}model.safetensors")
        except Exception as e:
            print(f"Error saving model: {e}")
            
    def load(self, path):
        """Load model from safetensors file"""
        try:
            from safetensors.torch import load_file
            loaded = load_file(path)
            
            # Extract and load the model weights
            for role_name, role_id in [("guard", 0), ("scout", 1)]:
                for part in ["model", "target"]:
                    model = self.agents[role_id][part]
                    state_dict = model.state_dict()
                    
                    # Update state dict with loaded weights
                    for key in state_dict:
                        loaded_key = f"{role_name}_{part}.{key}"
                        if loaded_key in loaded:
                            state_dict[key] = loaded[loaded_key]
                    
                    # Load updated state dict
                    model.load_state_dict(state_dict)
            print(f"Model loaded from {path}")
        except Exception as e:
            print(f"Error loading model: {e}")

# Enhanced flatten observation function with normalization
def preprocess_observation(obs):
    """
    Convert observation to a flat vector with consistent size and better normalization.
    """
    # Safety check for None observation
    if obs is None:
        print("Warning: Received None observation in preprocess_observation. Returning None.")
        return None
        
    try:
        # Get viewcone and normalize to [0,1]
        if 'viewcone' not in obs:
            print("Warning: 'viewcone' not found in observation. Using zeros.")
            flat_view = np.zeros(35)  # Assuming 7x5 viewcone
        else:
            flat_view = np.array(obs['viewcone']).flatten() / 255.0
        
        # Direction as one-hot encoding
        if 'direction' not in obs:
            print("Warning: 'direction' not found in observation. Using zeros.")
            direction_onehot = np.zeros(4)
        else:
            direction_onehot = np.zeros(4)
            direction_onehot[obs['direction']] = 1
        
        # Role indicator (scout or guard)
        if 'scout' not in obs:
            print("Warning: 'scout' not found in observation. Using default False.")
            is_scout = np.array([0])
        else:
            is_scout = np.array([obs['scout']])
        
        # Location normalized to [0,1]
        if 'location' not in obs:
            print("Warning: 'location' not found in observation. Using zeros.")
            location = np.zeros(2)
        else:
            location = np.array(obs['location']) / 15.0
        
        # Step count normalized
        if 'step' not in obs:
            print("Warning: 'step' not found in observation. Using zero.")
            step = np.array([0])
        else:
            step = np.array([obs['step'] / 100.0])
        
        # Combine all features
        return np.concatenate([flat_view, direction_onehot, is_scout, location, step])
        
    except Exception as e:
        print(f"Error in preprocess_observation: {e}")
        return None

# Improved reward shaping function with better gradients
def compute_improved_reward(obs, reward, role, env, current_phase=0):
    """Shaped rewards with better learning signals and curriculum adaptation"""
    # Safety check for None observation
    if obs is None:
        return reward
        
    agent_pos = obs['location']
    
    # Start with original reward
    shaped_reward = reward
    
    if role == 1:  # scout
        # Scout reward modifiers
        
        # 1. Make capture penalty less severe
        if reward <= -40:  # Captured
            shaped_reward = -10  # Less severe than before
        
        # 2. Increase proximity rewards
        targets = env.recon_points.union(env.missions) if hasattr(env, 'recon_points') and hasattr(env, 'missions') else set()
        if targets:
            # Calculate distances to targets with mission prioritization
            distances = []
            for tx, ty in targets:
                # Manhattan distance
                dist = abs(agent_pos[0] - tx) + abs(agent_pos[1] - ty)
                # Prioritize missions (higher reward targets)
                is_mission = (tx, ty) in env.missions if hasattr(env, 'missions') else False
                if is_mission:
                    dist *= 0.8  # Make missions seem "closer" for reward calculation
                distances.append(dist)
                
            min_dist = min(distances) if distances else 16
            
            # Stronger proximity reward with sigmoid curve for smoother transitions
            proximity_reward = 1.0 / (1.0 + np.exp((min_dist - 3) / 2))
            shaped_reward += proximity_reward
            
            # Progress reward - increase for getting closer
            if hasattr(env, 'last_scout_pos') and env.last_scout_pos is not None:
                old_distances = [abs(env.last_scout_pos[0] - tx) + abs(env.last_scout_pos[1] - ty) 
                               for tx, ty in targets]
                old_min_dist = min(old_distances) if old_distances else 16
                
                if min_dist < old_min_dist:
                    # Reward proportional to progress made
                    progress = old_min_dist - min_dist
                    shaped_reward += 0.3 * progress
        
        # Store position for next step
        env.last_scout_pos = agent_pos.copy()
        
        # 3. Amplify collection rewards
        if reward >= 4.5:  # Completed mission (original reward ~5)
            shaped_reward *= 1.3
        elif reward >= 0.8:  # Collected recon point (original reward ~1)
            shaped_reward *= 1.2
            
    else:  # guard
        # Add small positive reward just for being a guard to offset negative bias
        shaped_reward += 0.1
        
        # Focus more on patrolling high-value areas
        targets = env.missions if hasattr(env, 'missions') else set()
        if targets:
            distances = [abs(agent_pos[0] - tx) + abs(agent_pos[1] - ty) for tx, ty in targets]
            min_dist = min(distances) if distances else 16
            
            # Stronger rewards for guards near targets
            shaped_reward += 0.3 * np.exp(-0.2 * min_dist)
        
        # Reward guards for captures
        if reward > 40:  # If there's a capture reward
            shaped_reward *= 1.4  # Amplify it
    
    # Add time-dependent behavior
    time_progress = obs['step'] / 100.0
    if role == 1:  # Scout - incentivize acting quickly
        time_factor = 1.0 - time_progress * 0.3  # Linear decay from 1.0 to 0.7
        shaped_reward *= time_factor
    else:  # Guard - be more aggressive later
        if time_progress > 0.7:  # Last 30% of episode
            shaped_reward *= (1.0 + 0.3 * time_progress)  # Up to 30% boost late in episode
    
    # Phase-specific reward adjustments
    if current_phase >= 1:  # In later phases
        # Add small exploration bonus to encourage trying new strategies
        shaped_reward += 0.05 * random.random()
    
    return shaped_reward

# Main training function with improved curriculum
def train_with_curriculum(env_class, episodes=5000, save_interval=100):
    """Training with curriculum learning, improved stability, and self-play"""
    
    # Initialize environment
    env = env_class()
    state_dim = 43  # Based on flatten_obs output
    action_dim = 5
    
    # Create agent with optimized parameters
    agent = RobustDQNAgent(state_dim=state_dim, action_dim=action_dim)
    
    # Use replay buffer with proper prioritization
    buffer = StabilizedReplayBuffer(capacity=100000)
    
    # Training tracking
    rewards_window = {
        'all': [],
        'scout': [],
        'guard': []
    }
    scout_captures = 0
    scout_collections = 0
    guard_captures = 0
    
    # Ensure environment has attributes we need
    if not hasattr(env, 'last_scout_pos'):
        env.last_scout_pos = None
    
    start_time = time.time()
    
    # Modified curriculum learning phases with gentler transitions
    curriculum = [
        {'episodes': 1000, 'scout_ratio': 0.6, 'scout_lr': 5e-5, 'guard_lr': 2e-4},
        {'episodes': 2000, 'scout_ratio': 0.5, 'scout_lr': 4.5e-5, 'guard_lr': 1.8e-4},
        {'episodes': 2000, 'scout_ratio': 0.5, 'scout_lr': 4e-5, 'guard_lr': 1.5e-4},
        {'episodes': 5000, 'scout_ratio': 0.5, 'scout_lr': 3.5e-5, 'guard_lr': 1.2e-4},
        {'episodes': 10000, 'scout_ratio': 0.5, 'scout_lr': 3e-5, 'guard_lr': 1e-4}
    ]
    
    # Track progress through curriculum
    current_phase = 0
    phase_progress = 0
    phase_role_counts = {0: 0, 1: 0}
    
    # Initialize counter for role balancing
    episodes_since_role = {0: 0, 1: 0}
    
    # Historical model snapshots
    historical_snapshot_interval = 500
    
    # Optional: Enable periodic evaluation
    evaluation_interval = 250
    best_eval_reward = {0: -float('inf'), 1: -float('inf')}
    
    for ep in range(1, episodes+1):
        # Update curriculum phase if needed
        if current_phase < len(curriculum) - 1:
            if phase_progress >= curriculum[current_phase]['episodes']:
                current_phase += 1
                phase_progress = 0
                phase_role_counts = {0: 0, 1: 0}
                print(f"Moving to curriculum phase {current_phase+1}")
                
                # Update learning rates
                for role in [0, 1]:
                    new_lr = curriculum[current_phase][f"{'scout' if role==1 else 'guard'}_lr"]
                    for param_group in agent.agents[role]["optimizer"].param_groups:
                        param_group['lr'] = new_lr
        
        # Store historical models periodically
        if ep % historical_snapshot_interval == 0:
            agent.store_historical_model(ep)
        
        # Apply parameter noise periodically for enhanced exploration
        if ep % 10 == 0:
            for role in [0, 1]:
                agent.add_parameter_noise(role)
        
        # Role selection logic for balanced training
        scout_ratio = curriculum[current_phase]['scout_ratio']
        expected_scout_count = phase_progress * scout_ratio
        expected_guard_count = phase_progress * (1 - scout_ratio)
        
        # Force a scout episode if we're behind or it's been too long
        if phase_role_counts[1] < expected_scout_count - 5 or episodes_since_role[1] > 15:
            role = 1  # Scout
            episodes_since_role[1] = 0
            episodes_since_role[0] += 1
        # Force a guard episode if we're behind or it's been too long
        elif phase_role_counts[0] < expected_guard_count - 5 or episodes_since_role[0] > 15:
            role = 0  # Guard
            episodes_since_role[0] = 0
            episodes_since_role[1] += 1
        # Regular probability-based selection
        else:
            if random.random() < scout_ratio:
                role = 1  # Scout
                episodes_since_role[1] = 0
                episodes_since_role[0] += 1
            else:
                role = 0  # Guard
                episodes_since_role[0] = 0
                episodes_since_role[1] += 1
            
        # Update phase tracking
        phase_progress += 1
        phase_role_counts[role] += 1
        
        # Set environment role
        env.is_scout = (role == 1)
        
        # Use historical opponent occasionally for self-play
        use_historical_opponent = (len(agent.historical_models[0]) > 0 and
                                  random.random() < 0.2)  # 20% chance
        historical_idx = random.randint(0, len(agent.historical_models[0])-1) if use_historical_opponent else None
        
        # Reset environment
        obs = env.reset()
        state = preprocess_observation(obs)
        episode_reward = 0
        
        # Track recon/mission collection
        initial_recon_count = len(env.recon_points) if hasattr(env, 'recon_points') else 0
        initial_mission_count = len(env.missions) if hasattr(env, 'missions') else 0
        
        # Episode loop
        done = False
        steps = 0
        
        while not done:
            # Select action (potentially using historical model for opponent)
            action = agent.act(state, role, use_historical=use_historical_opponent, historical_idx=historical_idx)
            next_obs, reward, done = env.step(action)
            next_state = preprocess_observation(next_obs)
            
            # Shape reward for better learning signal
            shaped_reward = compute_improved_reward(next_obs, reward, role, env, current_phase)
            
            # Store experience in buffer
            buffer.push(state, action, shaped_reward, next_state, done, role)
            
            # Update agent with adaptive frequency
            update_freq = 4  # Less frequent updates for better stability
            if steps % update_freq == 0:
                # Multiple updates per batch
                for _ in range(2):
                    loss = agent.update(buffer, role, batch_size=128)
            
            # Move to next state
            state = next_state
            episode_reward += reward
            steps += 1
            
            # Optional: Early stopping if episode is very long
            if steps >= 150:  # Safeguard against very long episodes
                done = True
        
        # Check for scout captures or collections
        if role == 1:  # scout
            if reward <= -40:  # Captured
                scout_captures += 1
            
            # Check for collections
            if hasattr(env, 'recon_points') and hasattr(env, 'missions'):
                recon_collected = initial_recon_count - len(env.recon_points)
                missions_completed = initial_mission_count - len(env.missions)
                if recon_collected > 0 or missions_completed > 0:
                    scout_collections += 1
        else:  # guard
            # Assume positive reward for guard might be capture
            if reward > 40:
                guard_captures += 1
        
        # Decay exploration rate
        agent.decay_epsilon(role)
        
        # Track rewards by role (using sliding window)
        window_size = 50
        if len(rewards_window['all']) >= window_size:
            rewards_window['all'].pop(0)
        rewards_window['all'].append(episode_reward)
        
        if role == 1:  # scout
            if len(rewards_window['scout']) >= window_size:
                rewards_window['scout'].pop(0)
            rewards_window['scout'].append(episode_reward)
        else:  # guard
            if len(rewards_window['guard']) >= window_size:
                rewards_window['guard'].pop(0)
            rewards_window['guard'].append(episode_reward)
        
        # Log training progress periodically
        if ep % 100 == 0:
            elapsed = time.time() - start_time
            
            # Calculate stats
            avg_r = np.mean(rewards_window['all']) if rewards_window['all'] else 0
            avg_scout = np.mean(rewards_window['scout']) if rewards_window['scout'] else 0
            avg_guard = np.mean(rewards_window['guard']) if rewards_window['guard'] else 0
            
            # Log training status
            print(f"Ep {ep}/{episodes} | "
                  f"Time: {elapsed:.1f}s | "
                  f"Avg Reward: {avg_r:.2f} | "
                  f"Scout: {avg_scout:.2f} | "
                  f"Guard: {avg_guard:.2f} | "
                  f"Scout Captures: {scout_captures} | "
                  f"Scout Collections: {scout_collections} | "
                  f"Guard Captures: {guard_captures} | "
                  f"Epsilon (S/G): {agent.epsilon[1]:.3f}/{agent.epsilon[0]:.3f}")
            
            # Print current learning rates
            print(f"Learning rates - Scout: {agent.agents[1]['optimizer'].param_groups[0]['lr']:.6f} | "
                  f"Guard: {agent.agents[0]['optimizer'].param_groups[0]['lr']:.6f}")
        
        # Periodic evaluation
        if ep % evaluation_interval == 0:
            eval_rewards = {0: [], 1: []}
            
            # Run evaluation episodes
            print("\n----- Running Evaluation -----")
            for eval_role in [0, 1]:
                role_name = "Scout" if eval_role == 1 else "Guard"
                print(f"Evaluating {role_name}...")
                
                for _ in range(10):  # 10 episodes per role
                    env.is_scout = (eval_role == 1)
                    eval_obs = env.reset()
                    eval_state = preprocess_observation(eval_obs)
                    eval_episode_reward = 0
                    eval_done = False
                    eval_steps = 0
                    
                    while not eval_done and eval_steps < 150:
                        # Use evaluation mode (minimal exploration)
                        eval_action = agent.act(eval_state, eval_role, evaluation=True)
                        eval_next_obs, eval_reward, eval_done = env.step(eval_action)
                        eval_next_state = preprocess_observation(eval_next_obs)
                        eval_episode_reward += eval_reward
                        eval_state = eval_next_state
                        eval_steps += 1
                    
                    eval_rewards[eval_role].append(eval_episode_reward)
            
            # Calculate average evaluation rewards
            avg_eval_scout = np.mean(eval_rewards[1]) if eval_rewards[1] else 0
            avg_eval_guard = np.mean(eval_rewards[0]) if eval_rewards[0] else 0
            
            print(f"Evaluation results - Scout: {avg_eval_scout:.2f} | Guard: {avg_eval_guard:.2f}")
            
            # Save models if they're better than previous best
            improved = False
            
            if avg_eval_scout > best_eval_reward[1]:
                best_eval_reward[1] = avg_eval_scout
                improved = True
                print(f"New best Scout model! Reward: {avg_eval_scout:.2f}")
            
            if avg_eval_guard > best_eval_reward[0]:
                best_eval_reward[0] = avg_eval_guard
                improved = True
                print(f"New best Guard model! Reward: {avg_eval_guard:.2f}")
            
            if improved:
                agent.save("./model5/")
                print("Best model saved!")
        
        # Regular model saving
        if ep % save_interval == 0:
            agent.save()
    
    # Final save
    agent.save("./model5/")
    print(f"Training complete! Scout episodes: {phase_role_counts[1]}, "
          f"Guard episodes: {phase_role_counts[0]}")
    return agent

if __name__ == "__main__":
    # Set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    
    # Start improved training
    agent = train_with_curriculum(TILAIEnv, episodes=20000, save_interval=500)

Ep 100/20000 | Time: 31.2s | Avg Reward: 9.95 | Scout: 7.73 | Guard: 2.40 | Scout Captures: 18 | Scout Collections: 23 | Guard Captures: 12 | Epsilon (S/G): 0.398/0.297
Learning rates - Scout: 0.000047 | Guard: 0.000189
Ep 200/20000 | Time: 64.1s | Avg Reward: -4.07 | Scout: 3.77 | Guard: 4.07 | Scout Captures: 40 | Scout Collections: 50 | Guard Captures: 23 | Epsilon (S/G): 0.395/0.295
Learning rates - Scout: 0.000039 | Guard: 0.000164

----- Running Evaluation -----
Evaluating Guard...
Evaluating Scout...
Evaluation results - Scout: 5.40 | Guard: -9.13
New best Scout model! Reward: 5.40
New best Guard model! Reward: -9.13
Safetensors model saved to ./model5/model.safetensors
Best model saved!
Ep 300/20000 | Time: 103.0s | Avg Reward: 4.43 | Scout: 7.28 | Guard: -3.47 | Scout Captures: 58 | Scout Collections: 80 | Guard Captures: 30 | Epsilon (S/G): 0.393/0.293
Learning rates - Scout: 0.000024 | Guard: 0.000137
Ep 400/20000 | Time: 138.1s | Avg Reward: 5.13 | Scout: 6.20 | Guard: 4.78