In [27]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [28]:
class CompetitiveHiringAGIEnv(gym.Env):
    """Competitive AGI Race Environment with Two Parties, Team Hiring, and Collaboration"""

    def __init__(self, team_size=5, s0=4.0, alpha=1.5):
        super(CompetitiveHiringAGIEnv, self).__init__()

        # Team management parameters
        self.team_size = team_size
        self.max_skill = team_size
        self.s0 = s0  # AGI threshold skill
        self.alpha = alpha  # Sharpness of AGI probability curve
        
        # Defining collaboration modes
        self.COLLAB_MODES = {
            'INDEPENDENT': 0,   # Both parties working independently
            'COLLABORATIVE': 1, # Both parties actively collaborating
            'POST_COLLAB': 2,   # Both previously collaborated but now independent
        }

        # Defining party statuses
        self.PARTY_STATUS = {
            'EXPLORING': 0,     # Actively exploring
            'RETREATED': 1,     # Retreated
            'FOUND_AGI': 2,     # Found AGI
        }

        # Collaboration mode (applies to both parties)
        self.collab_mode = self.COLLAB_MODES['INDEPENDENT']

        # Individual party statuses
        self.party_1_status = self.PARTY_STATUS['EXPLORING']
        self.party_2_status = self.PARTY_STATUS['EXPLORING']

        # Track which party has broken collaboration
        self.party_1_broke_collab = False
        self.party_2_broke_collab = False

        # Track resources and teams
        self.party_1_resources = 100
        self.party_2_resources = 100
        
        # Initialize team members for each party
        self.party_1_team = []
        self.party_2_team = []
        self.party_1_candidate = None
        self.party_2_candidate = None
        
        # Team statistics
        self.party_1_skill = 0
        self.party_2_skill = 0
        self.party_1_salary = 0
        self.party_2_salary = 0

        # Actions for each agent:
        # 0 = Retreat (stop exploring)
        # 1 = Explore independently
        # 2 = Initiate/continue collaboration
        # 3 = Break collaboration
        # 4+ = Hire candidate (fire team member at index action-4)
        self.action_space = spaces.Tuple((
            spaces.Discrete(4 + team_size),  # Party 1 actions
            spaces.Discrete(4 + team_size)   # Party 2 actions
        ))

        # The observation space includes:
        # - Collaboration mode
        # - Party statuses
        # - Party resources
        # - Team skills and salaries
        # - Candidates
        self.observation_space = spaces.Dict({
            "collab_mode": spaces.Discrete(3),
            "party_1_status": spaces.Discrete(3),
            "party_2_status": spaces.Discrete(3),
            "party_1_resources": spaces.Box(low=0, high=float('inf'), shape=(1,), dtype=np.float32),
            "party_2_resources": spaces.Box(low=0, high=float('inf'), shape=(1,), dtype=np.float32),
            "party_1_team_skills": spaces.Box(low=0.0, high=1.0, shape=(team_size,), dtype=np.float32),
            "party_1_team_salaries": spaces.Box(low=0.0, high=1.0, shape=(team_size,), dtype=np.float32),
            "party_2_team_skills": spaces.Box(low=0.0, high=1.0, shape=(team_size,), dtype=np.float32),
            "party_2_team_salaries": spaces.Box(low=0.0, high=1.0, shape=(team_size,), dtype=np.float32),
            "party_1_candidate": spaces.Box(low=0.0, high=1.0, shape=(2,), dtype=np.float32),
            "party_2_candidate": spaces.Box(low=0.0, high=1.0, shape=(2,), dtype=np.float32),
        })

        # Probabilities and rewards
        self.independent_agi_prob_factor = 0.01  # Base prob, will be multiplied by team skill
        self.collaborative_agi_prob_factor = 0.03
        self.post_collab_agi_prob_factor = 0.015

        # Resource dynamics
        self.explore_independent_cost = 1
        self.explore_collaborative_cost = 3
        self.retreat_reward = 10
        self.collaboration_initiation_cost = 5
        self.agi_reward = 100
        self.salary_cost_factor = 0.1  # Salary cost per step

        # Competition penalty: if the opponent finds AGI first
        self.competition_penalty = -50

        # Max environment steps
        self.max_steps = 1000
        self.current_step = 0

        ##########

        # Defining the transition probabilities and rewards for each party
        # Structure: {(party_status, collab_mode): {action: [{next_status, prob, reward, effects}, ...]}}
        self.party_1_transitions = {
            (self.PARTY_STATUS['EXPLORING'], self.COLLAB_MODES['INDEPENDENT']): {
                1: [  # Explore independently
                    {"next_status": self.PARTY_STATUS['FOUND_AGI'], "prob": lambda: self._get_agi_probability(1), 
                        "reward": self.agi_reward, "resource_change": -self.explore_independent_cost,
                        "effects": [("party_2_reward", self.competition_penalty, 
                                    lambda: self.party_2_status == self.PARTY_STATUS['EXPLORING'])]},
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": lambda: 1 - self._get_agi_probability(1), 
                        "reward": -self.party_1_salary * self.salary_cost_factor, "resource_change": -self.explore_independent_cost, "effects": []}
                ],
                0: [  # Retreat
                    {"next_status": self.PARTY_STATUS['RETREATED'], "prob": 1.0, 
                        "reward": self.retreat_reward, "resource_change": self.retreat_reward, "effects": []}
                ],
                2: [  # Try to initiate collaboration
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": 1.0, 
                        "reward": -self.collaboration_initiation_cost, "resource_change": -self.collaboration_initiation_cost, "effects": []}
                ],
                # Hiring actions will be handled separately
            },
            (self.PARTY_STATUS['EXPLORING'], self.COLLAB_MODES['COLLABORATIVE']): {
                0: [  # Retreat
                    {"next_status": self.PARTY_STATUS['RETREATED'], "prob": 1.0, 
                        "reward": self.retreat_reward, "resource_change": self.retreat_reward, "effects": []}
                ],
                2: [  # Explore collaboratively
                    {"next_status": self.PARTY_STATUS['FOUND_AGI'], "prob": lambda: self._get_agi_probability(1), 
                        "reward": self.agi_reward, "resource_change": -self.explore_collaborative_cost,
                        "effects": [("party_2_status", self.PARTY_STATUS['FOUND_AGI'], 
                                    lambda: self.party_2_status == self.PARTY_STATUS['EXPLORING']),
                                ("party_2_reward", self.agi_reward, 
                                    lambda: self.party_2_status == self.PARTY_STATUS['EXPLORING'])]},
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": lambda: 1 - self._get_agi_probability(1), 
                        "reward": -(self.party_1_salary + self.party_2_salary) * self.salary_cost_factor , "resource_change": -self.explore_collaborative_cost, "effects": []}
                ],
                3: [  # Break collaboration
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": 1.0, 
                        "reward": 0, "resource_change": 0, "effects": [("party_1_broke_collab", True, lambda: True)]}
                ]
            },
            (self.PARTY_STATUS['EXPLORING'], self.COLLAB_MODES['POST_COLLAB']): {
                1: [  # Explore after collaboration
                    {"next_status": self.PARTY_STATUS['FOUND_AGI'], "prob": lambda: self._get_agi_probability(1), 
                        "reward": self.agi_reward, "resource_change": -self.explore_independent_cost,
                        "effects": [("party_2_reward", self.competition_penalty, 
                                    lambda: self.party_2_status == self.PARTY_STATUS['EXPLORING'])]},
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": lambda: 1 - self._get_agi_probability(1), 
                        "reward": -(self.party_1_salary * self.salary_cost_factor), "resource_change": -self.explore_independent_cost, "effects": []}
                ],
                0: [  # Retreat
                    {"next_status": self.PARTY_STATUS['RETREATED'], "prob": 1.0, 
                        "reward": self.retreat_reward, "resource_change": self.retreat_reward, "effects": []}
                ],
            }
        }

        # Create similar transitions for party 2
        # Create similar transitions for party 2 (properly swapping party references)
        self.party_2_transitions = {
            (self.PARTY_STATUS['EXPLORING'], self.COLLAB_MODES['INDEPENDENT']): {
                1: [  # Explore independently
                    {"next_status": self.PARTY_STATUS['FOUND_AGI'], "prob": lambda: self._get_agi_probability(2),  # Changed 1→2
                        "reward": self.agi_reward, "resource_change": -self.explore_independent_cost,
                        "effects": [("party_1_reward", self.competition_penalty,  # Changed party_2→party_1
                                    lambda: self.party_1_status == self.PARTY_STATUS['EXPLORING'])]},  # Changed party_2→party_1
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": lambda: 1 - self._get_agi_probability(2),  # Changed 1→2
                        "reward": -self.party_2_salary * self.salary_cost_factor, "resource_change": -self.explore_independent_cost, "effects": []}
                ],
                0: [  # Retreat
                    {"next_status": self.PARTY_STATUS['RETREATED'], "prob": 1.0,
                        "reward": self.retreat_reward, "resource_change": self.retreat_reward, "effects": []}
                ],
                2: [  # Try to initiate collaboration
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": 1.0,
                        "reward": -(self.collaboration_initiation_cost), "resource_change": -self.collaboration_initiation_cost, "effects": []}
                ],
                # Hiring actions will be handled separately
            },
            (self.PARTY_STATUS['EXPLORING'], self.COLLAB_MODES['COLLABORATIVE']): {
                0: [  # Retreat
                    {"next_status": self.PARTY_STATUS['RETREATED'], "prob": 1.0,
                        "reward": self.retreat_reward, "resource_change": self.retreat_reward, "effects": []}
                ],
                2: [  # Explore collaboratively
                    {"next_status": self.PARTY_STATUS['FOUND_AGI'], "prob": lambda: self._get_agi_probability(2),  # Changed 1→2
                        "reward": self.agi_reward, "resource_change": -self.explore_collaborative_cost,
                        "effects": [("party_1_status", self.PARTY_STATUS['FOUND_AGI'],  # Changed party_2→party_1
                                    lambda: self.party_1_status == self.PARTY_STATUS['EXPLORING']),  # Changed party_2→party_1
                                ("party_1_reward", self.agi_reward,  # Changed party_2→party_1
                                    lambda: self.party_1_status == self.PARTY_STATUS['EXPLORING'])]},  # Changed party_2→party_1
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": lambda: 1 - self._get_agi_probability(2),  # Changed 1→2
                        "reward": -(self.party_1_salary + self.party_2_salary) * self.salary_cost_factor, "resource_change": -self.explore_collaborative_cost, "effects": []}
                ],
                3: [  # Break collaboration
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": 1.0,
                        "reward": 0, "resource_change": 0, "effects": [("party_2_broke_collab", True, lambda: True)]}  # Changed 1->2
                ]
            },
            (self.PARTY_STATUS['EXPLORING'], self.COLLAB_MODES['POST_COLLAB']): {
                1: [  # Explore after collaboration
                    {"next_status": self.PARTY_STATUS['FOUND_AGI'], "prob": lambda: self._get_agi_probability(2),  # Changed 1→2
                        "reward": self.agi_reward, "resource_change": -self.explore_independent_cost,
                        "effects": [("party_1_reward", self.competition_penalty,  # Changed party_2→party_1
                                    lambda: self.party_1_status == self.PARTY_STATUS['EXPLORING'])]},  # Changed party_2→party_1
                    {"next_status": self.PARTY_STATUS['EXPLORING'], "prob": lambda: 1 - self._get_agi_probability(2),  # Changed 1→2
                        "reward": -self.party_2_salary * self.salary_cost_factor, "resource_change": -self.explore_independent_cost, "effects": []}
                ],
                0: [  # Retreat
                    {"next_status": self.PARTY_STATUS['RETREATED'], "prob": 1.0,
                        "reward": self.retreat_reward, "resource_change": self.retreat_reward, "effects": []}
                ],
            }
        }

        #######

        self.reset()

    def _generate_worker(self, team_skill=0, skill_bias=0.5, skill_uncertainty=0.5, skill_mean=0.5, salary_noise=0.2):
        """Generate a worker with randomized skill and salary."""
        base_skill = np.clip(
            np.random.normal(
                loc=skill_mean + skill_bias * (team_skill / self.max_skill),
                scale=skill_uncertainty
            ), 0, 1
        )

        # Add noise to salary, making it only loosely correlated with skill
        salary = np.clip(
            base_skill + np.random.normal(loc=0.0, scale=salary_noise),
            0.0,
            1.0
        )

        return {"skill": base_skill, "salary": salary}

    def _update_team_stats(self):
        """Update team statistics."""
        self.party_1_skill = sum(w['skill'] for w in self.party_1_team)
        self.party_1_salary = sum(w['salary'] for w in self.party_1_team)
        
        self.party_2_skill = sum(w['skill'] for w in self.party_2_team)
        self.party_2_salary = sum(w['salary'] for w in self.party_2_team)

    # def _get_agi_probability(self, party_id):
    #     """Calculate probability of finding AGI based on team skill and collaboration mode."""
    #     team_skill = self.party_1_skill if party_id == 1 else self.party_2_skill
    #     skill_factor = 1 / (1 + np.exp(-self.alpha * (team_skill - self.s0)))
        
    #     if self.collab_mode == self.COLLAB_MODES['INDEPENDENT']:
    #         return skill_factor * self.independent_agi_prob_factor
    #     elif self.collab_mode == self.COLLAB_MODES['COLLABORATIVE']:
    #         return skill_factor * self.collaborative_agi_prob_factor
    #     else:  # POST_COLLAB
    #         return skill_factor * self.post_collab_agi_prob_factor
    
    def _get_agi_probability(self, party_id):
        """Calculate probability of finding AGI based on team skill and collaboration mode."""
        team_skill = self.party_1_skill if party_id == 1 else self.party_2_skill
        skill_factor = 1 / (1 + np.exp(-self.alpha * (team_skill - self.s0)))
        
        if self.collab_mode == self.COLLAB_MODES['INDEPENDENT']:
            return skill_factor * self.independent_agi_prob_factor
        elif self.collab_mode == self.COLLAB_MODES['COLLABORATIVE']:
            return skill_factor * self.collaborative_agi_prob_factor
        else:  # POST_COLLAB
            return skill_factor * self.post_collab_agi_prob_factor

    def _process_collaboration_actions(self, action_1, action_2):
        """Process collaboration dynamics based on both parties' actions."""
        if self.collab_mode == self.COLLAB_MODES['INDEPENDENT']:
            # Both parties agree to collaborate
            if action_1 == 2 and action_2 == 2:
                self.collab_mode = self.COLLAB_MODES['COLLABORATIVE']
                self.party_1_resources -= self.collaboration_initiation_cost
                self.party_2_resources -= self.collaboration_initiation_cost
                return True
                
        elif self.collab_mode == self.COLLAB_MODES['COLLABORATIVE']:
            # Breaking collaboration
            if action_1 == 3 or action_2 == 3:
                self.collab_mode = self.COLLAB_MODES['POST_COLLAB']
                if action_1 == 3:
                    self.party_1_broke_collab = True
                if action_2 == 3:
                    self.party_2_broke_collab = True
                return True
            
            # Continuing collaboration
            elif action_1 == 2 and action_2 == 2:
                # Stay in collaboration mode
                return True
            # else:                                                                                 #COMMENTED THIS CODE BECAUSE OTHER ACTIONS SHOULDN'T BE ALLOWED IN THIS STAGE
            #     # Implicit break by not continuing
            #     self.collab_mode = self.COLLAB_MODES['POST_COLLAB']
            #     return True
                
        return False

    def _decode_action(self, action):
        """Decode an action into explore/retreat/collaborate/hire operations."""
        if action < 4:
            return {"action_type": action, "fire_index": None}
        else:
            return {"action_type": 4, "fire_index": action - 4}

    def _process_party_actions(self, party_id, action_dict, reward):
        """Process a party's action and update state/reward using transition dictionaries."""
        status = self.party_1_status if party_id == 1 else self.party_2_status
        
        # Only process if party is still exploring
        if status != self.PARTY_STATUS['EXPLORING']:
            return reward
        
        action_type = action_dict["action_type"]
        
        # Hiring action - handle separately since it's not in the transitions
        if action_type == 4:
            fire_index = action_dict["fire_index"]
            if party_id == 1 and 0 <= fire_index < self.team_size:
                old_skill = self.party_1_team[fire_index]["skill"]
                # Replace team member and update
                self.party_1_team[fire_index] = self.party_1_candidate
                new_skill = self.party_1_candidate["skill"]
                self.party_1_candidate = self._generate_worker(self.party_1_skill)
                self._update_team_stats()
                return reward + (new_skill - old_skill)
                
            elif party_id == 2 and 0 <= fire_index < self.team_size:
                old_skill = self.party_2_team[fire_index]["skill"]
                # Replace team member and update
                self.party_2_team[fire_index] = self.party_2_candidate
                new_skill = self.party_2_candidate["skill"]
                self.party_2_candidate = self._generate_worker(self.party_2_skill)
                self._update_team_stats()
                return reward + (new_skill - old_skill)
            


        
        # For exploration/collaboration/retreat, use transition dictionaries
        transitions = None
        if party_id == 1:
            state_key = (self.party_1_status, self.collab_mode)
            if state_key in self.party_1_transitions and action_type in self.party_1_transitions[state_key]:
                transitions = self.party_1_transitions[state_key][action_type]
        else:
            state_key = (self.party_2_status, self.collab_mode)
            if state_key in self.party_2_transitions and action_type in self.party_2_transitions[state_key]:
                transitions = self.party_2_transitions[state_key][action_type]
        
        if transitions != None:
            # Select transition based on probabilities
            probs = [t["prob"]() if callable(t["prob"]) else t["prob"] for t in transitions]        #MADE IT INTO A CALLABLE FUNCTION BECAUSE THERE IS LAMBDA FUNCTION THAT USES GET_AGI_PROBABILITY THIS TIME
            transition_idx = np.random.choice(len(transitions), p=probs)
            transition = transitions[transition_idx]
            
            # Apply transition
            if party_id == 1:
                self.party_1_status = transition["next_status"]
                self.party_1_resources += transition["resource_change"]
            else:
                self.party_2_status = transition["next_status"]
                self.party_2_resources += transition["resource_change"]
            
            reward += transition["reward"]
            
            # Apply side effects
            for effect in transition["effects"]:
                target, value, condition = effect
                if condition():
                    if target == "party_1_reward" and party_id == 2:
                        reward += value
                    elif target == "party_2_reward" and party_id == 1:
                        # Apply to other party in the next step
                        if self.party_2_status == self.PARTY_STATUS['EXPLORING']:
                            self.party_2_resources += value
                    elif target == "party_1_status" and party_id == 2:
                        self.party_1_status = value
                    elif target == "party_2_status" and party_id == 1:
                        self.party_2_status = value
                    elif target == "party_1_broke_collab":
                        self.party_1_broke_collab = value
                    elif target == "party_2_broke_collab":
                        self.party_2_broke_collab = value
        
        return reward

    def _get_observation(self):
        """Return the current observation (state)."""
        return {
            "collab_mode": self.collab_mode,
            "party_1_status": self.party_1_status,
            "party_2_status": self.party_2_status,
            "party_1_resources": np.array([self.party_1_resources], dtype=np.float32),
            "party_2_resources": np.array([self.party_2_resources], dtype=np.float32),
            "party_1_team_skills": np.array([w['skill'] for w in self.party_1_team], dtype=np.float32),
            "party_1_team_salaries": np.array([w['salary'] for w in self.party_1_team], dtype=np.float32),
            "party_2_team_skills": np.array([w['skill'] for w in self.party_2_team], dtype=np.float32),
            "party_2_team_salaries": np.array([w['salary'] for w in self.party_2_team], dtype=np.float32),
            "party_1_candidate": np.array([self.party_1_candidate['skill'], self.party_1_candidate['salary']], dtype=np.float32),
            "party_2_candidate": np.array([self.party_2_candidate['skill'], self.party_2_candidate['salary']], dtype=np.float32),
        }

    def _is_terminal(self):
        """Check if the episode has reached a terminal state."""
        return (self.party_1_status != self.PARTY_STATUS['EXPLORING'] or
                self.party_2_status != self.PARTY_STATUS['EXPLORING'] or
                self.current_step >= self.max_steps or
                self.party_1_resources <= 0 or
                self.party_2_resources <= 0)

    def _get_terminal_info(self):
        """Gather information about how the episode ended."""
        info = {}

        if self.party_1_resources <= 0:
            info["bankrupt"] = "party_1"
        if self.party_2_resources <= 0:
            info["bankrupt"] = "party_2" if "bankrupt" not in info else "both"

        if self.current_step >= self.max_steps:
            info["timeout"] = True

        if self.party_1_status == self.PARTY_STATUS['FOUND_AGI'] and self.party_2_status == self.PARTY_STATUS['FOUND_AGI']:
            info["winner"] = "both"
        elif self.party_1_status == self.PARTY_STATUS['FOUND_AGI']:
            info["winner"] = "party_1"
        elif self.party_2_status == self.PARTY_STATUS['FOUND_AGI']:
            info["winner"] = "party_2"

        if "winner" not in info and "bankrupt" not in info and not info.get("timeout", False):
            # Someone retreated
            if self.party_1_status == self.PARTY_STATUS['RETREATED'] and self.party_2_status == self.PARTY_STATUS['RETREATED']:
                info["both_retreated"] = True
            elif self.party_1_status == self.PARTY_STATUS['RETREATED']:
                info["retreated"] = "party_1"
            elif self.party_2_status == self.PARTY_STATUS['RETREATED']:
                info["retreated"] = "party_2"

        # Add team statistics
        info["party_1_skill"] = self.party_1_skill
        info["party_2_skill"] = self.party_2_skill
        info["party_1_salary"] = self.party_1_salary
        info["party_2_salary"] = self.party_2_salary

        return info

    def step(self, action):
        """Take a step in the environment with actions from both parties."""
        action_1, action_2 = action
        
        # Validate actions
        assert 0 <= action_1 < 4 + self.team_size, f"Invalid action for party 1: {action_1}"
        assert 0 <= action_2 < 4 + self.team_size, f"Invalid action for party 2: {action_2}"

        # If already in terminal state, return without changes
        if self._is_terminal():
            return self._get_observation(), (0, 0), True, False, self._get_terminal_info()

        # Process collaboration transitions first
        self._process_collaboration_actions(action_1, action_2)

        # Decode actions
        action_dict_1 = self._decode_action(action_1)
        action_dict_2 = self._decode_action(action_2)

        #REMOVED initialised reward that were doing self.party_1_salary * self.salary_cost_factor as it is already handled in the transition matrix
        reward_1 = 0
        reward_2 = 0

        # Process party actions and update rewards
        reward_1 = self._process_party_actions(1, action_dict_1, reward_1)
        reward_2 = self._process_party_actions(2, action_dict_2, reward_2)

        # Apply competition penalty if one finds AGI and the other doesn't
        if (self.party_1_status == self.PARTY_STATUS['FOUND_AGI'] and 
            self.party_2_status == self.PARTY_STATUS['EXPLORING']):
            reward_2 += self.competition_penalty
            
        if (self.party_2_status == self.PARTY_STATUS['FOUND_AGI'] and 
            self.party_1_status == self.PARTY_STATUS['EXPLORING']):
            reward_1 += self.competition_penalty

        # Increment step counter
        self.current_step += 1

        # Determine if episode has ended
        done = self._is_terminal()

        # Gather additional info for terminal states
        info = {}
        if done:
            info.update(self._get_terminal_info())

        return self._get_observation(), (reward_1, reward_2), done, False, info

    def reset(self, seed=None, options=None):
        """Reset the environment to initial state."""
        if seed is not None:
            np.random.seed(seed)

        self.collab_mode = self.COLLAB_MODES['INDEPENDENT']
        self.party_1_status = self.PARTY_STATUS['EXPLORING']
        self.party_2_status = self.PARTY_STATUS['EXPLORING']
        self.party_1_broke_collab = False
        self.party_2_broke_collab = False
        self.party_1_resources = 100
        self.party_2_resources = 100
        self.current_step = 0
        
        # Initialize teams
        self.party_1_team = [self._generate_worker() for _ in range(self.team_size)]
        self.party_2_team = [self._generate_worker() for _ in range(self.team_size)]
        
        # Generate initial candidates
        self.party_1_candidate = self._generate_worker()
        self.party_2_candidate = self._generate_worker()
        
        # Update team statistics
        self._update_team_stats()

        return self._get_observation(), {}

    def render(self):
        """Print the current state of the environment."""
        print(f"Step: {self.current_step} | Collaboration Mode: {self.collab_mode}")
        print(f"Party 1 Status: {self.party_1_status} | Resources: {self.party_1_resources:.2f}")
        print(f"Party 2 Status: {self.party_2_status} | Resources: {self.party_2_resources:.2f}")
        print(f"Party 1 Team Skill: {self.party_1_skill:.2f} | Salary: {self.party_1_salary:.2f}")
        print(f"Party 2 Team Skill: {self.party_2_skill:.2f} | Salary: {self.party_2_salary:.2f}")
        print("Party 1 Team:")
        for i, worker in enumerate(self.party_1_team):
            print(f"  Worker {i}: Skill={worker['skill']:.2f}, Salary={worker['salary']:.2f}")
        print("Party 2 Team:")
        for i, worker in enumerate(self.party_2_team):
            print(f"  Worker {i}: Skill={worker['skill']:.2f}, Salary={worker['salary']:.2f}")
        print(f"Party 1 Candidate: Skill={self.party_1_candidate['skill']:.2f}, Salary={self.party_1_candidate['salary']:.2f}")
        print(f"Party 2 Candidate: Skill={self.party_2_candidate['skill']:.2f}, Salary={self.party_2_candidate['salary']:.2f}")

    def close(self):
        """Clean up resources."""
        pass

