<a href="https://colab.research.google.com/github/KyPython/NeurlPS/blob/main/NeurlPS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install Dependencies

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()
!conda install -n base -c conda-forge mamba
!mamba create -n "lux-s3" "python==3.11"
!git clone https://github.com/Lux-AI-Challenge/Lux-Design-S3/
!pip install -e Lux-Design-S3/src

Test Run Agent

In [None]:
import os
os.environ['MPLBACKEND'] = 'Agg'  # Set the backend explicitly
import matplotlib.pyplot as plt

import matplotlib
matplotlib.use('Agg')  # Force the backend to Agg
import matplotlib.pyplot as plt
  # Ensure this is set for non-GUI environments
print(matplotlib.get_backend())  # Confirm that the backend is correctly set

# Run a match with the correct paths
!ls /content/Lux-Design-S3/kits/python/

!luxai-s3 /content/Lux-Design-S3/kits/python/main.py /content/Lux-Design-S3/kits/python/main.py --output replay.json

- Upload replay.json to this link

https://s3vis.lux-ai.org/

In [None]:
!eval "$(mamba shell hook --shell)"
!conda install -n base -c conda-forge mamba
!mamba activate lux-s3

Gymnax Setup

In [None]:
!pip install git+https://github.com/RobertTLange/gymnax.git@main
import gymnax
import jax

# Check if gymnax is installed and can be imported without errors
print(f"Gymnax installed: {gymnax is not None}")

# Try to access a gymnax environment to further verify installation
try:
  env, env_params = gymnax.make("Catch-bsuite")
  print(f"Environment 'Catch-bsuite' created successfully")
except Exception as e:
  print(f"Error creating environment: {e}")

Development Environment

In [None]:
!git clone https://github.com/Lux-AI-Challenge/Lux-Design-S2.git
!ls /content/Lux-Design-S3
!git fetch --all
!git pull origin main
%cd /content/Lux-Design-S3/

Training Environment

In [None]:
# Change to the project root directory
%cd /content/Lux-Design-S3/src

# Install in editable mode
!pip install -e .
# Change directory to your training area
%cd /content/Lux-Design-S3/src/luxai_s3/
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))  # This adds src to sys.path

# Restart the kernel to ensure changes take effect. (You might need to run this manually)
try:
    import luxai_s3
except ModuleNotFoundError:
    !pip install -e Lux-Design-S3/src
    import luxai_s3  # Import after installation
# Import necessary modules
from env import LuxAIS3Env  # Import after installation
from params import EnvParams
import importlib
import sys
import os

# Ensure agent.py is in the correct path
# Updated agent_path to the correct location
agent_path = os.path.join(os.getcwd(), "..", "..", "kits", "python", "agent.py")

# If agent.py is not in the current directory, adjust this:
# agent_path = "/path/to/your/agent.py"

# Add the directory containing agent.py to sys.path
agent_dir = os.path.dirname(agent_path)
sys.path.append(agent_dir)

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "..")))  # This adds Lux-Design-S3 to sys.path

# Construct the absolute path to your module
module_path = os.path.abspath(os.path.join(os.getcwd(), "content/Lux-Design-S3/kits/python"))

# Add the path to sys.path if it's not already there
if module_path not in sys.path:
    sys.path.append(module_path)

# Import the module
agent_module = importlib.import_module("agent")  # Assuming 'agent.py' is the module

# Access the Agent class from the imported module
Agent = agent_module.Agent
# Create the LuxAI_S3 environment:
env = LuxAIS3Env(fixed_env_params=EnvParams())

# Initialize your agents:
agents = {
    "player_0": Agent(player="player_0", env_cfg=env.fixed_env_params),
    "player_1": Agent(player="player_1", env_cfg=env.fixed_env_params),
}

Value Based Off-Policy Learning

In [None]:
import sys
import os
import gymnasium as gym
import numpy as np
import jax
import jax.random
from stable_baselines3 import PPO
import torch as th
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv
import torch.nn as nn

# Change to the project root directory
%cd /content/Lux-Design-S3/src

# Install in editable mode
!pip install -e .

# Change directory to your training area
%cd /content/Lux-Design-S3/src/luxai_s3/
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))  # This adds src to sys.path

# Restart the kernel to ensure changes take effect. (You might need to run this manually)
try:
    import luxai_s3
except ModuleNotFoundError:
    !pip install -e /content/Lux-Design-S3/src
    import luxai_s3  # Import after installation

# Import necessary modules
from luxai_s3.spaces import MultiDiscrete  # Import MultiDiscrete
from luxai_s3.env import LuxAIS3Env
from luxai_s3.params import EnvParams
from luxai_s3 import utils

# Create the LuxAI_S3 environment:
env = LuxAIS3Env(fixed_env_params=EnvParams())

def flatten_observation(obs):
    # Check if the observation is a numpy ndarray
    if isinstance(obs, np.ndarray):
        return obs.flatten()
    # If the observation is of another type, you could handle it here (if necessary)
    else:
        raise ValueError("Expected a numpy.ndarray for observation")

# Check the action space directly (using method call)
original_action_space = env.action_space()
print("original_action_space:", original_action_space)
print("Type of original_action_space:", type(original_action_space))
print(f"Type of observation: {type(obs)}")
print(f"Shape of observation: {getattr(obs, 'shape', 'No shape attribute')}")

