### A3C-Style Actor-Critic (Synchronous Implementation)

**Note:** This is a simplified synchronous implementation inspired by A3C architecture. True A3C uses asynchronous parallel workers updating a shared global model.

**Principle:** Multi-Worker Parallel Learning (Synchronous Variant)

**Definition:** Runs multiple agents in parallel, synchronously updating a shared model for diverse experience collection.

**Algorithm Description:** This notebook implements a synchronous variant of A3C (Asynchronous Advantage Actor-Critic). While true A3C runs multiple actor-learner agents asynchronously on separate threads/processes, this implementation collects experience from multiple parallel environments synchronously and performs batched updates. This approach still benefits from diverse parallel exploration while being simpler to implement and debug.

**Typical Use Cases:**
- Actor-critic method with parallel environment exploration
- Synchronous updates to a global model (simpler than true async)
- Educational demonstration of multi-worker training
- Works with continuous and discrete action spaces

**Assumptions:**
- Continuous/discrete actions
- Large datasets
- On-policy learning
- Parallel environments (synchronous updates)

### 1. Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import torch.multiprocessing as mp

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

### 2. A3C (Asynchronous Advantage Actor-Critic)

A3C uses multiple parallel workers to collect experience:

**Key Features:**
1. **Multiple Workers:** Each worker has own environment instance
2. **Asynchronous Updates:** Workers update global network independently
3. **Shared Network:** All workers share same global parameters

**Advantage:** Parallel exploration improves sample efficiency

**Note:** This is a simplified synchronous version for demonstration

In [None]:
class A3CNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU()
        )
        self.actor = nn.Linear(128, action_dim)
        self.critic = nn.Linear(128, 1)
    
    def forward(self, x):
        x = self.fc(x)
        return F.softmax(self.actor(x), dim=-1), self.critic(x)

In [None]:
class A3CAgent:
    """Simplified A3C (synchronous version)"""
    
    def __init__(self, state_dim, action_dim, n_workers=4, lr=3e-4, gamma=0.99):
        self.gamma = gamma
        self.n_workers = n_workers
        
        # Global shared network
        self.global_network = A3CNetwork(state_dim, action_dim)
        self.global_network.share_memory()  # Share across processes
        self.optimizer = optim.Adam(self.global_network.parameters(), lr=lr)
        
        # Create multiple environments (workers)
        self.envs = [gym.make('CartPole-v1') for _ in range(n_workers)]
    
    def worker_rollout(self, env, steps=5):
        """Collect experience from one worker"""
        states, actions, rewards, next_states, dones = [], [], [], [], []
        state, _ = env.reset()
        
        for _ in range(steps):
            state_t = torch.FloatTensor(state)
            with torch.no_grad():
                probs, _ = self.global_network(state_t)
                action = Categorical(probs).sample().item()
            
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            next_states.append(next_state)
            dones.append(done)
            
            state = next_state
            if done:
                state, _ = env.reset()
        
        return states, actions, rewards, next_states, dones
    
    def update(self):
        """Collect experience from all workers and update"""
        all_states, all_actions, all_rewards = [], [], []
        all_next_states, all_dones = [], []
        
        # Collect from all workers
        for env in self.envs:
            s, a, r, ns, d = self.worker_rollout(env, steps=5)
            all_states.extend(s)
            all_actions.extend(a)
            all_rewards.extend(r)
            all_next_states.extend(ns)
            all_dones.extend(d)
        
        # Convert to tensors
        states = torch.FloatTensor(all_states)
        actions = torch.LongTensor(all_actions)
        rewards = torch.FloatTensor(all_rewards)
        next_states = torch.FloatTensor(all_next_states)
        dones = torch.FloatTensor(all_dones)
        
        # Compute loss
        probs, values = self.global_network(states)
        dist = Categorical(probs)
        log_probs = dist.log_prob(actions)
        
        with torch.no_grad():
            _, next_values = self.global_network(next_states)
            targets = rewards + self.gamma * next_values.squeeze() * (1 - dones)
            advantages = targets - values.squeeze()
        
        actor_loss = -(log_probs * advantages).mean()
        critic_loss = F.mse_loss(values.squeeze(), targets)
        entropy = dist.entropy().mean()
        
        loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy
        
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping for training stability
        torch.nn.utils.clip_grad_norm_(self.global_network.parameters(), max_norm=0.5)
        
        self.optimizer.step()
        
        return loss.item()

### 3. Train A3C Agent

In [None]:
state_dim = 4  # CartPole
action_dim = 2
agent = A3CAgent(state_dim, action_dim, n_workers=4)

# Evaluation environment
eval_env = gym.make('CartPole-v1')
episode_rewards = []

print('Training A3C with 4 workers...')
for episode in range(500):
    # Update from all workers
    agent.update()
    
    # Evaluate every 10 episodes
    if episode % 10 == 0:
        state, _ = eval_env.reset()
        total_reward = 0
        for _ in range(500):
            with torch.no_grad():
                probs, _ = agent.global_network(torch.FloatTensor(state))
                action = probs.argmax().item()
            state, reward, terminated, truncated, _ = eval_env.step(action)
            total_reward += reward
            if terminated or truncated:
                break
        episode_rewards.append(total_reward)
    
    if (episode + 1) % 50 == 0:
        print(f'Episode {episode+1}, Avg Eval Reward: {np.mean(episode_rewards[-5:]):.2f}')

### 4. Visualize Results

In [None]:
plt.figure(figsize=(12, 5))
plt.plot(episode_rewards, marker='o', linewidth=2, label='Eval Reward')
plt.axhline(195, color='r', linestyle='--', label='Solved')
plt.xlabel('Evaluation Episode')
plt.ylabel('Reward')
plt.title('A3C Training Performance (4 Workers)')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print(f'\nFinal Performance: {np.mean(episode_rewards[-10:]):.2f}')
eval_env.close()
for env in agent.envs:
    env.close()