### MultiAgentDictObsPreprocessor() code

In [29]:
# Observation preprocessor for the integrated environment
class MultiAgentDictObsPreprocessor(nn.Module):
    def __init__(self, team_size):
        super().__init__()
        self.team_size = team_size

    def forward(self, obs_dict):
        # Handle discrete observations with one-hot encoding
        collab_mode = F.one_hot(torch.tensor(obs_dict["collab_mode"], dtype=torch.long), num_classes=3).float()
        party_1_status = F.one_hot(torch.tensor(obs_dict["party_1_status"], dtype=torch.long), num_classes=3).float()
        party_2_status = F.one_hot(torch.tensor(obs_dict["party_2_status"], dtype=torch.long), num_classes=3).float()
        
        # Convert continuous observations to tensors
        def ensure_2d(x, dtype=torch.float32):
            x = torch.tensor(x, dtype=dtype)
            if x.dim() == 1:
                x = x.unsqueeze(0)  # convert [D] → [1, D]
            return x

        party_1_resources = ensure_2d(obs_dict["party_1_resources"])
        party_2_resources = ensure_2d(obs_dict["party_2_resources"])
        party_1_team_skills = ensure_2d(obs_dict["party_1_team_skills"])
        party_1_team_salaries = ensure_2d(obs_dict["party_1_team_salaries"])
        party_2_team_skills = ensure_2d(obs_dict["party_2_team_skills"])
        party_2_team_salaries = ensure_2d(obs_dict["party_2_team_salaries"])
        party_1_candidate = ensure_2d(obs_dict["party_1_candidate"])
        party_2_candidate = ensure_2d(obs_dict["party_2_candidate"])
        
        # Concatenate all features
        return torch.cat([
            collab_mode,
            party_1_status,
            party_2_status,
            party_1_resources,
            party_2_resources,
            party_1_team_skills,
            party_1_team_salaries,
            party_2_team_skills,
            party_2_team_salaries,
            party_1_candidate,
            party_2_candidate
        ], dim=-1)