# Define the custom wrapper
class LuxAIWrapper(gym.Env):
    def __init__(self, env):
        super().__init__()
        self.env = env

        # Get the original action space
        original_action_space = env.action_space()

        # Handle action space if it's a Dict
        if isinstance(original_action_space, gym.spaces.Dict):
            # Print out the individual spaces in the Dict
            for key, space in original_action_space.spaces.items():
                print(f"Action space for {key}: {space}")
            self.action_space = original_action_space
        else:
            print("original_action_space is not a Dict, it's of type:", type(original_action_space))
            self.action_space = original_action_space  # Or use another appropriate default

        # Get the raw observation and reshape it
        sample_obs, _ = env.reset(jax.random.PRNGKey(0))

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = sample_obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Check the raw observation shape before reshaping
        print(f"Raw observation shape for player_0: {player_obs_np.shape}")

        player_obs_np = flatten_observation(player_obs_np)
        reshaped_obs = reshape_observation(player_obs_np)

        if reshaped_obs is None:
            raise ValueError("Observation reshaping failed. Please check your reshaping logic.")

        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=reshaped_obs.shape[1:],  # Exclude batch dimension
            dtype=np.float32
        )

    def reset(self, seed=None):
        key = jax.random.PRNGKey(seed if seed else 0)
        obs, info = self.env.reset(key)

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Flatten and reshape the observation as necessary
        player_obs_np = flatten_observation(player_obs_np)
        player_obs_np = reshape_observation(player_obs_np)

        return player_obs_np, info

    def step(self, action):
        lux_action = self.convert_action(action)
        obs, reward, done, truncated, info = self.env.step(lux_action)

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Flatten and reshape the observation as necessary
        player_obs_np = flatten_observation(player_obs_np)
        player_obs_np = reshape_observation(player_obs_np)

        return player_obs_np, reward, done, truncated, info

    def convert_action(self, action):
        """
        Converts a Dict action to your LuxAIS3Env's action space.
        """
        # Convert the action for both players (or any other components). Adjust as needed.
        return {"player_0": action["player_0"].tolist(), "player_1": action["player_1"].tolist()}

# Create the environment function for training
def make_env():
    env = LuxAIS3Env(fixed_env_params=EnvParams())
    return LuxAIWrapper(env)

# Use the wrapper to create vectorized environments for training
vec_env = DummyVecEnv([make_env])

# Custom CNN for feature extraction
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        with th.no_grad():
            sample_obs = th.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample_obs).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations.float()))

# Define PPO model with CustomCNN for feature extraction
model = PPO(
    "CnnPolicy",
    vec_env,
    verbose=1,
    policy_kwargs=dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=128),
        net_arch=[dict(pi=[64, 64], vf=[64, 64])]
    )
)

# Check the action space directly (using method call)
original_action_space = env.action_space()
print("original_action_space:", original_action_space)
print("Type of original_action_space:", type(original_action_space))

# Define the custom wrapper
class LuxAIWrapper(gym.Env):
    def __init__(self, env):
        super().__init__()
        self.env = env

        # Get the original action space
        original_action_space = env.action_space()

        # Handle action space if it's a Dict
        if isinstance(original_action_space, gym.spaces.Dict):
            # Print out the individual spaces in the Dict
            for key, space in original_action_space.spaces.items():
                print(f"Action space for {key}: {space}")
            self.action_space = original_action_space
        else:
            print("original_action_space is not a Dict, it's of type:", type(original_action_space))
            self.action_space = original_action_space  # Or use another appropriate default

        # Get the raw observation and reshape it
        sample_obs, _ = env.reset(jax.random.PRNGKey(0))

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = sample_obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Check the raw observation shape before reshaping
        print(f"Raw observation shape for player_0: {player_obs_np.shape}")

        player_obs_np = flatten_observation(player_obs_np)
        reshaped_obs = reshape_observation(player_obs_np)

        if reshaped_obs is None:
            raise ValueError("Observation reshaping failed. Please check your reshaping logic.")

        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=reshaped_obs.shape[1:],  # Exclude batch dimension
            dtype=np.float32
        )

    def reset(self, seed=None):
        key = jax.random.PRNGKey(seed if seed is not None else 0)
        obs, info = self.env.reset(key)

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Flatten and reshape the observation as necessary
        player_obs_np = flatten_observation(player_obs_np)
        player_obs_np = reshape_observation(player_obs_np)

        return player_obs_np, info

    def step(self, action):
        lux_action = self.convert_action(action)
        obs, reward, done, truncated, info = self.env.step(lux_action)

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Flatten and reshape the observation as necessary
        player_obs_np = flatten_observation(player_obs_np)
        player_obs_np = reshape_observation(player_obs_np)

        return player_obs_np, reward, done, truncated, info

    def convert_action(self, action):
        """
        Converts a Dict action to your LuxAIS3Env's action space.
        """
        # Convert the action for both players (or any other components). Adjust as needed.
        return {"player_0": action["player_0"].tolist(), "player_1": action["player_1"].tolist()}

# Create the environment function for training
def make_env():
    env = LuxAIS3Env(fixed_env_params=EnvParams())
    return LuxAIWrapper(env)

# Use the wrapper to create vectorized environments for training
vec_env = DummyVecEnv([make_env])

# Custom CNN for feature extraction
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        with th.no_grad():
            sample_obs = th.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample_obs).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations.float()))

# Define PPO model with CustomCNN for feature extraction
model = PPO(
    "CnnPolicy",
    vec_env,
    verbose=1,
    policy_kwargs=dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=128),
        net_arch=[dict(pi=[64, 64], vf=[64, 64])]
    )
)


# Change directory to your training area
%cd /content/Lux-Design-S3/src/luxai_s3/
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))  # This adds src to sys.path

