In [None]:
# Install required packages
!pip install torch torchvision numpy matplotlib -q

print("✓ All packages installed successfully!")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
from collections import deque
import random
from typing import List, Tuple, Dict
import matplotlib.pyplot as plt
from IPython.display import clear_output
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("✓ Libraries imported successfully!")

In [None]:
class LowerTransformer(nn.Module):
    """Aggregates observations, actions, and rewards from target and neighbors"""
    def __init__(self, d_model=64, nhead=4, num_layers=3, n_neighbors=4):
        super().__init__()
        self.d_model = d_model
        self.n_neighbors = n_neighbors
        
        # Separate embeddings for o, a, r
        self.obs_embed = nn.Linear(25, d_model)
        self.action_embed = nn.Linear(8, d_model)
        self.reward_embed = nn.Linear(1, d_model)
        
        # Decision token
        self.decision_token = nn.Parameter(torch.randn(1, 1, d_model))
        
        # Position embedding
        self.pos_embed = nn.Parameter(torch.randn(1, 3*(1+n_neighbors)+1, d_model))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=d_model*4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
    def forward(self, obs, actions, rewards, mask=None):
        batch_size = obs.shape[0]
        
        # Embed o, a, r separately
        obs_emb = self.obs_embed(obs)
        act_emb = self.action_embed(actions)
        rew_emb = self.reward_embed(rewards)
        
        # Concatenate along sequence dimension
        seq = torch.stack([obs_emb, act_emb, rew_emb], dim=2)
        seq = seq.reshape(batch_size, -1, self.d_model)
        
        # Prepend decision token
        decision = self.decision_token.expand(batch_size, -1, -1)
        seq = torch.cat([decision, seq], dim=1)
        
        # Add positional embedding
        seq = seq + self.pos_embed
        
        # Apply transformer
        out = self.transformer(seq)
        
        return out[:, 0, :]  # Return decision token output

print("✓ Lower Transformer defined")

In [None]:
class UpperTransformer(nn.Module):
    """Learns scenario-agnostic decision policies across timesteps"""
    def __init__(self, d_model=64, d_output=128, nhead=4, num_layers=3, history_len=10):
        super().__init__()
        self.d_model = d_model
        self.d_output = d_output
        self.history_len = history_len
        
        # Project lower transformer output
        self.input_proj = nn.Linear(d_model, d_output)
        
        # Position embedding
        self.pos_embed = nn.Parameter(torch.randn(1, history_len, d_output))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_output,
            nhead=nhead,
            dim_feedforward=d_output*4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
    def forward(self, lower_outputs):
        # Project to d_output dimension
        x = self.input_proj(lower_outputs)
        
        # Add positional embedding
        x = x + self.pos_embed
        
        # Apply transformer
        out = self.transformer(x)
        
        return out

print("✓ Upper Transformer defined")

In [None]:
class DynamicPredictor(nn.Module):
    """Predicts next state embedding for learning environment dynamics"""
    def __init__(self, d_model=128, n_neighbors=4):
        super().__init__()
        
        input_dim = d_model + (1 + n_neighbors) * (8 + 1)
        
        self.predictor = nn.Sequential(
            nn.Linear(input_dim, d_model * 2),
            nn.ReLU(),
            nn.Linear(d_model * 2, d_model * 2),
            nn.ReLU(),
            nn.Linear(d_model * 2, d_model)
        )
        
    def forward(self, prev_embedding, actions, rewards):
        batch_size = prev_embedding.shape[0]
        
        actions_flat = actions.reshape(batch_size, -1)
        rewards_flat = rewards.reshape(batch_size, -1)
        
        x = torch.cat([prev_embedding, actions_flat, rewards_flat], dim=-1)
        
        return self.predictor(x)

print("✓ Dynamic Predictor defined")