In [30]:
def test_integrated_environment():
    """Test the integrated environment functionality."""
    env = CompetitiveHiringAGIEnv(team_size=3)
    obs, _ = env.reset()
    
    print("Initial State:")
    env.render()
    print("\nAction Space:", env.action_space)
    print("Observation Space:", env.observation_space)
    
    # Test a few steps with random actions
    for step in range(5):
        action = (
            np.random.randint(0, 7),  # Random action for party 1 (0-6 for team_size=3)
            np.random.randint(0, 7)   # Random action for party 2
        )
        next_obs, rewards, done, _, info = env.step(action)
        
        print(f"\nStep {step+1}:")
        print(f"Actions: Party 1 = {action[0]}, Party 2 = {action[1]}")
        print(f"Rewards: Party 1 = {rewards[0]:.2f}, Party 2 = {rewards[1]:.2f}")
        print(f"Done: {done}")
        env.render()
        
        if done:
            print("\nEpisode ended early.")
            print("Terminal Info:", info)
            break

    # Test collaboration scenario
    print("\n\nTesting collaboration scenario:")
    obs, _ = env.reset()
    
    # Both parties initiate collaboration
    action = (2, 2)
    next_obs, rewards, done, _, info = env.step(action)
    print("After collaboration initiation:")
    print(f"Collaboration Mode: {next_obs['collab_mode']}")
    print(f"Rewards: Party 1 = {rewards[0]:.2f}, Party 2 = {rewards[1]:.2f}")
    env.render()
    
    # Test hiring action
    print("\n\nTesting hiring action:")
    obs, _ = env.reset()
    
    # Party 1 hires a new worker (replacing worker 0)
    action = (4, 1)  # Party 1: Hire (fire index 0), Party 2: Explore
    next_obs, rewards, done, _, info = env.step(action)
    print("After hiring action:")
    print(f"Rewards: Party 1 = {rewards[0]:.2f}, Party 2 = {rewards[1]:.2f}")
    env.render()