# Restart the kernel to ensure changes take effect. (You might need to run this manually)
try:
    import luxai_s3
except ModuleNotFoundError:
    !pip install -e /content/Lux-Design-S3/src
    import luxai_s3  # Import after installation

# Check the action space directly (using method call)
original_action_space = env.action_space()
print("original_action_space:", original_action_space)
print("Type of original_action_space:", type(original_action_space))

# Define the custom wrapper
class LuxAIWrapper(gym.Env):
    def __init__(self, env):
        super().__init__()
        self.env = env

        # Get the original action space
        original_action_space = env.action_space()

        # Handle action space if it's a Dict
        if isinstance(original_action_space, gym.spaces.Dict):
            # Print out the individual spaces in the Dict
            for key, space in original_action_space.spaces.items():
                print(f"Action space for {key}: {space}")
            self.action_space = original_action_space
        else:
            print("original_action_space is not a Dict, it's of type:", type(original_action_space))
            self.action_space = original_action_space  # Or use another appropriate default

        # Get the raw observation and reshape it
        sample_obs, _ = env.reset(jax.random.PRNGKey(0))

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = sample_obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Check the raw observation shape before reshaping
        print(f"Raw observation shape for player_0: {player_obs_np.shape}")

        player_obs_np = flatten_observation(player_obs_np)
        reshaped_obs = reshape_observation(player_obs_np)

        if reshaped_obs is None:
            raise ValueError("Observation reshaping failed. Please check your reshaping logic.")

        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=reshaped_obs.shape[1:],  # Exclude batch dimension
            dtype=np.float32
        )

    def reset(self, seed=None):
        key = jax.random.PRNGKey(seed if seed is not None else 0)
        obs, info = self.env.reset(key)

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Flatten and reshape the observation as necessary
        player_obs_np = flatten_observation(player_obs_np)
        player_obs_np = reshape_observation(player_obs_np)

        return player_obs_np, info

    def step(self, action):
        lux_action = self.convert_action(action)
        obs, reward, done, truncated, info = self.env.step(lux_action)

        # Extract observation for 'player_0' (or 'player_1')
        player_obs = obs.get('player_0', None)  # Adjust key for player_1 if necessary
        if player_obs is None:
            raise ValueError("Failed to extract player observation from the dictionary.")

        # Convert the observation to numpy and flatten it
        player_obs_np = utils.to_numpy(player_obs)

        # Flatten and reshape the observation as necessary
        player_obs_np = flatten_observation(player_obs_np)
        player_obs_np = reshape_observation(player_obs_np)

        return player_obs_np, reward, done, truncated, info

    def convert_action(self, action):
        """
        Converts a Dict action to your LuxAIS3Env's action space.
        """
        # Convert the action for both players (or any other components). Adjust as needed.
        return {"player_0": action["player_0"].tolist(), "player_1": action["player_1"].tolist()}

# Create the environment function for training
def make_env():
    env = LuxAIS3Env(fixed_env_params=EnvParams())
    return LuxAIWrapper(env)

# Use the wrapper to create vectorized environments for training
vec_env = DummyVecEnv([make_env])

# Custom CNN for feature extraction
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute the size of the output after the convolutions
        with th.no_grad():
            sample_obs = th.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample_obs).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU()
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations.float()))



# Define PPO model with CustomCNN for feature extraction
model = PPO(
    "CnnPolicy",
    vec_env,
    verbose=1,
    policy_kwargs=dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=128),
        net_arch=[dict(pi=[64, 64], vf=[64, 64])]
    )
)

# Clip the gradients during training to prevent exploding gradients:
model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=3e-4)  # Replace with your learning rate
model.policy.optimizer.clip_grad_norm_ = 1.0  # Clip gradients with norm greater than 1.0

# Train the model
model.learn(total_timesteps=10000)  # Adjust total timesteps as needed.

# Save the trained model
model.save("ppo_luxai_model")

Running The Model In A Game

In [None]:
from luxai_s3.env import LuxAI_S3

env = LuxAI_S2()
obs = env.reset()

for step in range(100):
    action = model.predict(obs)
    obs, reward, done, info = env.step(action)
    if done:
        obs = env.reset()

Submit Trained Model

In [None]:
kaggle competitions submit -c lux-ai-season-3 -f submission.py -m "PPO model for Lux AI"
tar -czvf submission.tar.gz *

Define Shared Experience Buffer

In [None]:
import numpy as np
from collections import deque

class SharedExperienceBuffer:
    def __init__(self, buffer_size, batch_size):
        # Buffer stores (state, action, reward, next_state)
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.buffer = deque(maxlen=buffer_size)

    def store(self, experience):
        """Store experience in the buffer."""
        self.buffer.append(experience)

    def sample(self):
        """Sample a batch of experiences from the buffer."""
        indices = np.random.choice(len(self.buffer), self.batch_size, replace=False)
        batch = [self.buffer[i] for i in indices]
        return batch

    def size(self):
        """Return the current size of the buffer."""
        return len(self.buffer)

Integrate Shared Experience Buffer With Multiple Agents

