In [1]:
%%writefile single_file_rl_training_final.py
#!/usr/bin/env python3
"""
Enhanced Single-file implementation of Reinforcement Learning for Hierarchical Employee Training Optimization
Based on the research paper by Soumedhik Bharati, Rupsha Sadhukhan, Debanjali Saha

Improvements implemented:
- Fixed learning rate scheduling with StepLR for consistent learning
- Reduced entropy coefficient for more decisive policies
- Enhanced budget management through increased cost penalties
- Amplified efficiency bonuses to encourage budget awareness

Usage:
    python single_file_rl_training.py --mode train --episodes 3000 --cost-penalty 0.015
    python single_file_rl_training.py --mode compare --episodes 1000
"""

import os
import sys
import argparse
import time
import copy
from collections import defaultdict, deque
from typing import Dict, List, Tuple, Optional, Literal
from dataclasses import dataclass

# Core libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Gymnasium for RL environment
import gymnasium as gym
from gymnasium import spaces

# PyTorch for neural networks
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.optim.lr_scheduler import StepLR, ExponentialLR

# =============================================================================
# CONFIGURATION AND HYPERPARAMETERS
# =============================================================================

@dataclass
class TrainingConfig:
    """Configuration class for training parameters with budget-aware defaults"""
    # Training parameters
    num_episodes: int = 3000  # Increased for budget management learning
    gamma: float = 0.99
    learning_rate: float = 3e-4
    hidden_dim: int = 128
    use_baseline: bool = True
    entropy_coefficient: float = 0.001  # FIXED: Reduced from 0.01 for more decisive policy
    
    # Learning rate scheduling parameters
    lr_step_size: int = 750  # NEW: Steps before LR decay
    lr_gamma: float = 0.9    # NEW: LR decay factor
    
    # Environment parameters
    D: int = 8  # Number of skills
    K: int = 4  # Number of training modules
    alpha: List[float] = None  # Learning rates for modules
    beta: float = 0.01  # Reduced forgetting rate
    kappa: float = 1.2  # Reduced diminishing returns
    C_max: float = 120.0  # Increased budget
    
    # Reward strategy
    reward_strategy: Literal['basic', 'terminal', 'efficiency', 'hybrid'] = 'hybrid'
    cost_penalty: float = 0.01  # FIXED: Increased from 0.002 for budget awareness (5x increase)
    skill_amplifier: float = 1.0  # Amplify skill improvements
    terminal_bonus_multiplier: float = 1.5  # Terminal reward multiplier
    
    # Budget penalty parameters - FIXED: Increased for harsher budget enforcement
    base_budget_penalty: float = 4.0  # FIXED: Increased from 2.0
    max_budget_penalty: float = 8.0   # FIXED: Increased from 4.0
    
    # Logging and saving
    log_interval: int = 25  # More frequent logging for better monitoring
    save_interval: int = 500  # Less frequent saves for faster training
    model_save_path: str = 'models/employee_training_model.pth'
    plot_save_path: str = 'plots/'
    
    # Evaluation parameters
    eval_episodes: int = 100
    eval_render: bool = False
    
    def __post_init__(self):
        if self.alpha is None:
            self.alpha = [0.3, 0.25, 0.2, 0.35]
        
        # Validate configuration
        if self.entropy_coefficient > 0.1:
            print(f"WARNING: High entropy coefficient ({self.entropy_coefficient}) may prevent policy convergence")
        
        if self.lr_step_size > self.num_episodes // 2:
            print(f"WARNING: LR step size ({self.lr_step_size}) is large relative to training episodes ({self.num_episodes})")
        
        # Budget management validation
        if self.cost_penalty < 0.005:
            print(f"INFO: Low cost penalty ({self.cost_penalty}) may lead to budget overruns")


# =============================================================================
# EMPLOYEE TRAINING ENVIRONMENT
# =============================================================================