In [None]:
class XLight(nn.Module):
    """Complete X-Light model"""
    def __init__(self, 
                 obs_dim=25,
                 action_dim=8,
                 d_model=64,
                 d_output=128,
                 n_neighbors=4,
                 history_len=10,
                 nhead=4,
                 num_layers=3):
        super().__init__()
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.n_neighbors = n_neighbors
        self.history_len = history_len
        
        # Lower Transformer
        self.lower_transformer = LowerTransformer(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            n_neighbors=n_neighbors
        )
        
        # Upper Transformer
        self.upper_transformer = UpperTransformer(
            d_model=d_model,
            d_output=d_output,
            nhead=nhead,
            num_layers=num_layers,
            history_len=history_len
        )
        
        # Dynamic Predictor
        self.dynamic_predictor = DynamicPredictor(
            d_model=d_output,
            n_neighbors=n_neighbors
        )
        
        # Actor (with residual connection)
        self.actor = nn.Sequential(
            nn.Linear(d_output + obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )
        
        # Critic
        self.critic = nn.Sequential(
            nn.Linear(d_output, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
    def forward(self, obs_history, action_history, reward_history, 
                current_obs, neighbor_mask=None):
        batch_size = obs_history.shape[0]
        
        # Process each timestep through Lower Transformer
        lower_outputs = []
        for t in range(self.history_len):
            mask_t = neighbor_mask[:, t, :] if neighbor_mask is not None else None
            lower_out = self.lower_transformer(
                obs_history[:, t, :, :],
                action_history[:, t, :, :],
                reward_history[:, t, :, :],
                mask_t
            )
            lower_outputs.append(lower_out)
        
        lower_outputs = torch.stack(lower_outputs, dim=1)
        
        # Process through Upper Transformer
        upper_outputs = self.upper_transformer(lower_outputs)
        
        # Dynamic prediction for pretext task
        predicted_embeddings = []
        for t in range(self.history_len - 1):
            pred = self.dynamic_predictor(
                upper_outputs[:, t, :],
                action_history[:, t, :, :],
                reward_history[:, t, :, :]
            )
            predicted_embeddings.append(pred)
        predicted_embeddings = torch.stack(predicted_embeddings, dim=1) if predicted_embeddings else None
        
        # Use last timestep for decision
        last_embedding = upper_outputs[:, -1, :]
        
        # Actor with residual connection
        actor_input = torch.cat([last_embedding, current_obs], dim=-1)
        action_logits = self.actor(actor_input)
        
        # Critic
        value = self.critic(last_embedding)
        
        return action_logits, value, predicted_embeddings, upper_outputs[:, 1:, :]

print("✓ X-Light model defined")

In [None]:
class SimpleTrafficEnv:
    """Simplified traffic simulation environment"""
    def __init__(self, n_intersections=16, n_neighbors=4, max_queue=50):
        self.n_intersections = n_intersections
        self.n_neighbors = n_neighbors
        self.max_queue = max_queue
        self.n_phases = 8
        
        # Create a grid topology (4x4)
        self.grid_size = int(np.sqrt(n_intersections))
        self.adjacency = self._create_grid_adjacency()
        
        self.reset()
        
    def _create_grid_adjacency(self):
        """Create adjacency matrix for grid topology"""
        adj = np.zeros((self.n_intersections, self.n_intersections))
        for i in range(self.n_intersections):
            row = i // self.grid_size
            col = i % self.grid_size
            
            neighbors = []
            if row > 0:
                neighbors.append((row-1) * self.grid_size + col)
            if row < self.grid_size - 1:
                neighbors.append((row+1) * self.grid_size + col)
            if col > 0:
                neighbors.append(row * self.grid_size + (col-1))
            if col < self.grid_size - 1:
                neighbors.append(row * self.grid_size + (col+1))
            
            for n in neighbors:
                adj[i, n] = 1
                
        return adj
    
    def reset(self):
        """Reset environment"""
        self.queues = np.random.randint(0, self.max_queue//2, 
                                       (self.n_intersections, 4))
        self.current_phase = np.zeros(self.n_intersections, dtype=int)
        self.phase_duration = np.zeros(self.n_intersections)
        self.time_step = 0
        
        return self._get_all_observations()
    
    def _get_all_observations(self):
        """Get observations for all intersections"""
        obs = []
        for i in range(self.n_intersections):
            obs.append(self._get_observation(i))
        return np.array(obs)
    
    def _get_observation(self, intersection_id):
        """Get observation for a single intersection"""
        queues = self.queues[intersection_id] / self.max_queue
        
        phase_onehot = np.zeros(8)
        phase_onehot[self.current_phase[intersection_id]] = 1
        
        occupancy = np.clip(queues + np.random.randn(4) * 0.1, 0, 1)
        flow = np.random.rand(4) * 0.5
        num_stops = (queues > 0.3).astype(float)
        
        obs = np.concatenate([queues, phase_onehot, occupancy, flow, num_stops])
        return obs
    
    def get_neighbors(self, intersection_id):
        """Get neighbor IDs for an intersection"""
        neighbors = np.where(self.adjacency[intersection_id] == 1)[0]
        
        if len(neighbors) < self.n_neighbors:
            neighbors = np.pad(neighbors, (0, self.n_neighbors - len(neighbors)), 
                             constant_values=-1)
        else:
            neighbors = neighbors[:self.n_neighbors]
            
        return neighbors
    
    def step(self, actions):
        """Execute actions for all intersections"""
        rewards = np.zeros(self.n_intersections)
        
        for i in range(self.n_intersections):
            if actions[i] != self.current_phase[i]:
                self.current_phase[i] = actions[i]
                self.phase_duration[i] = 0
            else:
                self.phase_duration[i] += 1
            
            # Simulate traffic flow
            active_movements = self._get_active_movements(actions[i])
            for movement in active_movements:
                reduction = np.random.poisson(3)
                self.queues[i, movement] = max(0, self.queues[i, movement] - reduction)
            
            arrivals = np.random.poisson(2, 4)
            self.queues[i] += arrivals
            self.queues[i] = np.clip(self.queues[i], 0, self.max_queue)
            
            queue_penalty = -np.sum(self.queues[i])
            wait_penalty = -np.sum(self.queues[i] ** 1.5) * 0.01
            rewards[i] = queue_penalty + wait_penalty
        
        self.time_step += 1
        done = self.time_step >= 360  # Shorter episodes for faster training
        
        obs = self._get_all_observations()
        
        return obs, rewards, done, {}
    
    def _get_active_movements(self, phase):
        """Get which movements are active for a phase"""
        phase_movements = {
            0: [0, 2], 1: [1, 3], 2: [0, 1], 3: [0, 3],
            4: [2, 1], 5: [2, 3], 6: [0], 7: [1],
        }
        return phase_movements.get(phase, [])

print("✓ Traffic Environment defined")

In [None]:
class RolloutBuffer:
    """Store trajectories for PPO training"""
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = []
        
    def push(self, obs_history, action_history, reward_history, 
             current_obs, action, reward, value, log_prob, neighbor_indices):
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
            
        self.buffer.append({
            'obs_history': obs_history,
            'action_history': action_history,
            'reward_history': reward_history,
            'current_obs': current_obs,
            'action': action,
            'reward': reward,
            'value': value,
            'log_prob': log_prob,
            'neighbor_indices': neighbor_indices
        })
    
    def sample(self, batch_size):
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))
    
    def clear(self):
        self.buffer = []
    
    def __len__(self):
        return len(self.buffer)


def compute_ppo_loss(model, batch, clip_epsilon=0.2, alpha=1.0, beta=0.5, gamma=1.0):
    """Compute PPO loss with dynamic prediction"""
    obs_hist = torch.stack([torch.FloatTensor(t['obs_history']) for t in batch])
    act_hist = torch.stack([torch.FloatTensor(t['action_history']) for t in batch])
    rew_hist = torch.stack([torch.FloatTensor(t['reward_history']) for t in batch])
    current_obs = torch.stack([torch.FloatTensor(t['current_obs']) for t in batch])
    actions = torch.LongTensor([t['action'] for t in batch])
    rewards = torch.FloatTensor([t['reward'] for t in batch])
    old_values = torch.FloatTensor([t['value'] for t in batch])
    old_log_probs = torch.FloatTensor([t['log_prob'] for t in batch])
    
    # Forward pass
    action_logits, values, predicted_embs, target_embs = model(
        obs_hist, act_hist, rew_hist, current_obs
    )
    
    # Actor loss (PPO)
    dist = Categorical(logits=action_logits)
    new_log_probs = dist.log_prob(actions)
    entropy = dist.entropy().mean()
    
    ratio = torch.exp(new_log_probs - old_log_probs)
    advantages = rewards - old_values.detach()
    
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
    actor_loss = -torch.min(surr1, surr2).mean()
    
    # Critic loss
    critic_loss = F.mse_loss(values.squeeze(), rewards)
    
    # Dynamic prediction loss
    if predicted_embs is not None and target_embs is not None:
        pred_loss = F.mse_loss(predicted_embs, target_embs.detach())
    else:
        pred_loss = torch.tensor(0.0)
    
    # Total loss
    total_loss = alpha * actor_loss + beta * critic_loss + gamma * pred_loss
    
    return total_loss, {
        'actor_loss': actor_loss.item(),
        'critic_loss': critic_loss.item(),
        'pred_loss': pred_loss.item() if isinstance(pred_loss, torch.Tensor) else pred_loss,
        'entropy': entropy.item()
    }

print("✓ Training components defined")

In [None]:
class XLightTrainer:
    """Trainer for X-Light model"""
    def __init__(self, model, env, lr=3e-4, history_len=10, n_neighbors=4):
        self.model = model
        self.env = env
        self.history_len = history_len
        self.n_neighbors = n_neighbors
        
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.buffer = RolloutBuffer(capacity=10000)
        
        self.obs_histories = {}
        self.action_histories = {}
        self.reward_histories = {}
        
        self._init_histories()
    
    def _init_histories(self):
        """Initialize history buffers"""
        for i in range(self.env.n_intersections):
            self.obs_histories[i] = deque(maxlen=self.history_len)
            self.action_histories[i] = deque(maxlen=self.history_len)
            self.reward_histories[i] = deque(maxlen=self.history_len)
            
            dummy_obs = np.zeros((1 + self.n_neighbors, 25))
            dummy_action = np.zeros((1 + self.n_neighbors, 8))
            dummy_reward = np.zeros((1 + self.n_neighbors, 1))
            
            for _ in range(self.history_len):
                self.obs_histories[i].append(dummy_obs)
                self.action_histories[i].append(dummy_action)
                self.reward_histories[i].append(dummy_reward)
    
    def _get_neighbor_data(self, intersection_id, all_obs, all_actions, all_rewards):
        """Get data for target and neighbors"""
        neighbor_ids = self.env.get_neighbors(intersection_id)
        
        obs_batch = [all_obs[intersection_id]]
        action_batch = [all_actions[intersection_id]]
        reward_batch = [[all_rewards[intersection_id]]]
        
        for nid in neighbor_ids:
            if nid >= 0:
                obs_batch.append(all_obs[nid])
                action_batch.append(all_actions[nid])
                reward_batch.append([all_rewards[nid]])
            else:
                obs_batch.append(np.zeros_like(all_obs[0]))
                action_batch.append(np.zeros(8))
                reward_batch.append([0.0])
        
        return (np.array(obs_batch), 
                np.array(action_batch), 
                np.array(reward_batch),
                neighbor_ids)
    
    def train_episode(self):
        """Train for one episode"""
        obs = self.env.reset()
        done = False
        episode_reward = 0
        step_count = 0
        
        all_actions_onehot = np.zeros((self.env.n_intersections, 8))
        all_actions_onehot[:, 0] = 1
        all_rewards = np.zeros(self.env.n_intersections)
        
        while not done:
            actions = []
            
            for i in range(self.env.n_intersections):
                obs_data, action_data, reward_data, neighbor_ids = \
                    self._get_neighbor_data(i, obs, all_actions_onehot, all_rewards)
                
                self.obs_histories[i].append(obs_data)
                self.action_histories[i].append(action_data)
                self.reward_histories[i].append(reward_data)
                
                obs_hist = torch.FloatTensor(np.array(self.obs_histories[i])).unsqueeze(0)
                act_hist = torch.FloatTensor(np.array(self.action_histories[i])).unsqueeze(0)
                rew_hist = torch.FloatTensor(np.array(self.reward_histories[i])).unsqueeze(0)
                current_obs_tensor = torch.FloatTensor(obs[i]).unsqueeze(0)
                
                with torch.no_grad():
                    action_logits, value, _, _ = self.model(
                        obs_hist, act_hist, rew_hist, current_obs_tensor
                    )
                    dist = Categorical(logits=action_logits)
                    action = dist.sample()
                    log_prob = dist.log_prob(action)
                
                actions.append(action.item())
                
                self.buffer.push(
                    obs_history=np.array(self.obs_histories[i]),
                    action_history=np.array(self.action_histories[i]),
                    reward_history=np.array(self.reward_histories[i]),
                    current_obs=obs[i],
                    action=action.item(),
                    reward=0,
                    value=value.item(),
                    log_prob=log_prob.item(),
                    neighbor_indices=neighbor_ids
                )
            
            next_obs, rewards, done, _ = self.env.step(actions)
            
            for i in range(self.env.n_intersections):
                if len(self.buffer) > i:
                    self.buffer.buffer[-(self.env.n_intersections - i)]['reward'] = rewards[i]
            
            all_actions_onehot = np.zeros((self.env.n_intersections, 8))
            for i, a in enumerate(actions):
                all_actions_onehot[i, a] = 1
            
            all_rewards = rewards
            obs = next_obs
            episode_reward += np.mean(rewards)
            step_count += 1
        
        return episode_reward, step_count
    
    def update(self, batch_size=64, n_updates=4):
        """Update model using PPO"""
        if len(self.buffer) < batch_size:
            return {}
        
        losses = {
            'actor_loss': [],
            'critic_loss': [],
            'pred_loss': [],
            'entropy': []
        }
        
        for _ in range(n_updates):
            batch = self.buffer.sample(batch_size)
            
            self.optimizer.zero_grad()
            loss, loss_dict = compute_ppo_loss(self.model, batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
            self.optimizer.step()
            
            for key in loss_dict:
                losses[key].append(loss_dict[key])
        
        avg_losses = {k: np.mean(v) for k, v in losses.items()}
        return avg_losses

print("✓ Trainer class defined")

In [None]:
def train_xlight(n_episodes=50, update_interval=5, plot_live=True):
    """Main training loop with live plotting"""
    
    # Create environment
    env = SimpleTrafficEnv(n_intersections=16, n_neighbors=4)
    
    # Create model
    model = XLight(
        obs_dim=25,
        action_dim=8,
        d_model=64,
        d_output=128,
        n_neighbors=4,
        history_len=10,
        nhead=4,
        num_layers=3
    )
    
    # Create trainer
    trainer = XLightTrainer(model, env, lr=3e-4)
    
    # Training loop
    episode_rewards = []
    actor_losses = []
    critic_losses = []
    
    print("Starting training...")
    print("=" * 50)
    
    for episode in range(n_episodes):
        # Train episode
        episode_reward, steps = trainer.train_episode()
        episode_rewards.append(episode_reward)
        
        # Update model
        if episode % update_interval == 0 and episode > 0:
            losses = trainer.update(batch_size=64, n_updates=4)
            
            if losses:
                actor_losses.append(losses['actor_loss'])
                critic_losses.append(losses['critic_loss'])
            
            print(f"Episode {episode}/{n_episodes} | Reward: {episode_reward:.2f} | Steps: {steps}")
            if losses:
                print(f"  Actor Loss: {losses['actor_loss']:.4f} | Critic Loss: {losses['critic_loss']:.4f}")
                print(f"  Pred Loss: {losses['pred_loss']:.4f} | Entropy: {losses['entropy']:.4f}")
            
            # Live plotting
            if plot_live and episode > update_interval:
                clear_output(wait=True)
                
                fig, axes = plt.subplots(1, 2, figsize=(14, 4))
                
                # Plot rewards
                axes[0].plot(episode_rewards, label='Episode Reward', color='blue', alpha=0.6)
                axes[0].plot(np.convolve(episode_rewards, np.ones(5)/5, mode='valid'), 
                           label='Moving Avg (5)', color='red', linewidth=2)
                axes[0].set_xlabel('Episode')
                axes[0].set_ylabel('Average Reward')
                axes[0].set_title('Training Progress')
                axes[0].legend()
                axes[0].grid(True, alpha=0.3)
                
                # Plot losses
                if actor_losses:
                    x_loss = np.arange(len(actor_losses)) * update_interval
                    axes[1].plot(x_loss, actor_losses, label='Actor Loss', color='orange')
                    axes[1].plot(x_loss, critic_losses, label='Critic Loss', color='green')
                    axes[1].set_xlabel('Episode')
                    axes[1].set_ylabel('Loss')
                    axes[1].set_title('Training Losses')
                    axes[1].legend()
                    axes[1].grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                print(f"\nEpisode {episode}/{n_episodes}")
                print(f"Latest Reward: {episode_reward:.2f}")
                print(f"Average Reward (last 10): {np.mean(episode_rewards[-10:]):.2f}")
            
            # Clear buffer periodically
            if episode % (update_interval * 4) == 0:
                trainer.buffer.clear()
    
    print("\n" + "=" * 50)
    print("Training completed!")
    
    return model, trainer, episode_rewards

# This cell just defines the function, don't run training yet
print("✓ Training function defined")

In [None]:
# RUN THIS CELL TO START TRAINING
model, trainer, rewards = train_xlight(n_episodes=50, update_interval=5, plot_live=True)

In [None]:
# Plot final results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(rewards, label='Episode Reward', alpha=0.6)
plt.plot(np.convolve(rewards, np.ones(5)/5, mode='valid'), 
         label='Moving Average (5)', linewidth=2)
plt.xlabel('Episode')
plt.ylabel('Average Reward')
plt.title('Training Progress - Final')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.hist(rewards, bins=20, edgecolor='black', alpha=0.7)
plt.xlabel('Reward')
plt.ylabel('Frequency')
plt.title('Reward Distribution')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final Statistics:")
print(f"  Mean Reward: {np.mean(rewards):.2f}")
print(f"  Std Reward: {np.std(rewards):.2f}")
print(f"  Best Reward: {np.max(rewards):.2f}")
print(f"  Worst Reward: {np.min(rewards):.2f}")

In [None]:
def evaluate_model(model, env, n_episodes=5):
    """Evaluate trained model"""
    print("Evaluating model...")
    total_rewards = []
    
    for ep in range(n_episodes):
        obs = env.reset()
        done = False
        episode_reward = 0
        
        # Initialize histories
        obs_histories = {}
        action_histories = {}
        reward_histories = {}
        
        for i in range(env.n_intersections):
            obs_histories[i] = deque(maxlen=10)
            action_histories[i] = deque(maxlen=10)
            reward_histories[i] = deque(maxlen=10)
            
            dummy_obs = np.zeros((1 + 4, 25))
            dummy_action = np.zeros((1 + 4, 8))
            dummy_reward = np.zeros((1 + 4, 1))
            
            for _ in range(10):
                obs_histories[i].append(dummy_obs)
                action_histories[i].append(dummy_action)
                reward_histories[i].append(dummy_reward)
        
        all_actions_onehot = np.zeros((env.n_intersections, 8))
        all_actions_onehot[:, 0] = 1
        all_rewards = np.zeros(env.n_intersections)
        
        while not done:
            actions = []
            
            for i in range(env.n_intersections):
                neighbor_ids = env.get_neighbors(i)
                
                obs_batch = [obs[i]]
                action_batch = [all_actions_onehot[i]]
                reward_batch = [[all_rewards[i]]]
                
                for nid in neighbor_ids:
                    if nid >= 0:
                        obs_batch.append(obs[nid])
                        action_batch.append(all_actions_onehot[nid])
                        reward_batch.append([all_rewards[nid]])
                    else:
                        obs_batch.append(np.zeros_like(obs[0]))
                        action_batch.append(np.zeros(8))
                        reward_batch.append([0.0])
                
                obs_data = np.array(obs_batch)
                action_data = np.array(action_batch)
                reward_data = np.array(reward_batch)
                
                obs_histories[i].append(obs_data)
                action_histories[i].append(action_data)
                reward_histories[i].append(reward_data)
                
                obs_hist = torch.FloatTensor(np.array(obs_histories[i])).unsqueeze(0)
                act_hist = torch.FloatTensor(np.array(action_histories[i])).unsqueeze(0)
                rew_hist = torch.FloatTensor(np.array(reward_histories[i])).unsqueeze(0)
                current_obs_tensor = torch.FloatTensor(obs[i]).unsqueeze(0)
                
                with torch.no_grad():
                    action_logits, _, _, _ = model(
                        obs_hist, act_hist, rew_hist, current_obs_tensor
                    )
                    action = torch.argmax(action_logits, dim=-1)
                
                actions.append(action.item())
            
            next_obs, rewards_step, done, _ = env.step(actions)
            
            all_actions_onehot = np.zeros((env.n_intersections, 8))
            for i, a in enumerate(actions):
                all_actions_onehot[i, a] = 1
            
            all_rewards = rewards_step
            obs = next_obs
            episode_reward += np.mean(rewards_step)
        
        total_rewards.append(episode_reward)
        print(f"  Eval Episode {ep+1}/{n_episodes}: Reward = {episode_reward:.2f}")
    
    print(f"\n✓ Evaluation Complete!")
    print(f"  Average Reward: {np.mean(total_rewards):.2f} ± {np.std(total_rewards):.2f}")
    return total_rewards

# Run evaluation
eval_rewards = evaluate_model(model, trainer.env, n_episodes=5)