In [None]:
class Agent:
    def __init__(self, player: str, env_cfg, shared_buffer):
        self.player = player
        self.env_cfg = env_cfg
        self.shared_buffer = shared_buffer
        self.epsilon = 0.1  # Exploration parameter for epsilon-greedy

    def act(self, obs):
        # Sample action based on policy
        if np.random.rand() < self.epsilon:
            return np.random.randint(0, self.env_cfg["action_space_size"])  # Random action
        else:
            return self.policy(obs)  # Determine action using learned policy

    def train(self):
        # Sample experiences from the shared buffer
        batch = self.shared_buffer.sample()

        # Train using the batch (e.g., Q-learning, PPO, etc.)
        for experience in batch:
            state, action, reward, next_state = experience
            # Apply update rule based on chosen RL algorithm (e.g., Q-learning, PPO, etc.)
            self.update_policy(state, action, reward, next_state)

    def update_policy(self, state, action, reward, next_state):
        # This function should be implemented with the RL update rules (e.g., Q-update, advantage updates)
        pass

Multi-Agent Training Process

In [None]:
# Create a shared experience buffer
shared_buffer = SharedExperienceBuffer(buffer_size=100000, batch_size=64)

# Define a placeholder for env_cfg (you need to replace this with your actual configuration)
# Example configuration:
env_cfg = {
    "action_space_size": 5  # Replace with the size of your action space
}

# Initialize agents
agents = [Agent(player="player_0", env_cfg=env_cfg, shared_buffer=shared_buffer),
          Agent(player="player_1", env_cfg=env_cfg, shared_buffer=shared_buffer)]

# Main training loop
for episode in range(num_episodes):
    for agent in agents:
        # Agents interact with the environment and store their experiences in the shared buffer
        state = env.reset()
        done = False
        while not done:
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)

            # Store the experience in the shared buffer
            shared_buffer.store((state, action, reward, next_state))
            state = next_state

        # Each agent trains on experiences from the shared buffer
        agent.train()

Define The Manager (Strategic Policy)

In [None]:
class Manager:
    def __init__(self, env_cfg):
        self.env_cfg = env_cfg
        # Define high-level strategies (e.g., prioritize resources or offense)

    def act(self, global_state):
        # Decide on a high-level strategy (strategic goal)
        strategy = self.select_strategy(global_state)
        return strategy

    def select_strategy(self, global_state):
        # Example: decide whether to focus on defense, offense, or resource gathering
        if global_state['resources'] < 50:
            return "gather_resources"
        else:
            return "expand_city"

Define The Worker (Tactical Policy)

In [None]:
class Worker:
    def __init__(self, player, env_cfg, manager):
        self.player = player
        self.env_cfg = env_cfg
        self.manager = manager
        self.epsilon = 0.1  # Exploration parameter

    def act(self, current_state):
        # The worker gets the strategy from the manager
        strategy = self.manager.act(current_state)

        # Depending on the strategy, decide on the tactical actions
        if strategy == "gather_resources":
            return self.gather_resources(current_state)
        elif strategy == "expand_city":
            return self.expand_city(current_state)
        else:
            return self.defend(current_state)

    def gather_resources(self, state):
        # Implement the logic for gathering resources tactically
        pass

    def expand_city(self, state):
        # Implement the logic for expanding a city tactically
        pass

    def defend(self, state):
        # Implement the logic for defense tactics
        pass

Training with Hierarchical Reinforcement Learning (HRL)

In [None]:
class HRLAgent:
    def __init__(self, player, env_cfg):
        self.manager = Manager(env_cfg)
        self.worker = Worker(player, env_cfg, self.manager)

    def train(self, environment):
        for episode in range(num_episodes):
            state = environment.reset()
            done = False
            while not done:
                # The worker makes tactical decisions based on the manager's strategy
                action = self.worker.act(state)

                # Simulate the action and get the reward
                next_state, reward, done, _ = environment.step(action)

                # Update both the manager and the worker based on the reward
                self.manager.update(state, reward, next_state)
                self.worker.update(state, reward, next_state)

Refine RL Policy Actions with Monte Carlo Tree Search (MCTS)

In [None]:
import random

class MCTSNode:
    def __init__(self, state, parent=None):
        self.state = state  # Game state at this node
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0

    def select_best_child(self):
        return max(self.children, key=lambda child: child.value / (child.visits + 1e-6))  # Avoid division by zero

    def expand(self, action_space):
        for action in action_space:
            new_state = self.simulate_action(self.state, action)
            self.children.append(MCTSNode(new_state, parent=self))

    def simulate_action(self, state, action):
        """ Simulates a move and returns a new state """
        new_state = state.copy()
        new_state["action_taken"] = action
        return new_state

    def backpropagate(self, reward):
        node = self
        while node:
            node.visits += 1
            node.value += reward
            node = node.parent

def mcts_search(root_state, action_space, iterations=100):
    root = MCTSNode(root_state)

    for _ in range(iterations):
        node = root

        # Selection
        while node.children:
            node = node.select_best_child()

        # Expansion
        if not node.children:
            node.expand(action_space)

        # Simulation
        simulated_reward = random.uniform(0, 1)  # Replace with real reward function
        node.backpropagate(simulated_reward)

    return root.select_best_child().state["action_taken"]  # Best move found

CNN Policy Network

In [None]:
import torch.nn as nn

class CNNPolicy(nn.Module):
    def __init__(self, map_size, action_dim):
        super(CNNPolicy, self).__init__()
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, padding=1)  # 4 input channels (state)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * map_size * map_size, 128)
        self.fc2 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(x.shape[0], -1)  # Flatten
        x = torch.relu(self.fc1(x))
        return torch.softmax(self.fc2(x), dim=-1)  # Output action probabilities

Graph Convolutional Networks

In [None]:
import torch.nn.functional as F
from dgl.nn import GraphConv

class GNNPolicy(nn.Module):
    def __init__(self, input_dim, hidden_dim, action_dim):
        super(GNNPolicy, self).__init__()
        self.conv1 = GraphConv(input_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, action_dim)

    def forward(self, g, features):
        x = F.relu(self.conv1(g, features))
        x = self.conv2(g, x)
        return F.softmax(x, dim=-1)  # Action probabilities