class EmployeeTrainingEnv(gym.Env):
    """
    Enhanced Custom Gymnasium environment with budget-aware reward structure.
    """
    
    metadata = {"render_modes": ["human"]}
    
    def __init__(self, config: TrainingConfig):
        super().__init__()
        
        self.config = config
        self.D = config.D
        self.K = config.K
        self.beta = config.beta
        self.kappa = config.kappa
        self.C_max = config.C_max
        self.gamma = config.gamma
        
        # Learning rates for each training module
        self.alpha = config.alpha
        
        # Training module costs
        self.costs = [10.0, 15.0, 20.0, 12.0]
        
        # Define which sub-attributes each training module targets
        self.module_targets = {
            0: [0, 1],      # Technical Skills: Coding, Debugging
            1: [2, 3],      # Technical Skills: Testing, Architecture
            2: [4, 5],      # Soft Skills: Communication, Leadership
            3: [6, 7]       # Soft Skills: Teamwork, Problem-solving
        }
        
        # Cross-attribute synergy matrix
        self.synergy_matrix = self._initialize_synergy_matrix()
        
        # Gymnasium spaces
        self.observation_space = spaces.Box(low=0.0, high=1.0, shape=(self.D,), dtype=np.float32)
        self.action_space = spaces.Discrete(self.K)
        
        # Episode state
        self.current_skills = None
        self.current_cost = 0.0
        self.episode_length = 0
        self.max_episode_length = 50
        self.initial_skills = None  # Track initial state for terminal rewards
        
    def _initialize_synergy_matrix(self) -> np.ndarray:
        """Initialize the cross-attribute synergy matrix."""
        synergy = np.zeros((self.D, self.D))
        
        # Within technical skills
        synergy[0, 1] = synergy[1, 0] = 0.3  # Coding <-> Debugging
        synergy[0, 2] = synergy[2, 0] = 0.2  # Coding <-> Testing
        synergy[1, 2] = synergy[2, 1] = 0.4  # Debugging <-> Testing
        synergy[2, 3] = synergy[3, 2] = 0.3  # Testing <-> Architecture
        
        # Within soft skills
        synergy[4, 5] = synergy[5, 4] = 0.4  # Communication <-> Leadership
        synergy[4, 6] = synergy[6, 4] = 0.3  # Communication <-> Teamwork
        synergy[5, 6] = synergy[6, 5] = 0.2  # Leadership <-> Teamwork
        synergy[6, 7] = synergy[7, 6] = 0.3  # Teamwork <-> Problem-solving
        
        # Cross-domain synergies
        synergy[1, 7] = synergy[7, 1] = 0.15  # Debugging <-> Problem-solving
        synergy[3, 5] = synergy[5, 3] = 0.1   # Architecture <-> Leadership
        
        return synergy
    
    def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
        """Reset the environment to initial state."""
        super().reset(seed=seed)
        
        # Initialize skills randomly between 0.1 and 0.6
        self.current_skills = self.np_random.uniform(0.1, 0.6, size=self.D).astype(np.float32)
        self.initial_skills = self.current_skills.copy()  # Store initial state
        self.current_cost = 0.0
        self.episode_length = 0
        
        return self.current_skills.copy(), {}
    
    def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        """Execute one step with enhanced budget penalty handling."""
        if action < 0 or action >= self.K:
            raise ValueError(f"Invalid action: {action}")
        
        # Calculate cost and check budget constraint
        action_cost = self.costs[action]
        
        # FIXED: Enhanced budget penalty - harsher penalties for better budget discipline
        if self.current_cost + action_cost > self.C_max:
            # Calculate proportional penalty based on budget overrun
            overrun_amount = (self.current_cost + action_cost - self.C_max)
            overrun_ratio = overrun_amount / self.C_max
            
            # Penalty scales from base_penalty to max_penalty based on overrun
            penalty_scale = min(1.0, overrun_ratio)
            budget_penalty = -(self.config.base_budget_penalty + 
                             penalty_scale * (self.config.max_budget_penalty - self.config.base_budget_penalty))
            
            info = {
                "budget_exceeded": True,
                "current_cost": self.current_cost + action_cost,
                "overrun_amount": overrun_amount,
                "budget_penalty": budget_penalty
            }
            
            return self.current_skills.copy(), budget_penalty, True, False, info
        
        # Store previous skills for reward calculation
        prev_skills = self.current_skills.copy()
        
        # Apply training module
        self.current_skills = self._apply_training(self.current_skills, action)
        
        # Update cost and episode length
        self.current_cost += action_cost
        self.episode_length += 1
        
        # Check termination conditions
        terminated = (self.episode_length >= self.max_episode_length or 
                     self.current_cost >= self.C_max)
        
        # Calculate reward based on strategy
        reward = self._calculate_reward(prev_skills, self.current_skills, action_cost, terminated)
        
        info = {
            "current_cost": self.current_cost,
            "episode_length": self.episode_length,
            "skill_improvement": np.sum(self.current_skills - prev_skills),
            "total_skill_improvement": np.sum(self.current_skills - self.initial_skills),
            "budget_utilization": self.current_cost / self.C_max,
            "terminated": terminated,
            "budget_exceeded": False
        }
        
        return self.current_skills.copy(), reward, terminated, False, info
    
    def _apply_training(self, skills: np.ndarray, action: int) -> np.ndarray:
        """Apply training module to current skills."""
        new_skills = skills.copy()
        alpha_a = self.alpha[action]
        target_attributes = self.module_targets[action]
        
        # Calculate potential gains for each attribute
        for j in range(self.D):
            # Direct training effect
            if j in target_attributes:
                delta_j = (1 - skills[j]) ** self.kappa
            else:
                # Cross-attribute synergy effect
                delta_j = 0.0
                for k in target_attributes:
                    delta_j += self.synergy_matrix[j, k] * (1 - skills[j]) ** self.kappa
            
            # Apply training gain and forgetting
            new_skills[j] = skills[j] + alpha_a * delta_j - self.beta * skills[j]
        
        # Clip skills to valid range [0, 1]
        new_skills = np.clip(new_skills, 0.0, 1.0)
        
        return new_skills
    
    def _calculate_reward(self, prev_skills: np.ndarray, new_skills: np.ndarray, 
                         cost: float, terminated: bool) -> float:
        """Calculate reward with enhanced budget awareness."""
        skill_improvement = np.sum(new_skills - prev_skills)
        
        if self.config.reward_strategy == 'basic':
            # Basic: Amplified skill gain - increased cost penalty
            return (self.config.skill_amplifier * skill_improvement - 
                   self.config.cost_penalty * cost)
        
        elif self.config.reward_strategy == 'terminal':
            # Terminal: Small step rewards + large terminal bonus
            step_reward = skill_improvement - self.config.cost_penalty * cost
            
            if terminated and self.current_cost <= self.C_max:
                # Terminal bonus based on total skill level and budget efficiency
                total_skill_level = np.sum(self.current_skills)
                budget_efficiency = (self.C_max - self.current_cost) / self.C_max
                terminal_bonus = (total_skill_level * self.config.terminal_bonus_multiplier + 
                                budget_efficiency * 2.0)
                return step_reward + terminal_bonus
            
            return step_reward
        
        elif self.config.reward_strategy == 'efficiency':
            # Efficiency: Reward skill gains, bonus for finishing under budget
            step_reward = skill_improvement  # No cost penalty during episode
            
            if terminated:
                # End-of-episode rewards
                total_skills = np.sum(self.current_skills)
                if self.current_cost <= self.C_max:
                    budget_bonus = (self.C_max - self.current_cost) * 0.1
                    skill_bonus = total_skills * 1.5
                    return step_reward + budget_bonus + skill_bonus
                else:
                    return step_reward - 8.0  # Increased penalty for exceeding budget
            
            return step_reward
        
        elif self.config.reward_strategy == 'hybrid':
            # FIXED: Enhanced hybrid strategy with budget-aware terminal bonus
            base_reward = self.config.skill_amplifier * skill_improvement
            cost_penalty = self.config.cost_penalty * cost
            
            if terminated and self.current_cost <= self.C_max:
                # FIXED: Enhanced terminal bonus heavily favoring efficiency
                total_improvement = np.sum(self.current_skills - self.initial_skills)
                
                # Amplify efficiency bonus by 2x and keep improvement bonus the same
                efficiency_bonus = ((self.C_max - self.current_cost) / self.C_max) * 2.0  # 2x amplification
                improvement_bonus = total_improvement * 0.5  # Keep same
                
                terminal_bonus = improvement_bonus + efficiency_bonus
                return base_reward - cost_penalty + terminal_bonus
            
            return base_reward - cost_penalty
        
        else:
            raise ValueError(f"Unknown reward strategy: {self.config.reward_strategy}")
    
    def get_hierarchical_skills(self, skills: np.ndarray) -> Dict[str, float]:
        """Calculate hierarchical skill aggregations."""
        return {
            "technical_skills": np.mean(skills[0:4]),
            "soft_skills": np.mean(skills[4:8]),
            "coding_debugging": np.mean(skills[0:2]),
            "testing_architecture": np.mean(skills[2:4]),
            "communication_leadership": np.mean(skills[4:6]),
            "teamwork_problem_solving": np.mean(skills[6:8])
        }
    
    def render(self, mode: str = "human") -> None:
        """Render the current state."""
        if mode == "human":
            hierarchical = self.get_hierarchical_skills(self.current_skills)
            print(f"Episode Length: {self.episode_length}, Cost: {self.current_cost:.2f}/{self.C_max}")
            print("Hierarchical Skills:")
            for skill_name, value in hierarchical.items():
                print(f"  {skill_name}: {value:.3f}")
            print(f"Individual Skills: {self.current_skills}")


# =============================================================================
# ENHANCED NEURAL NETWORKS AND AGENT
# =============================================================================

class PolicyNetwork(nn.Module):
    """Enhanced policy network with better initialization and regularization."""
    
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.1)
        
        # Better initialization
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights using Xavier initialization."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return F.softmax(x, dim=-1)


class ValueNetwork(nn.Module):
    """Enhanced value network with better initialization."""
    
    def __init__(self, input_dim: int, hidden_dim: int):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(0.1)
        
        # Better initialization
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights using Xavier initialization."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        return self.fc3(x)


