## A3C (Asynchronous Advantage Actor-Critic algorithm)
This is basically Advantage Actor-Critic but with multiple parallel agents.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import gymnasium as gym
import numpy as np
import time
from collections import deque
from Environment import TexasHoldem6PlayerEnv

In [None]:
# Define Actor-Critic Network
class ActorCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(ActorCritic, self).__init__()
        
        self.shared = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
        )
        
        self.actor = nn.Sequential(
            nn.Linear(256, action_dim),
            nn.Softmax(dim=-1)
        )
        
        self.critic = nn.Linear(256, 1)

    def forward(self, x):
        x = self.shared(x)
        return self.actor(x), self.critic(x)

In [46]:
# Worker function for parallel training
def worker(global_model, optimizer, rank, env_name, gamma, update_steps):
    env = gym.make(env_name)
    local_model = ActorCritic(env.observation_space.shape[0], env.action_space.n)
    local_model.load_state_dict(global_model.state_dict())
    
    while True:
        state, _ = env.reset()  # state is initialized
        state = torch.tensor(state, dtype=torch.float32)
        log_probs, values, rewards = [], [], []
        done = False
        episode_reward = 0
        
        for _ in range(update_steps):
            probs, value = local_model(state)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()
            
            next_state, reward, done, _, _ = env.step(action.item())  #done
            log_probs.append(dist.log_prob(action))
            values.append(value)
            rewards.append(reward)
            
            state = torch.tensor(next_state, dtype=torch.float32)
            
            if done:
                break
            
        print(f"Worker {rank}: Episode Reward = {episode_reward}") 
        
       #Compute advantages and returns after done is defined 
        R = 0 if done else local_model(state)[1].item()
        returns = []
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32)
        values = torch.cat(values)
        log_probs = torch.cat(log_probs)
        
        advantage = returns - values
        
        # Compute loss
        actor_loss = -(log_probs * advantage.detach()).mean()
        critic_loss = advantage.pow(2).mean()
        loss = actor_loss + 0.5 * critic_loss
        
        # Update global model
        optimizer.zero_grad()
        loss.backward()
        for global_param, local_param in zip(global_model.parameters(), local_model.parameters()):
            global_param._grad = local_param.grad
        optimizer.step()
        
        local_model.load_state_dict(global_model.state_dict())

        print(f"Worker {rank}: Actor Loss = {actor_loss.item():.4f}, Critic Loss = {critic_loss.item():.4f}")


In [47]:
# Register the environment (Only Needs to be Done Once)
try:
    gym.envs.registration.register(
        id="TexasHoldem6PlayerEnv-v0",  
        entry_point="Environment:TexasHoldem6PlayerEnv",
    )
except gym.error.Error:
    pass  # Ignore if it's already registered

In [None]:
# Main function to launch parallel training
def train_a3c(env_name="TexasHoldem6PlayerEnv-v0", num_workers=4, gamma=0.99, update_steps=50):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    global_model = ActorCritic(input_dim, action_dim)
    global_model.share_memory()
    optimizer = optim.Adam(global_model.parameters(), lr=1e-4)
    
    processes = []
    for rank in range(num_workers):
        p = mp.Process(target=worker, args=(global_model, optimizer, rank, env_name, gamma, update_steps))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()

if __name__ == "__main__":
    try:
        mp.set_start_method("spawn")
    except RuntimeError:
        pass  
    train_a3c()

    print("Training complete!")

Training complete!


In [None]:
def test_agent(env_name="TexasHoldem6PlayerEnv-v0", model_path="trained_a3c_model.pth"):
    env = gym.make(env_name)
    state, _ = env.reset()
    state = torch.tensor(state, dtype=torch.float32)

    input_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    model = ActorCritic(input_dim, action_dim)
    model.load_state_dict(torch.load(model_path))  # Load trained weights
    model.eval()  # Set model to evaluation mode

    total_reward = 0
    done = False

    print("\nStarting Test Episode...\n")

    while not done:
        with torch.no_grad():
            probs, _ = model(state)  # Get action probabilities
            action = torch.argmax(probs).item()  # Choose best action

        next_state, reward, done, _, _ = env.step(action)  # Take action
        total_reward += reward  # Accumulate reward
        state = torch.tensor(next_state, dtype=torch.float32)

    print(f"\nTest Completed: Total Reward = {total_reward}")

    test_agent() 