CNN & GNN Representation

In [None]:
import numpy as np
import torch
import dgl
from dgl.nn import GraphConv
import torch.nn as nn
import torch.nn.functional as F

def encode_state_cnn(obs, team_id, env_cfg):
    """Convert game state into a multi-channel image representation for CNN input"""
    map_size = env_cfg["map_width"]  # Assuming square grid

    state_tensor = np.zeros((4, map_size, map_size), dtype=np.float32)  # 4 channels

    # Channel 1: Unit positions
    unit_positions = np.array(obs["units"]["position"][team_id])
    for pos in unit_positions:
        x, y = pos
        state_tensor[0, x, y] = 1  # Player’s unit

    opponent_team_id = 1 - team_id
    opponent_positions = np.array(obs["units"]["position"][opponent_team_id])
    for pos in opponent_positions:
        x, y = pos
        state_tensor[0, x, y] = -1  # Opponent’s unit

    # Channel 2: Energy levels (normalized)
    unit_energys = np.array(obs["units"]["energy"][team_id])
    for i, pos in enumerate(unit_positions):
        x, y = pos
        state_tensor[1, x, y] = unit_energys[i] / 100.0  # Normalize (assuming max energy = 100)

    # Channel 3: Relic locations
    relic_positions = np.array(obs["relic_nodes"])
    for pos in relic_positions:
        x, y = pos
        state_tensor[2, x, y] = 1  # Mark relic nodes

    # Channel 4: Obstacles (if any)
    if "walls" in obs:
        wall_positions = np.array(obs["walls"])
        for pos in wall_positions:
            x, y = pos
            state_tensor[3, x, y] = 1  # Mark walls

    return torch.tensor(state_tensor, dtype=torch.float32)

def encode_state_gnn(obs, team_id):
    """Convert game state into a graph representation for GNN input"""
    g = dgl.DGLGraph()

    # Add unit nodes
    unit_positions = np.array(obs["units"]["position"][team_id])
    unit_energys = np.array(obs["units"]["energy"][team_id])

    num_units = len(unit_positions)
    g.add_nodes(num_units)

    # Set node features (positions + energy)
    pos_features = torch.tensor(unit_positions, dtype=torch.float32)
    energy_features = torch.tensor(unit_energys, dtype=torch.float32)
    g.ndata["pos"] = pos_features
    g.ndata["energy"] = energy_features

    # Add edges based on proximity (e.g., within 3 tiles)
    for i in range(num_units):
        for j in range(i + 1, num_units):
            if np.linalg.norm(unit_positions[i] - unit_positions[j]) <= 3:
                g.add_edges(i, j)
                g.add_edges(j, i)  # Bi-directional edges

    return g

Hybrid CNN & GNN Model

In [None]:
class HybridPolicy(nn.Module):
    def __init__(self, map_size, action_dim, hidden_dim=128):
        super(HybridPolicy, self).__init__()

        # CNN for spatial processing
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, padding=1)  # 4 input channels (state)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * map_size * map_size, hidden_dim)

        # GNN for relational reasoning
        self.gnn1 = GraphConv(3, hidden_dim)  # Input: (x, y, energy)
        self.gnn2 = GraphConv(hidden_dim, hidden_dim)

        # Fusion layer
        self.fusion_fc = nn.Linear(2 * hidden_dim, hidden_dim)

        # Output layer
        self.action_fc = nn.Linear(hidden_dim, action_dim)

    def forward(self, cnn_state, g, g_features):
        # CNN processing
        x_cnn = F.relu(self.conv1(cnn_state))
        x_cnn = F.relu(self.conv2(x_cnn))
        x_cnn = x_cnn.view(x_cnn.shape[0], -1)  # Flatten
        x_cnn = F.relu(self.fc1(x_cnn))

        # GNN processing
        x_gnn = F.relu(self.gnn1(g, g_features))
        x_gnn = self.gnn2(g, x_gnn)
        x_gnn = x_gnn.mean(dim=0)  # Aggregate node information

        # Fusion
        x_fused = torch.cat([x_cnn, x_gnn], dim=-1)
        x_fused = F.relu(self.fusion_fc(x_fused))

        # Output action probabilities
        return F.softmax(self.action_fc(x_fused), dim=-1)

Action Selection

In [None]:
import torch

def select_action(model, obs, team_id, env_cfg):
    cnn_state = encode_state_cnn(obs, team_id, env_cfg).unsqueeze(0)  # Add batch dim
    g = encode_state_gnn(obs, team_id)
    g_features = g.ndata["pos"]  # Use position as node features

    with torch.no_grad():
        action_probs = model(cnn_state, g, g_features)

    # Choose action stochastically
    action = torch.multinomial(action_probs, 1).item()

    return action

Encode Observations

In [None]:
import numpy as np
import torch

def encode_unit_view(obs, unit_pos, env_cfg):
    """Encodes a unit’s local view as a CNN tensor."""
    view_size = 5  # 5x5 grid around the unit
    state_tensor = np.zeros((4, view_size, view_size), dtype=np.float32)  # 4 channels

    # Get relative positions in local 5x5 grid
    x_min, x_max = max(0, unit_pos[0] - 2), min(env_cfg["map_width"], unit_pos[0] + 3)
    y_min, y_max = max(0, unit_pos[1] - 2), min(env_cfg["map_height"], unit_pos[1] + 3)

    # Extract local region
    state_tensor[:, :x_max - x_min, :y_max - y_min] = obs["global_map"][:, x_min:x_max, y_min:y_max]

    return torch.tensor(state_tensor, dtype=torch.float32)