if __name__ == "__main__":
    test_integrated_environment()

Initial State:
Step: 0 | Collaboration Mode: 0
Party 1 Status: 0 | Resources: 100.00
Party 2 Status: 0 | Resources: 100.00
Party 1 Team Skill: 1.57 | Salary: 1.28
Party 2 Team Skill: 0.53 | Salary: 0.86
Party 1 Team:
  Worker 0: Skill=0.91, Salary=0.56
  Worker 1: Skill=0.66, Salary=0.72
  Worker 2: Skill=0.00, Salary=0.00
Party 2 Team:
  Worker 0: Skill=0.00, Salary=0.05
  Worker 1: Skill=0.00, Salary=0.15
  Worker 2: Skill=0.53, Salary=0.66
Party 1 Candidate: Skill=0.38, Salary=0.65
Party 2 Candidate: Skill=0.48, Salary=0.38

Action Space: Tuple(Discrete(7), Discrete(7))
Observation Space: Dict('collab_mode': Discrete(3), 'party_1_candidate': Box(0.0, 1.0, (2,), float32), 'party_1_resources': Box(0.0, inf, (1,), float32), 'party_1_status': Discrete(3), 'party_1_team_salaries': Box(0.0, 1.0, (3,), float32), 'party_1_team_skills': Box(0.0, 1.0, (3,), float32), 'party_2_candidate': Box(0.0, 1.0, (2,), float32), 'party_2_resources': Box(0.0, inf, (1,), float32), 'party_2_status': Discret