class EnhancedREINFORCEAgent:
    """
    Enhanced REINFORCE agent with fixed learning rate scheduling.
    """
    
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.state_dim = config.D
        self.action_dim = config.K
        self.use_baseline = config.use_baseline
        self.entropy_coefficient = config.entropy_coefficient
        
        # Policy network
        self.policy_net = PolicyNetwork(self.state_dim, config.hidden_dim, self.action_dim)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=config.learning_rate)
        
        # FIXED: Use StepLR instead of ExponentialLR for stable learning
        self.policy_scheduler = StepLR(
            self.policy_optimizer, 
            step_size=config.lr_step_size, 
            gamma=config.lr_gamma
        )
        
        # Value network (optional baseline)
        if self.use_baseline:
            self.value_net = ValueNetwork(self.state_dim, config.hidden_dim)
            self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=config.learning_rate)
            self.value_scheduler = StepLR(
                self.value_optimizer, 
                step_size=config.lr_step_size, 
                gamma=config.lr_gamma
            )
        
        # Episode memory
        self.log_probs: List[torch.Tensor] = []
        self.rewards: List[float] = []
        self.states: List[torch.Tensor] = []
        self.entropies: List[torch.Tensor] = []  # For entropy regularization
        
        # Training statistics
        self.training_stats = {
            'policy_losses': [],
            'value_losses': [],
            'entropy_losses': [],
            'learning_rates': [],
            'lr_decay_steps': []  # Track when LR decays occur
        }
        
        print(f"Initialized budget-aware agent with LR={config.learning_rate}, "
              f"cost_penalty={config.cost_penalty}, step_size={config.lr_step_size}")
        
    def select_action(self, state: np.ndarray) -> int:
        """Select action using current policy and track entropy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        action_probs = self.policy_net(state_tensor)
        
        # Sample action from probability distribution
        dist = Categorical(action_probs)
        action = dist.sample()
        
        # Store log probability, entropy, and state
        self.log_probs.append(dist.log_prob(action))
        self.entropies.append(dist.entropy())
        self.states.append(state_tensor)
        
        return action.item()
    
    def store_reward(self, reward: float) -> None:
        """Store reward for current step."""
        self.rewards.append(reward)
    
    def update_policy(self, gamma: float = 0.99) -> Dict[str, float]:
        """Enhanced policy update with improved learning rate tracking."""
        if len(self.rewards) == 0:
            return {"policy_loss": 0.0, "value_loss": 0.0, "entropy_loss": 0.0, "learning_rate": 0.0}
        
        # Store current learning rate before any updates
        current_lr = self.policy_optimizer.param_groups[0]['lr']
        
        # Calculate discounted returns
        returns = []
        R = 0
        for r in reversed(self.rewards):
            R = r + gamma * R
            returns.insert(0, R)
        
        returns = torch.tensor(returns, dtype=torch.float32)
        
        # Normalize returns for stability
        if len(returns) > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        # Calculate policy loss components
        policy_losses = []
        value_losses = []
        entropy_losses = []
        
        for i, (log_prob, entropy, R) in enumerate(zip(self.log_probs, self.entropies, returns)):
            if self.use_baseline:
                # Use value function as baseline
                state = self.states[i]
                baseline = self.value_net(state).squeeze()
                advantage = R - baseline
                policy_losses.append(-log_prob * advantage.detach())
                value_losses.append(F.mse_loss(baseline, R))
            else:
                policy_losses.append(-log_prob * R)
            
            entropy_losses.append(entropy)
        
        # Combine losses
        policy_loss = torch.stack(policy_losses).mean()
        entropy_loss = torch.stack(entropy_losses).mean()
        
        # Total policy loss with entropy regularization
        total_policy_loss = policy_loss - self.entropy_coefficient * entropy_loss
        
        # Update policy network
        self.policy_optimizer.zero_grad()
        total_policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.policy_optimizer.step()
        
        # Track LR decay steps
        prev_lr = current_lr
        self.policy_scheduler.step()
        new_lr = self.policy_optimizer.param_groups[0]['lr']
        
        if new_lr != prev_lr:
            self.training_stats['lr_decay_steps'].append(len(self.training_stats['learning_rates']))
        
        # Update value network if using baseline
        value_loss = torch.tensor(0.0)
        if self.use_baseline and value_losses:
            value_loss = torch.stack(value_losses).mean()
            
            self.value_optimizer.zero_grad()
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 1.0)
            self.value_optimizer.step()
            self.value_scheduler.step()
        
        # Store training statistics
        self.training_stats['policy_losses'].append(policy_loss.item())
        self.training_stats['value_losses'].append(value_loss.item())
        self.training_stats['entropy_losses'].append(entropy_loss.item())
        self.training_stats['learning_rates'].append(new_lr)
        
        # Clear episode memory
        self.log_probs.clear()
        self.rewards.clear()
        self.states.clear()
        self.entropies.clear()
        
        return {
            "policy_loss": policy_loss.item(),
            "value_loss": value_loss.item(),
            "entropy_loss": entropy_loss.item(),
            "learning_rate": new_lr
        }
    
    def get_state_value(self, state: np.ndarray) -> float:
        """Get state value estimate (if using baseline)."""
        if not self.use_baseline:
            return 0.0
        
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            return self.value_net(state_tensor).item()
    
    def save_model(self, filepath: str) -> None:
        """Save trained model with training statistics."""
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        
        checkpoint = {
            'policy_net': self.policy_net.state_dict(),
            'policy_optimizer': self.policy_optimizer.state_dict(),
            'policy_scheduler': self.policy_scheduler.state_dict(),
            'config': self.config,
            'training_stats': self.training_stats
        }
        
        if self.use_baseline:
            checkpoint['value_net'] = self.value_net.state_dict()
            checkpoint['value_optimizer'] = self.value_optimizer.state_dict()
            checkpoint['value_scheduler'] = self.value_scheduler.state_dict()
        
        torch.save(checkpoint, filepath)
    
    def load_model(self, filepath: str) -> None:
        """Load trained model."""
        if not os.path.exists(filepath):
            print(f"Warning: Model file not found at {filepath}. Using random agent.")
            return
        
        checkpoint = torch.load(filepath)
        self.policy_net.load_state_dict(checkpoint['policy_net'])
        self.policy_optimizer.load_state_dict(checkpoint['policy_optimizer'])
        
        if 'policy_scheduler' in checkpoint:
            self.policy_scheduler.load_state_dict(checkpoint['policy_scheduler'])
        
        if self.use_baseline and 'value_net' in checkpoint:
            self.value_net.load_state_dict(checkpoint['value_net'])
            self.value_optimizer.load_state_dict(checkpoint['value_optimizer'])
            if 'value_scheduler' in checkpoint:
                self.value_scheduler.load_state_dict(checkpoint['value_scheduler'])
        
        if 'training_stats' in checkpoint:
            self.training_stats = checkpoint['training_stats']


# =============================================================================
# ENHANCED TRAINING LOOP
# =============================================================================

class EnhancedTrainingLoop:
    """Enhanced training loop with budget management focus."""
    
    def __init__(self, env: EmployeeTrainingEnv, agent: EnhancedREINFORCEAgent, config: TrainingConfig):
        self.env = env
        self.agent = agent
        self.config = config
        
        # Metrics tracking
        self.episode_rewards = []
        self.episode_lengths = []
        self.episode_costs = []
        self.policy_losses = []
        self.value_losses = []
        self.entropy_losses = []
        self.learning_rates = []
        self.skill_improvements = []
        self.budget_utilizations = []
        self.success_episodes = []
        self.budget_exceeded_episodes = []  # Track budget violations
        
        # Recent performance tracking
        self.recent_rewards = deque(maxlen=100)
        self.recent_lengths = deque(maxlen=100)
        self.recent_success = deque(maxlen=100)
        self.recent_entropy = deque(maxlen=100)  # Track entropy trend
        self.recent_budget_exceeded = deque(maxlen=100)  # Track budget discipline
        
        # Performance milestones
        self.first_positive_reward_episode = None
        self.first_80_percent_success = None
        self.first_budget_discipline = None  # NEW: When budget exceeded rate drops below 20%
        
    def run_episode(self) -> Dict[str, float]:
        """Run a single training episode with enhanced budget metrics."""
        state, _ = self.env.reset()
        episode_reward = 0
        episode_length = 0
        total_skill_improvement = 0
        budget_exceeded = False
        
        while True:
            # Select action
            action = self.agent.select_action(state)
            
            # Execute action
            next_state, reward, terminated, truncated, info = self.env.step(action)
            
            # Store reward
            self.agent.store_reward(reward)
            
            # Update metrics
            episode_reward += reward
            episode_length += 1
            total_skill_improvement += info.get('skill_improvement', 0)
            
            # Track budget exceeded
            if info.get('budget_exceeded', False):
                budget_exceeded = True
            
            # Update state
            state = next_state
            
            # Check termination
            if terminated or truncated:
                break
        
        # Update policy at end of episode
        losses = self.agent.update_policy(self.config.gamma)
        
        # Determine success (positive reward AND within budget)
        success = (episode_reward > 0 and not budget_exceeded)
        
        # Track milestones
        episode_num = len(self.episode_rewards) + 1
        if episode_reward > 0 and self.first_positive_reward_episode is None:
            self.first_positive_reward_episode = episode_num
            print(f"🎉 First positive reward achieved at episode {episode_num}!")
        
        return {
            'episode_reward': episode_reward,
            'episode_length': episode_length,
            'episode_cost': info.get('current_cost', 0),
            'skill_improvement': total_skill_improvement,
            'budget_utilization': info.get('budget_utilization', 0),
            'budget_exceeded': budget_exceeded,
            'success': success,
            'policy_loss': losses['policy_loss'],
            'value_loss': losses['value_loss'],
            'entropy_loss': losses['entropy_loss'],
            'learning_rate': losses['learning_rate']
        }
    
    def train(self) -> None:
        """Enhanced training loop with budget discipline tracking."""
        print("Starting Budget-Aware Training with Enhanced Cost Penalties...")
        print(f"Episodes: {self.config.num_episodes}")
        print(f"Environment: {self.env.D} skills, {self.env.K} training modules")
        print(f"Agent: {'Actor-Critic' if self.agent.use_baseline else 'REINFORCE'} with entropy regularization")
        print(f"Reward Strategy: {self.config.reward_strategy}")
        print(f"Learning Rate: {self.config.learning_rate} (StepLR: step={self.config.lr_step_size}, gamma={self.config.lr_gamma})")
        print(f"Cost Penalty: {self.config.cost_penalty} (5x increased for budget awareness)")
        print(f"Budget Penalties: base={self.config.base_budget_penalty}, max={self.config.max_budget_penalty}")
        print("-" * 90)
        
        start_time = time.time()
        best_avg_reward = float('-inf')
        episodes_since_improvement = 0
        
        for episode in range(self.config.num_episodes):
            # Run episode
            episode_metrics = self.run_episode()
            
            # Store metrics
            self.episode_rewards.append(episode_metrics['episode_reward'])
            self.episode_lengths.append(episode_metrics['episode_length'])
            self.episode_costs.append(episode_metrics['episode_cost'])
            self.skill_improvements.append(episode_metrics['skill_improvement'])
            self.budget_utilizations.append(episode_metrics['budget_utilization'])
            self.success_episodes.append(episode_metrics['success'])
            self.budget_exceeded_episodes.append(episode_metrics['budget_exceeded'])
            self.policy_losses.append(episode_metrics['policy_loss'])
            self.value_losses.append(episode_metrics['value_loss'])
            self.entropy_losses.append(episode_metrics['entropy_loss'])
            self.learning_rates.append(episode_metrics['learning_rate'])
            
            # Update recent performance
            self.recent_rewards.append(episode_metrics['episode_reward'])
            self.recent_lengths.append(episode_metrics['episode_length'])
            self.recent_success.append(episode_metrics['success'])
            self.recent_entropy.append(episode_metrics['entropy_loss'])
            self.recent_budget_exceeded.append(episode_metrics['budget_exceeded'])
            
            # Check for improvement and milestones
            if len(self.recent_rewards) >= 100:
                current_avg = np.mean(self.recent_rewards)
                current_success_rate = np.mean(self.recent_success)
                current_budget_exceeded_rate = np.mean(self.recent_budget_exceeded)
                
                if current_avg > best_avg_reward:
                    best_avg_reward = current_avg
                    episodes_since_improvement = 0
                else:
                    episodes_since_improvement += 1
                
                # Check for 80% success rate milestone
                if current_success_rate >= 0.8 and self.first_80_percent_success is None:
                    self.first_80_percent_success = episode + 1
                    print(f"🎯 80% success rate achieved at episode {episode + 1}!")
                
                # NEW: Check for budget discipline milestone
                if current_budget_exceeded_rate <= 0.2 and self.first_budget_discipline is None:
                    self.first_budget_discipline = episode + 1
                    print(f"💰 Budget discipline achieved (≤20% exceeded) at episode {episode + 1}!")
            
            # Logging with enhanced budget information
            if (episode + 1) % self.config.log_interval == 0:
                self._log_progress(episode + 1, episode_metrics)
            
            # Save model
            if (episode + 1) % self.config.save_interval == 0:
                self.agent.save_model(self.config.model_save_path)
                print(f"💾 Model saved at episode {episode + 1}")
        
        # Final save
        self.agent.save_model(self.config.model_save_path)
        
        # Training summary with budget focus
        training_time = time.time() - start_time
        final_success_rate = np.mean(self.recent_success) if len(self.recent_success) >= 10 else 0
        final_entropy = np.mean(self.recent_entropy) if len(self.recent_entropy) >= 10 else 0
        final_budget_exceeded_rate = np.mean(self.recent_budget_exceeded) if len(self.recent_budget_exceeded) >= 10 else 0
        
        print(f"\n{'='*90}")
        print("BUDGET-AWARE TRAINING COMPLETED")
        print(f"{'='*90}")
        print(f"Training time: {training_time:.2f} seconds")
        print(f"Best average reward (last 100): {best_avg_reward:.3f}")
        print(f"Final average reward (last 100): {np.mean(self.recent_rewards):.3f}")
        print(f"Final success rate (last 100): {final_success_rate:.2%}")
        print(f"Final budget exceeded rate (last 100): {final_budget_exceeded_rate:.2%} 🎯")
        print(f"Final entropy: {final_entropy:.3f} (max possible: {np.log(self.env.K):.3f})")
        print(f"Final learning rate: {self.learning_rates[-1]:.2e}")
        
        if self.first_positive_reward_episode:
            print(f"First positive reward: Episode {self.first_positive_reward_episode}")
        if self.first_80_percent_success:
            print(f"80% success milestone: Episode {self.first_80_percent_success}")
        if self.first_budget_discipline:
            print(f"Budget discipline milestone: Episode {self.first_budget_discipline}")
        
        # LR decay analysis
        lr_decays = len(self.agent.training_stats['lr_decay_steps'])
        print(f"Learning rate decayed {lr_decays} times during training")
        
    def _log_progress(self, episode: int, metrics: Dict[str, float]) -> None:
        """Enhanced progress logging with budget management focus."""
        avg_reward = np.mean(self.recent_rewards)
        avg_length = np.mean(self.recent_lengths)
        success_rate = np.mean(self.recent_success)
        avg_entropy = np.mean(self.recent_entropy)
        budget_exceeded_rate = np.mean(self.recent_budget_exceeded)
        
        # Trend indicators
        reward_trend = "📈" if len(self.recent_rewards) >= 50 and avg_reward > np.mean(list(self.recent_rewards)[:25]) else "📉"
        entropy_trend = "📉" if len(self.recent_entropy) >= 50 and avg_entropy < np.mean(list(self.recent_entropy)[:25]) else "📈"
        budget_trend = "📉" if len(self.recent_budget_exceeded) >= 50 and budget_exceeded_rate < np.mean(list(self.recent_budget_exceeded)[:25]) else "📈"
        
        print(f"Episode {episode:4d} | "
              f"Reward: {metrics['episode_reward']:6.2f} | "
              f"Avg: {avg_reward:6.2f} {reward_trend} | "
              f"Success: {success_rate:5.2%} | "
              f"Budget Exceeded: {budget_exceeded_rate:5.2%} {budget_trend} | "
              f"Entropy: {avg_entropy:5.3f} {entropy_trend} | "
              f"LR: {metrics['learning_rate']:.2e}")
    
    def evaluate(self, num_episodes: int = 100, render: bool = False) -> Dict[str, float]:
        """Enhanced evaluation with detailed budget analysis."""
        print(f"\nEvaluating budget-aware policy over {num_episodes} episodes...")
        
        eval_rewards = []
        eval_lengths = []
        eval_costs = []
        eval_skill_improvements = []
        eval_budget_utilizations = []
        eval_successes = []
        eval_budget_exceeded = []
        
        for episode in range(num_episodes):
            state, _ = self.env.reset()
            episode_reward = 0
            episode_length = 0
            total_skill_improvement = 0
            budget_exceeded = False
            
            while True:
                # Select action (deterministic for evaluation)
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(state).unsqueeze(0)
                    action_probs = self.agent.policy_net(state_tensor)
                    action = torch.argmax(action_probs).item()
                
                next_state, reward, terminated, truncated, info = self.env.step(action)
                
                episode_reward += reward
                episode_length += 1
                total_skill_improvement += info.get('skill_improvement', 0)
                
                if info.get('budget_exceeded', False):
                    budget_exceeded = True
                
                if render and episode == 0:
                    self.env.render()
                
                state = next_state
                
                if terminated or truncated:
                    break
            
            eval_rewards.append(episode_reward)
            eval_lengths.append(episode_length)
            eval_costs.append(info.get('current_cost', 0))
            eval_skill_improvements.append(total_skill_improvement)
            eval_budget_utilizations.append(info.get('budget_utilization', 0))
            eval_successes.append(episode_reward > 0 and not budget_exceeded)
            eval_budget_exceeded.append(budget_exceeded)
        
        results = {
            'mean_reward': np.mean(eval_rewards),
            'std_reward': np.std(eval_rewards),
            'mean_length': np.mean(eval_lengths),
            'mean_cost': np.mean(eval_costs),
            'mean_skill_improvement': np.mean(eval_skill_improvements),
            'mean_budget_utilization': np.mean(eval_budget_utilizations),
            'success_rate': np.mean(eval_successes),
            'budget_exceeded_rate': np.mean(eval_budget_exceeded),
            'reward_improvement': np.mean(eval_rewards) - np.mean(self.episode_rewards[:100]) if len(self.episode_rewards) >= 100 else 0
        }
        
        print("Budget-Aware Evaluation Results:")
        print(f"  Mean Reward: {results['mean_reward']:.3f} ± {results['std_reward']:.3f}")
        print(f"  Mean Length: {results['mean_length']:.2f}")
        print(f"  Mean Cost: {results['mean_cost']:.2f}")
        print(f"  Mean Skill Improvement: {results['mean_skill_improvement']:.3f}")
        print(f"  Mean Budget Utilization: {results['mean_budget_utilization']:.2%}")
        print(f"  Success Rate: {results['success_rate']:.2%} 🎯")
        print(f"  Budget Exceeded Rate: {results['budget_exceeded_rate']:.2%} 💰")
        print(f"  Reward Improvement from Start: {results['reward_improvement']:+.3f}")
        
        return results
    
    def plot_enhanced_training_curves(self, save_path: str = None) -> None:
        """Enhanced training curves with budget management analysis."""
        fig, axes = plt.subplots(3, 3, figsize=(20, 15))
        
        # Episode rewards with milestone markers
        axes[0, 0].plot(self.episode_rewards, alpha=0.7, label='Raw')
        axes[0, 0].plot(self._smooth_curve(self.episode_rewards, 50), 'r-', linewidth=2, label='Smoothed')
        
        # Mark milestones
        if self.first_positive_reward_episode:
            axes[0, 0].axvline(self.first_positive_reward_episode, color='green', linestyle='--', 
                             label=f'First Positive ({self.first_positive_reward_episode})')
        if self.first_80_percent_success:
            axes[0, 0].axvline(self.first_80_percent_success, color='blue', linestyle='--',
                             label=f'80% Success ({self.first_80_percent_success})')
        if self.first_budget_discipline:
            axes[0, 0].axvline(self.first_budget_discipline, color='purple', linestyle='--',
                             label=f'Budget Discipline ({self.first_budget_discipline})')
        
        axes[0, 0].set_title('Episode Rewards with Budget Milestones')
        axes[0, 0].set_xlabel('Episode')
        axes[0, 0].set_ylabel('Reward')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # Success rate with budget analysis
        success_rate = self._smooth_curve([int(s) for s in self.success_episodes], 50)
        budget_exceeded_rate = self._smooth_curve([int(b) for b in self.budget_exceeded_episodes], 50)
        axes[0, 1].plot(success_rate, 'g-', linewidth=2, label='Success Rate')
        axes[0, 1].plot(budget_exceeded_rate, 'r-', linewidth=2, label='Budget Exceeded Rate')
        axes[0, 1].axhline(0.2, color='orange', linestyle=':', label='Budget Discipline Target (20%)')
        axes[0, 1].set_title('Success vs Budget Management')
        axes[0, 1].set_xlabel('Episode')
        axes[0, 1].set_ylabel('Rate')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # Learning rate with decay markers
        axes[0, 2].plot(self.learning_rates, 'b-', linewidth=2)
        # Mark LR decay steps
        for decay_step in self.agent.training_stats.get('lr_decay_steps', []):
            if decay_step < len(self.learning_rates):
                axes[0, 2].axvline(decay_step, color='red', linestyle=':', alpha=0.7)
        axes[0, 2].set_title('Learning Rate Schedule')
        axes[0, 2].set_xlabel('Episode')
        axes[0, 2].set_ylabel('Learning Rate')
        axes[0, 2].set_yscale('log')
        axes[0, 2].grid(True)
        
        # Entropy evolution
        axes[1, 0].plot(self.entropy_losses, alpha=0.7, label='Raw')
        axes[1, 0].plot(self._smooth_curve(self.entropy_losses, 50), 'r-', linewidth=2, label='Smoothed')
        axes[1, 0].axhline(np.log(self.env.K), color='black', linestyle='--', 
                          label=f'Max Entropy ({np.log(self.env.K):.3f})')
        axes[1, 0].set_title('Policy Entropy Evolution')
        axes[1, 0].set_xlabel('Episode')
        axes[1, 0].set_ylabel('Entropy')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # Budget utilization with improved emphasis
        axes[1, 1].plot(self.budget_utilizations, alpha=0.7, label='Raw')
        axes[1, 1].plot(self._smooth_curve(self.budget_utilizations, 50), 'r-', linewidth=2, label='Smoothed')
        axes[1, 1].axhline(1.0, color='red', linestyle='--', linewidth=2, label='Budget Limit (100%)')
        axes[1, 1].axhline(0.8, color='orange', linestyle=':', label='Efficient Target (80%)')
        axes[1, 1].set_title('Budget Utilization (Key Metric)')
        axes[1, 1].set_xlabel('Episode')
        axes[1, 1].set_ylabel('Budget Used %')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
        
        # Skill improvements
        axes[1, 2].plot(self.skill_improvements, alpha=0.7)
        axes[1, 2].plot(self._smooth_curve(self.skill_improvements, 50), 'r-', linewidth=2)
        axes[1, 2].set_title('Skill Improvements')
        axes[1, 2].set_xlabel('Episode')
        axes[1, 2].set_ylabel('Total Skill Δ')
        axes[1, 2].grid(True)
        
        # Policy and value losses
        axes[2, 0].plot(self.policy_losses, alpha=0.7)
        axes[2, 0].plot(self._smooth_curve(self.policy_losses, 50), 'r-', linewidth=2)
        axes[2, 0].set_title('Policy Losses')
        axes[2, 0].set_xlabel('Episode')
        axes[2, 0].set_ylabel('Loss')
        axes[2, 0].grid(True)
        
        if self.agent.use_baseline:
            axes[2, 1].plot(self.value_losses, alpha=0.7)
            axes[2, 1].plot(self._smooth_curve(self.value_losses, 50), 'r-', linewidth=2)
            axes[2, 1].set_title('Value Losses')
            axes[2, 1].set_xlabel('Episode')
            axes[2, 1].set_ylabel('Loss')
            axes[2, 1].grid(True)
        else:
            axes[2, 1].text(0.5, 0.5, 'No Baseline Used', ha='center', va='center', 
                           transform=axes[2, 1].transAxes, fontsize=14)
            axes[2, 1].set_title('Value Losses')
        
        # Budget management effectiveness
        budget_effectiveness = []
        for i in range(len(self.episode_rewards)):
            if i >= 99:  # Calculate over last 100 episodes
                recent_success = np.mean(self.success_episodes[i-99:i+1])
                recent_budget_ok = 1 - np.mean(self.budget_exceeded_episodes[i-99:i+1])
                effectiveness = (recent_success + recent_budget_ok) / 2  # Combined metric
                budget_effectiveness.append(effectiveness)
        
        if budget_effectiveness:
            axes[2, 2].plot(range(99, len(self.episode_rewards)), budget_effectiveness, 'g-', linewidth=2)
            axes[2, 2].axhline(0.8, color='blue', linestyle='--', label='Target (80%)')
            axes[2, 2].set_title('Budget Management Effectiveness')
            axes[2, 2].set_xlabel('Episode')
            axes[2, 2].set_ylabel('Effectiveness Score')
            axes[2, 2].legend()
            axes[2, 2].grid(True)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Budget-aware training curves saved to {save_path}")
        
        plt.show()
    
    def _smooth_curve(self, values: List[float], window: int = 50) -> List[float]:
        """Apply moving average smoothing to a curve."""
        if len(values) < window:
            return values
        
        smoothed = []
        for i in range(len(values)):
            start = max(0, i - window + 1)
            end = i + 1
            smoothed.append(np.mean(values[start:end]))
        
        return smoothed


# =============================================================================
# PERFORMANCE COMPARISON FRAMEWORK
# =============================================================================

class PerformanceComparison:
    """Framework for comparing budget-aware configurations."""
    
    def __init__(self, base_config: TrainingConfig):
        self.base_config = base_config
        self.results = {}
        
    def run_comparison(self, configurations: Dict[str, TrainingConfig], episodes: int = 1500) -> Dict:
        """Run training with different budget-aware configurations."""
        print("=== Budget-Aware Performance Comparison Framework ===")
        print(f"Testing {len(configurations)} configurations with {episodes} episodes each")
        print("Focus: Budget management and cost-effectiveness")
        print("=" * 70)
        
        results = {}
        
        for config_name, config in configurations.items():
            print(f"\nTesting Configuration: {config_name}")
            print(f"  Reward Strategy: {config.reward_strategy}")
            print(f"  Cost Penalty: {config.cost_penalty}")
            print(f"  Budget Penalties: base={config.base_budget_penalty}, max={config.max_budget_penalty}")
            print(f"  Entropy Coefficient: {config.entropy_coefficient}")
            print("-" * 50)
            
            # Update episodes for comparison
            config.num_episodes = episodes
            config.log_interval = max(episodes // 20, 25)
            
            # Create environment and agent
            env = EmployeeTrainingEnv(config)
            agent = EnhancedREINFORCEAgent(config)
            training_loop = EnhancedTrainingLoop(env, agent, config)
            
            # Train
            start_time = time.time()
            training_loop.train()
            training_time = time.time() - start_time
            
            # Evaluate
            eval_results = training_loop.evaluate(num_episodes=50, render=False)
            
            # Store results with budget-focused metrics
            results[config_name] = {
                'config': config,
                'training_time': training_time,
                'final_rewards': training_loop.episode_rewards[-100:],
                'final_success_rate': np.mean(training_loop.success_episodes[-100:]),
                'final_budget_exceeded_rate': np.mean(training_loop.budget_exceeded_episodes[-100:]),
                'final_entropy': np.mean(training_loop.entropy_losses[-100:]),
                'budget_discipline_episode': training_loop.first_budget_discipline,
                'eval_results': eval_results,
                'training_loop': training_loop
            }
            
            print(f"  Training Time: {training_time:.2f}s")
            print(f"  Final Success Rate: {eval_results['success_rate']:.2%}")
            print(f"  Final Budget Exceeded Rate: {eval_results['budget_exceeded_rate']:.2%}")
            print(f"  Final Mean Reward: {eval_results['mean_reward']:.3f}")
        
        self.results = results
        return results
    
    def print_summary(self) -> None:
        """Print budget-focused comparison summary."""
        if not self.results:
            print("No results to summarize. Run comparison first.")
            return
        
        print("\n" + "="*100)
        print("BUDGET-AWARE PERFORMANCE COMPARISON SUMMARY")
        print("="*100)
        
        # Sort by success rate (primary) and low budget exceeded rate (secondary)
        sorted_results = sorted(self.results.items(), 
                              key=lambda x: (x[1]['eval_results']['success_rate'], 
                                           -x[1]['eval_results']['budget_exceeded_rate']), 
                              reverse=True)
        
        # Header
        print(f"{'Rank':<4} {'Configuration':<20} {'Success Rate':<12} {'Budget Exceed':<12} "
              f"{'Mean Reward':<12} {'Budget Discipline':<16}")
        print("-" * 100)
        
        # Results
        for i, (config_name, result) in enumerate(sorted_results, 1):
            eval_results = result['eval_results']
            budget_discipline = result['budget_discipline_episode'] or "Not achieved"
            
            print(f"{i:<4} {config_name:<20} {eval_results['success_rate']:<12.2%} "
                  f"{eval_results['budget_exceeded_rate']:<12.2%} {eval_results['mean_reward']:<12.3f} "
                  f"{str(budget_discipline):<16}")
        
        print("\n" + "="*100)
        print("WINNER (Best Budget Management):", sorted_results[0][0])
        winner_config = sorted_results[0][1]['config']
        winner_result = sorted_results[0][1]
        
        print(f"  Reward Strategy: {winner_config.reward_strategy}")
        print(f"  Cost Penalty: {winner_config.cost_penalty}")
        print(f"  Budget Penalties: base={winner_config.base_budget_penalty}, max={winner_config.max_budget_penalty}")
        print(f"  Success Rate: {winner_result['eval_results']['success_rate']:.2%}")
        print(f"  Budget Exceeded Rate: {winner_result['eval_results']['budget_exceeded_rate']:.2%}")
        
        if winner_result['budget_discipline_episode']:
            print(f"  Budget Discipline Achieved: Episode {winner_result['budget_discipline_episode']}")


# =============================================================================
# MAIN EXECUTION AND UTILITIES
# =============================================================================

def create_directories():
    """Create necessary directories for saving models and plots."""
    os.makedirs('models', exist_ok=True)
    os.makedirs('plots', exist_ok=True)
    os.makedirs('logs', exist_ok=True)


def get_budget_aware_configurations() -> Dict[str, TrainingConfig]:
    """
    Get refined budget-aware configurations to find the optimal trade-off 
    between high reward and strict budget management.
    """
    
    # --- Base parameters that work well ---
    base_params = {
        'learning_rate': 3e-4,
        'entropy_coefficient': 0.001,
        'lr_step_size': 1500, # Give more time to learn before LR decay
        'lr_gamma': 0.9,
        'reward_strategy': 'hybrid' # This strategy has proven effective
    }

    # --- Create a modified config for the fast learner ---
    fast_learner_params = base_params.copy()
    fast_learner_params['learning_rate'] = 5e-4 # Override the learning rate here

    return {
        # Hypothesis: A balanced approach. A modest increase in cost penalty 
        # might be enough to improve budget discipline without hurting rewards too much.
        'balanced_approach': TrainingConfig(
            **base_params,
            cost_penalty=0.01, # 5x the original
            base_budget_penalty=4.0,
            max_budget_penalty=8.0
        ),

        # Hypothesis: Slightly more aggressive on cost. This is the configuration
        # recommended from the last analysis. It should further reduce budget overruns.
        'strong_incentive': TrainingConfig(
            **base_params,
            cost_penalty=0.015, # 7.5x the original
            base_budget_penalty=5.0,
            max_budget_penalty=10.0
        ),

        # Hypothesis: Very conservative. This agent should almost never exceed the budget.
        # The key question is: will it be too scared to take actions, resulting in low skill gain?
        'ultra_conservative': TrainingConfig(
            **base_params,
            cost_penalty=0.025, # Over 10x the original penalty
            base_budget_penalty=6.0,
            max_budget_penalty=12.0
        ),

        # Hypothesis: Focus on the terminal reward. Maybe a huge bonus for finishing
        # with budget to spare is a better signal than a punishing per-step cost.
        'terminal_focus_strong': TrainingConfig(
            reward_strategy='terminal',
            cost_penalty=0.008, # Lower step penalty, as the focus is on the end
            entropy_coefficient=0.001,
            learning_rate=3e-4,
            base_budget_penalty=5.0,
            max_budget_penalty=10.0,
            terminal_bonus_multiplier=3.5 # Greatly increased terminal bonus
        ),
        
        # FIXED: Unpack the modified dictionary to avoid the TypeError
        'fast_learner_frugal': TrainingConfig(
            **fast_learner_params, # Use the modified params
            cost_penalty=0.02,  # High cost penalty
            base_budget_penalty=5.0,
            max_budget_penalty=10.0
        )
    }


def main():
    """Budget-aware main entry point with enhanced cost management."""
    parser = argparse.ArgumentParser(description='Budget-Aware Employee Training Optimization')
    
    # --- Existing Arguments ---
    parser.add_argument('--mode', choices=['train', 'evaluate', 'visualize', 'compare'], 
                       default='train', help='Run mode')
    parser.add_argument('--model', type=str, help='Path to saved model for evaluation')
    parser.add_argument('--episodes', type=int, default=3000, help='Number of training episodes')
    parser.add_argument('--reward-strategy', choices=['basic', 'terminal', 'efficiency', 'hybrid'], 
                       default='hybrid', help='Reward strategy to use')
    parser.add_argument('--cost-penalty', type=float, default=0.01, help='Cost penalty coefficient')
    parser.add_argument('--entropy-coef', type=float, default=0.001, help='Entropy regularization coefficient')
    parser.add_argument('--learning-rate', type=float, default=3e-4, help='Learning rate')
    parser.add_argument('--lr-step-size', type=int, default=750, help='Learning rate decay step size')
    parser.add_argument('--lr-gamma', type=float, default=0.9, help='Learning rate decay factor')
    parser.add_argument('--no-baseline', action='store_true', help='Disable baseline (Actor-Critic)')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    
    # --- FIXED: Add the missing arguments ---
    parser.add_argument('--base-budget-penalty', type=float, default=4.0, help='Base penalty for exceeding budget')
    parser.add_argument('--max-budget-penalty', type=float, default=8.0, help='Maximum penalty for exceeding budget')
    parser.add_argument('--terminal-bonus-multiplier', type=float, default=1.5, help='Multiplier for terminal bonus in reward function')

    args = parser.parse_args()
    
    # Set random seeds for reproducibility
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    # Create directories
    create_directories()
    
    print("Budget-Aware Employee Training Optimization System")
    print("=" * 80)
    print(f"Mode: {args.mode}")
    print(f"Seed: {args.seed}")
    print(f"User: Soumedhik")
    print("=" * 80)
    
    if args.mode == 'compare':
        # Performance comparison mode for budget-aware configurations
        print("Running performance comparison with budget-aware configurations...")
        
        base_config = TrainingConfig()
        comparison = PerformanceComparison(base_config)
        configurations = get_budget_aware_configurations()
        
        # Run the comparison
        results = comparison.run_comparison(configurations, episodes=args.episodes)
        
        # Summarize and present the results
        comparison.print_summary()
        
    else:
        # Create a single configuration for train, evaluate, or visualize modes
        # FIXED: Pass all the newly added arguments to the config
        config = TrainingConfig(
            num_episodes=args.episodes,
            reward_strategy=args.reward_strategy,
            cost_penalty=args.cost_penalty,
            entropy_coefficient=args.entropy_coef,
            learning_rate=args.learning_rate,
            lr_step_size=args.lr_step_size,
            lr_gamma=args.lr_gamma,
            use_baseline=not args.no_baseline,
            base_budget_penalty=args.base_budget_penalty,
            max_budget_penalty=args.max_budget_penalty,
            terminal_bonus_multiplier=args.terminal_bonus_multiplier
        )
        
        # Instantiate environment and agent
        env = EmployeeTrainingEnv(config)
        agent = EnhancedREINFORCEAgent(config)
        training_loop = EnhancedTrainingLoop(env, agent, config)
        
        print(f"Environment: {env.D} skills, {env.K} training modules")
        print(f"Agent: {'Enhanced Actor-Critic' if config.use_baseline else 'Enhanced REINFORCE'}")
        print(f"Reward Strategy: {config.reward_strategy}")
        print(f"Cost Penalty: {config.cost_penalty}")
        print(f"Learning Rate: {config.learning_rate} (StepLR: {config.lr_step_size}/{config.lr_gamma})")
        
        if args.mode == 'train':
            # Training mode
            print("\nStarting budget-aware training...")
            training_loop.train()
            
            # Plot the training curves
            training_loop.plot_enhanced_training_curves(
                save_path=os.path.join(config.plot_save_path, 'budget_aware_training_curves.png')
            )
            
            # Evaluate the newly trained model
            print("\nEvaluating trained model...")
            eval_results = training_loop.evaluate(
                num_episodes=config.eval_episodes,
                render=config.eval_render
            )
            
        elif args.mode == 'evaluate':
            # Evaluation mode
            if args.model:
                agent.load_model(args.model)
                print(f"Loaded model from {args.model}")
            else:
                print("Warning: No model specified for evaluation. Using a randomly initialized policy.")
            
            eval_results = training_loop.evaluate(
                num_episodes=config.eval_episodes,
                render=True
            )
            
        elif args.mode == 'visualize':
            # Visualization mode to inspect a trained agent's behavior
            if args.model:
                agent.load_model(args.model)
                print(f"Loaded model from {args.model}")
            else:
                print("Warning: No model specified for visualization. Using a randomly initialized policy.")

            print("\nGenerating visualizations for the agent's policy...")

            # --- 1. Visualize Skill Progression on Sample Episodes ---
            num_viz_episodes = 5
            skill_histories = []
            plt.figure(figsize=(12, 7))

            for i in range(num_viz_episodes):
                state, _ = env.reset()
                episode_skills = [state.copy()]
                
                while True:
                    with torch.no_grad():
                        action_probs = agent.policy_net(torch.FloatTensor(state).unsqueeze(0))
                        action = torch.argmax(action_probs).item()
                    
                    state, _, terminated, truncated, _ = env.step(action)
                    episode_skills.append(state.copy())

                    if terminated or truncated:
                        break
                
                skill_histories.append(np.array(episode_skills))

            # Plot average skill level over time for each sample episode
            for i, history in enumerate(skill_histories):
                avg_skill_per_step = np.mean(history, axis=1)
                plt.plot(avg_skill_per_step, alpha=0.8, label=f'Sample Run {i+1}')
            
            plt.title('Agent Policy: Average Skill Progression')
            plt.xlabel('Training Step in Episode')
            plt.ylabel('Average Skill Level')
            plt.legend()
            plt.grid(True)
            viz_path = os.path.join(config.plot_save_path, 'visualized_skill_progression.png')
            plt.savefig(viz_path, bbox_inches='tight')
            print(f"-> Skill progression plot saved to {viz_path}")
            plt.show()

            # --- 2. Visualize the Synergy Matrix ---
            plt.figure(figsize=(10, 8))
            skill_names = [
                'Coding', 'Debugging', 'Testing', 'Arch.',
                'Comm.', 'Leader.', 'Teamwork', 'Prob. Solv.'
            ]
            sns.heatmap(env.synergy_matrix, xticklabels=skill_names, yticklabels=skill_names,
                       annot=True, fmt='.2f', cmap='viridis', cbar_kws={'label': 'Synergy Coefficient'})
            plt.title('Environment: Cross-Attribute Synergy Matrix')
            synergy_path = os.path.join(config.plot_save_path, 'visualized_synergy_matrix.png')
            plt.savefig(synergy_path, bbox_inches='tight')
            print(f"-> Synergy matrix plot saved to {synergy_path}")
            plt.show()

            print("\nVisualizations generated and saved to 'plots/' directory.")

    print("\nEnhanced system execution completed.")

if __name__ == "__main__":
    main()

Writing single_file_rl_training_final.py


In [2]:
!python single_file_rl_training_final.py --mode compare --episodes 100000

Budget-Aware Employee Training Optimization System
Mode: compare
Seed: 42
User: Soumedhik
Running performance comparison with budget-aware configurations...
=== Budget-Aware Performance Comparison Framework ===
Testing 5 configurations with 100000 episodes each
Focus: Budget management and cost-effectiveness

Testing Configuration: balanced_approach
  Reward Strategy: hybrid
  Cost Penalty: 0.01
  Budget Penalties: base=4.0, max=8.0
  Entropy Coefficient: 0.001
--------------------------------------------------
Initialized budget-aware agent with LR=0.0003, cost_penalty=0.01, step_size=1500
Starting Budget-Aware Training with Enhanced Cost Penalties...
Episodes: 100000
Environment: 8 skills, 4 training modules
Agent: Actor-Critic with entropy regularization
Reward Strategy: hybrid
Learning Rate: 0.0003 (StepLR: step=1500, gamma=0.9)
Cost Penalty: 0.01 (5x increased for budget awareness)
Budget Penalties: base=4.0, max=8.0
----------------------------------------

In [3]:
!python single_file_rl_training_final.py --mode train \
    --episodes 250000 \
    --reward-strategy terminal \
    --cost-penalty 0.008 \
    --entropy-coef 0.001 \
    --base-budget-penalty 5.0 \
    --max-budget-penalty 10.0 \
    --terminal-bonus-multiplier 3.5

Budget-Aware Employee Training Optimization System
Mode: train
Seed: 42
User: Soumedhik
Initialized budget-aware agent with LR=0.0003, cost_penalty=0.008, step_size=750
Environment: 8 skills, 4 training modules
Agent: Enhanced Actor-Critic
Reward Strategy: terminal
Cost Penalty: 0.008
Learning Rate: 0.0003 (StepLR: 750/0.9)

Starting budget-aware training...
Starting Budget-Aware Training with Enhanced Cost Penalties...
Episodes: 250000
Environment: 8 skills, 4 training modules
Agent: Actor-Critic with entropy regularization
Reward Strategy: terminal
Learning Rate: 0.0003 (StepLR: step=750, gamma=0.9)
Cost Penalty: 0.008 (5x increased for budget awareness)
Budget Penalties: base=5.0, max=10.0
------------------------------------------------------------------------------------------
🎉 First positive reward achieved at episode 9!
Episode   25 | Reward:  -4.87 | Avg:  -2.14 📉 | Success: 8.00% | Budget Exceeded: 92.00% 📈 | Entropy: 1.382 📈 | LR: 3.00e-04
Episode   50