Encode Team Awareness (GNN)

In [None]:
import dgl
import torch
import torch.nn as nn
from dgl.nn import GATConv

def encode_team_graph(obs, team_id):
    """Builds a graph where units are connected if they are within 3 tiles."""
    g = dgl.DGLGraph()

    # Get unit positions
    unit_positions = np.array(obs["units"]["position"][team_id])
    num_units = len(unit_positions)
    g.add_nodes(num_units)

    # Add edges (if within 3 tiles)
    for i in range(num_units):
        for j in range(i + 1, num_units):
            if np.linalg.norm(unit_positions[i] - unit_positions[j]) <= 3:
                g.add_edges(i, j)
                g.add_edges(j, i)  # Bi-directional

    # Node features: (x, y, energy)
    unit_energys = np.array(obs["units"]["energy"][team_id])
    node_features = torch.tensor(np.hstack([unit_positions, unit_energys]), dtype=torch.float32)

    g.ndata["features"] = node_features
    return g, node_features

Multiagent PPO Policy for Reinforcement Learning (RL)

In [None]:
class MultiAgentPolicy(nn.Module):
    def __init__(self, action_dim, hidden_dim=128):
        super(MultiAgentPolicy, self).__init__()

        # CNN for local vision
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 5 * 5, hidden_dim)

        # GAT for communication
        self.gat1 = GATConv(3, hidden_dim, num_heads=4)  # Input: (x, y, energy)
        self.gat2 = GATConv(hidden_dim * 4, hidden_dim, num_heads=1)

        # Fusion layer
        self.fusion_fc = nn.Linear(2 * hidden_dim, hidden_dim)

        # Output action layer
        self.action_fc = nn.Linear(hidden_dim, action_dim)

    def forward(self, unit_cnn_state, g, g_features):
        # CNN processing
        x_cnn = F.relu(self.conv1(unit_cnn_state))
        x_cnn = F.relu(self.conv2(x_cnn))
        x_cnn = x_cnn.view(x_cnn.shape[0], -1)  # Flatten
        x_cnn = F.relu(self.fc1(x_cnn))

        # GNN processing
        x_gnn = F.relu(self.gat1(g, g_features))
        x_gnn = self.gat2(g, x_gnn)
        x_gnn = x_gnn.mean(dim=0)  # Aggregate graph info

        # Fusion
        x_fused = torch.cat([x_cnn, x_gnn], dim=-1)
        x_fused = F.relu(self.fusion_fc(x_fused))

        # Output action probabilities
        return F.softmax(self.action_fc(x_fused), dim=-1)

Action Selection

In [None]:
def select_multi_agent_actions(model, obs, team_id, env_cfg):
    g, g_features = encode_team_graph(obs, team_id)
    actions = {}

    for unit_id, unit_pos in enumerate(obs["units"]["position"][team_id]):
        unit_cnn_state = encode_unit_view(obs, unit_pos, env_cfg).unsqueeze(0)  # Add batch dim

        with torch.no_grad():
            action_probs = model(unit_cnn_state, g, g_features)

        action = torch.multinomial(action_probs, 1).item()  # Sample an action
        actions[unit_id] = action

    return actions

Reward SHaping

In [None]:
class LuxEnv(gym.Env):
    def __init__(self, env_cfg, player: str):
        super(LuxEnv, self).__init__()

        self.env_cfg = env_cfg
        self.player = player
        self.opp_player = "player_1" if player == "player_0" else "player_0"

        self.game = Game()
        self.game_state = None

        # Define action space and observation space
        self.action_space = spaces.Discrete(5)  # e.g., move, harvest, build, etc.
        self.observation_space = spaces.Box(low=0, high=255, shape=(self.env_cfg["map_width"], self.env_cfg["map_height"], 4), dtype=np.uint8)

        # Initialize the game state
        self.reset()

    def reset(self):
        """Reset the environment to the initial state and return the initial observation."""
        self.game_state = self.game.reset(self.env_cfg)
        return self.get_observation()

    def get_observation(self):
        """Extract and return the game observation."""
        units = self.game_state.units[self.player]
        unit_positions = np.array([unit.pos for unit in units])
        unit_energies = np.array([unit.energy for unit in units])
        relic_nodes = self.game_state.relic_nodes

        map_grid = np.zeros((self.env_cfg["map_width"], self.env_cfg["map_height"], 4), dtype=np.uint8)

        for pos in unit_positions:
            map_grid[pos[0], pos[1], 0] = 1  # Mark unit positions

        for relic_node in relic_nodes:
            map_grid[relic_node[0], relic_node[1], 1] = 1  # Mark relic nodes

        return map_grid

    def step(self, action):
        """Take an action and return the next state, reward, done, and additional info."""
        actions = self.convert_action(action)
        self.game_state = self.game.step(self.game_state, actions)

        # Get observation after the step
        observation = self.get_observation()

        # Calculate the reward for the agent (based on resource collection, city survival, etc.)
        reward = self.calculate_reward()

        # Check if the episode is done
        done = self.is_done()

        # Info dictionary
        info = {}

        return observation, reward, done, info

    def calculate_reward(self):
        """Calculate the reward for the agent based on various milestones."""
        reward = 0

        # Resource collection: Reward for each energy collected
        for unit in self.game_state.units[self.player]:
            reward += unit.energy  # Assuming reward is tied to energy collected

        # City survival: Reward for keeping the city alive
        if self.game_state.turns < self.env_cfg["max_turns"]:
            # City survival reward: +10 for each turn the city survives
            reward += 10  # Or adjust based on city status

        # Winning: Reward for winning the game
        if self.game_state.is_game_over():
            if self.game_state.winner == self.player:
                reward += 100  # Large reward for winning

        return reward

    def is_done(self):
        """Check if the episode is done."""
        # The episode is done if the game ends or if the city is destroyed
        if self.game_state.turns >= self.env_cfg["max_turns"]:
            return True
        return self.game_state.is_game_over()

    def convert_action(self, action):
        """Convert RL action to Lux AI action."""
        actions = []

        if action == 0:
            actions = [("move", unit_id, target_pos) for unit_id, target_pos in self.get_move_targets()]
        elif action == 1:
            actions = [("harvest", unit_id) for unit_id in self.get_harvest_units()]
        elif action == 2:
            actions = [("build", unit_id, building_type) for unit_id, building_type in self.get_building_targets()]
        elif action == 3:
            actions = [("explore", unit_id) for unit_id in self.get_exploration_units()]
        elif action == 4:
            actions = [("collect_energy", unit_id) for unit_id in self.get_energy_units()]

        return actions

    def render(self, mode='human'):
        """Render the current state of the game."""
        print(f"Player points: {self.game_state.team_points[self.player]} - Opponent points: {self.game_state.team_points[self.opp_player]}")
        pass

Dynamic Opponent Difficulty

In [None]:
class CurriculumTrainer:
    def __init__(self, env_cfg, player):
        self.env_cfg = env_cfg
        self.player = player
        self.current_stage = 1
        self.opponent_difficulty = "easy"  # Initial opponent difficulty

    def update_opponent_difficulty(self, agent_performance):
        """Adjust opponent difficulty based on agent performance."""
        if agent_performance >= 0.8:  # If agent is performing well, increase difficulty
            self.current_stage += 1
            if self.current_stage == 2:
                self.opponent_difficulty = "medium"
            elif self.current_stage == 3:
                self.opponent_difficulty = "hard"
            elif self.current_stage >= 4:
                self.opponent_difficulty = "expert"

    def train(self):
        """Train the agent with the current stage and opponent difficulty."""
        if self.opponent_difficulty == "easy":
            # Start training with easy opponent
            opponent = EasyOpponent()
        elif self.opponent_difficulty == "medium":
            # Training with medium difficulty opponent
            opponent = MediumOpponent()
        elif self.opponent_difficulty == "hard":
            # Training with hard opponent
            opponent = HardOpponent()
        else:
            # Training with expert opponent
            opponent = ExpertOpponent()

        # Train agent against selected opponent
        agent = Agent(self.player, self.env_cfg)
        game = Game(agent, opponent)
        game.run()

        # Track agent's performance
        agent_performance = agent.evaluate_performance()
        self.update_opponent_difficulty(agent_performance)

Adjust Opponent Based on Performance

In [None]:
def train_with_curriculum(self, agent):
    """Train the agent with progressively harder opponents."""
    # Training loop for several episodes
    for episode in range(100):  # 100 training episodes as an example
        game = Game(agent, opponent)
        result = game.run()

        # Track performance metrics
        win_rate = result['win_rate']  # Track the agent's win rate
        self.update_opponent_difficulty(win_rate)

        # If agent is doing well, move to the next difficulty level
        if self.current_stage > 1:
            self.opponent_difficulty = "medium"
        if self.current_stage > 2:
            self.opponent_difficulty = "hard"
        if self.current_stage > 3:
            self.opponent_difficulty = "expert"

Implement Self-Play in Lux AI

In [None]:
class SelfPlayTrainer:
    def __init__(self, env_cfg, agent_cls, player, num_epochs=100):
        self.env_cfg = env_cfg
        self.agent_cls = agent_cls
        self.player = player
        self.num_epochs = num_epochs
        self.previous_agents = []  # List to hold previous agent versions

    def save_agent(self, agent):
        """Save the current agent to a list for future self-play."""
        self.previous_agents.append(agent)

    def select_opponent(self, epoch):
        """Select an opponent based on the current epoch. The opponent will be a previous version of the agent."""
        # Opponent is selected from past agents
        if epoch == 0:
            return self.agent_cls(self.player, self.env_cfg)  # No opponent in first epoch, play against a simple random agent
        return self.previous_agents[-1]  # Select the most recent version of the agent

    def train(self):
        """Train the agent using self-play, with progressively harder opponents."""
        for epoch in range(self.num_epochs):
            print(f"Epoch {epoch + 1} Training Begins...")

            # Create the agent for this epoch
            agent = self.agent_cls(self.player, self.env_cfg)

            # Select opponent: play against previous agent or random strategy
            opponent = self.select_opponent(epoch)

            # Run the game (self-play against the selected opponent)
            game = Game(agent, opponent)
            result = game.run()  # Run a self-play game

            # Evaluate the agent's performance
            agent_performance = agent.evaluate_performance()

            # Save the agent after this training epoch for future self-play
            self.save_agent(agent)

            print(f"Epoch {epoch + 1} completed. Agent Performance: {agent_performance}")

Self-Play Game Loop

In [None]:
class Game:
    def __init__(self, agent1, agent2):
        self.agent1 = agent1
        self.agent2 = agent2
        self.env = LuxEnvironment()  # Lux AI environment
        self.current_step = 0

    def run(self):
        """Run the game loop for a number of steps."""
        while not self.env.is_game_over():
            obs1 = self.env.get_observation(self.agent1.player)
            obs2 = self.env.get_observation(self.agent2.player)

            # Get actions from both agents
            actions1 = self.agent1.act(self.current_step, obs1)
            actions2 = self.agent2.act(self.current_step, obs2)

            # Step the environment forward
            self.env.step(actions1, actions2)

            self.current_step += 1

        # At the end of the game, return the result
        return self.env.get_game_result()

Evaluation Metric for Self-Play

In [None]:
def evaluate_performance(self):
    """Evaluate the agent’s performance based on custom metrics."""
    # Example of evaluating win rate
    win_rate = self.calculate_win_rate()

    # Other performance evaluations (e.g., resource efficiency)
    resource_efficiency = self.calculate_resource_efficiency()

    return {'win_rate': win_rate, 'resource_efficiency': resource_efficiency}

def calculate_win_rate(self):
    """Calculate the agent's win rate."""
    # Placeholder: Assume agent plays 100 games and wins 80
    return 0.8

Hybrid PPO + Q-Learning for Lux AI

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from stable_baselines3 import PPO
from collections import deque

class HybridAgent:
    def __init__(self, player: str, env_cfg, alpha=0.001, gamma=0.99, epsilon=0.1):
        self.player = player
        self.env_cfg = env_cfg
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon

        # Initialize PPO agent
        self.ppo_agent = PPO("MlpPolicy", env_cfg, verbose=0)

        # Initialize Q-learning
        self.q_table = np.zeros((env_cfg["state_space_size"], env_cfg["action_space_size"]))

    def select_action_ppo(self, obs):
        """Select action using PPO (on-policy)"""
        action, _ = self.ppo_agent.predict(obs)
        return action

    def select_action_qlearning(self, state):
        """Select action using Q-learning (off-policy)"""
        if np.random.rand() < self.epsilon:
            # Exploration: Random action
            return np.random.choice(self.env_cfg["action_space_size"])
        else:
            # Exploitation: Choose the best action based on Q-table
            return np.argmax(self.q_table[state])

    def update_q_table(self, state, action, reward, next_state):
        """Update the Q-table using Q-learning update rule"""
        best_next_action = np.argmax(self.q_table[next_state])
        td_target = reward + self.gamma * self.q_table[next_state, best_next_action]
        td_error = td_target - self.q_table[state, action]
        self.q_table[state, action] += self.alpha * td_error

    def train(self, env, steps=1000):
        for step in range(steps):
            obs = env.reset()
            done = False
            while not done:
                # Decide on action using a combination of PPO and Q-learning
                if step % 2 == 0:  # Alternate between PPO and Q-learning
                    action = self.select_action_ppo(obs)
                else:
                    state = self.get_state_from_obs(obs)  # Convert observation to state
                    action = self.select_action_qlearning(state)

                next_obs, reward, done, info = env.step(action)

                # Update Q-table
                if step % 2 == 1:  # Q-learning update after every second step
                    next_state = self.get_state_from_obs(next_obs)
                    self.update_q_table(state, action, reward, next_state)

                # Save the experience for PPO training
                if step % 2 == 0:
                    self.ppo_agent.learn(total_timesteps=1)  # Simulate learning

                obs = next_obs

    def get_state_from_obs(self, obs):
        """Convert the environment observation to a state for Q-learning."""
        # Example transformation from obs to state (this would depend on the environment)
        return np.digitize(obs, self.env_cfg["state_bins"])

Quantization in PyTorch

In [None]:
import torch
from torch.quantization import quantize_dynamic

# Load a pre-trained model
model = torch.load("lux_ai_model.pth")

# Apply dynamic quantization (applies to weights only, keeping activations as float)
quantized_model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

# Save the quantized model
torch.save(quantized_model, "lux_ai_quantized_model.pth")

Model Distillation in PyTorch

In [None]:
import os
import sys

# Add the Lux AI environment to your path
sys.path.append(os.path.abspath('Lux-Design-S3/src'))

import gym
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import PPO
from luxai_s3.lux_env import LuxEnv

# 1. Define the Teacher Model Architecture

class TeacherModel(nn.Module):
    def __init__(self, observation_space, action_space):
        super(TeacherModel, self).__init__()
        # Define your model layers here, e.g., for a simple MLP:
        self.fc1 = nn.Linear(observation_space.shape[0], 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, action_space.n)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)  # No activation for the output layer (usually)
        return x

# 2. Create the Lux AI Environment
env = LuxEnv(
    configs={
        "seed": 562124210,
        "loglevel": 2,
        "annotations": True,
        "width": 12,
        "height": 12,
        "max_episode_length": 1000,  # number of steps in the environment
    },
    learning_agent="player_0",
    opponent_agent="player_1",
    verbose=2,
)

# 3. Instantiate the Teacher Model
teacher_model = TeacherModel(env.observation_space, env.action_space)

# 4. Set up the Optimizer
optimizer = torch.optim.Adam(teacher_model.parameters(), lr=0.001)

# 5. Training Loop (Example - Replace with your training logic)
for episode in range(num_episodes):
    obs = env.reset()
    done = False
    while not done:
        # Get the teacher's action
        action = teacher_model(torch.tensor(obs, dtype=torch.float32))
        action = torch.argmax(action).item()  # Choose the action with highest probability

        # Take the action in the environment
        next_obs, reward, done, info = env.step(action)

        # Calculate loss and update the teacher model
        # ... (your loss calculation and optimization logic here) ...

        obs = next_obs

# 6. Save the Trained Teacher Model
torch.save(teacher_model.state_dict(), "teacher_model.pth")