# A3C for Kung Fu

## Part 0 - Installing the required packages and importing the libraries

### Installing Gymnasium

In [2]:
!pip install gymnasium
!pip install "gymnasium[atari, accept-rom-license]"
!pip install ale-py
!apt-get install -y swig
!pip install gymnasium[box2d]

E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?


### Importing the libraries

In [3]:
import cv2
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributions as distributions
from torch.distributions import Categorical
import ale_py
import gymnasium as gym
from gymnasium.spaces import Box
from gymnasium import ObservationWrapper

## Part 1 - Building the AI

### Creating the architecture of the Neural Network

In [4]:
class Network(nn.Module):
    """
    Actor-Critic Network for Advantage Actor-Critic (A3C) algorithm.

    This neural network serves two purposes:
    1. **Policy estimation (Actor):** Outputs action probabilities based on the input state.
    2. **State-value estimation (Critic):** Outputs a scalar value representing the value of the input state.

    The network consists of convolutional layers for feature extraction followed by fully connected layers for computing
    the policy and value function.

    Attributes:
        conv1 (torch.nn.Conv2d): First convolutional layer with 4 input channels, 32 output channels, and a 3x3 kernel.
        conv2 (torch.nn.Conv2d): Second convolutional layer with 32 input channels, 32 output channels, and a 3x3 kernel.
        conv3 (torch.nn.Conv2d): Third convolutional layer with 32 input channels, 32 output channels, and a 3x3 kernel.
        flatten (torch.nn.Flatten): Flattens the output of the convolutional layers into a 1D tensor.
        fc1 (torch.nn.Linear): Fully connected layer that maps 512 features to 128 features.
        fc2a (torch.nn.Linear): Fully connected layer for action values, mapping 128 features to the action space.
        fc2s (torch.nn.Linear): Fully connected layer for state-value estimation, mapping 128 features to a single value.
    """

    def __init__(self, action_size):
        """
        Initializes the network layers.

        Args:
            action_size (int): The size of the action space. Determines the output dimension of the actor network.
        """
        super(Network, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=4, out_channels=32, kernel_size=(3, 3), stride=2)
        self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=2)
        self.conv3 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=2)
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(512, 128)
        self.fc2a = torch.nn.Linear(128, action_size)
        self.fc2s = torch.nn.Linear(128, 1)

    def forward(self, state):
        """
        Forward pass through the network.

        Args:
            state (torch.Tensor): Input state tensor with shape (batch_size, 4, height, width), where 4 corresponds
                                  to the number of input channels (e.g., stacked frames).

        Returns:
            tuple:
                action_values (torch.Tensor): Tensor of shape (batch_size, action_size) representing the action scores.
                states_value (torch.Tensor): Tensor of shape (batch_size, 1) representing the scalar value of the state.
        """
        x = self.conv1(state)
        x = F.relu(x)  # Apply ReLU activation
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.flatten(x)  # Flatten convolutional features
        x = self.fc1(x)
        x = F.relu(x)
        action_values = self.fc2a(x)  # Compute action values
        states_value = self.fc2s(x)[0]  # Compute state value
        return action_values, states_value


## Part 2 - Training the AI

### Setting up the environment

In [9]:
class PreprocessAtari(ObservationWrapper):

  def __init__(self, env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4):
    super(PreprocessAtari, self).__init__(env)
    self.img_size = (height, width)
    self.crop = crop
    self.dim_order = dim_order
    self.color = color
    self.frame_stack = n_frames
    n_channels = 3 * n_frames if color else n_frames
    obs_shape = {'tensorflow': (height, width, n_channels), 'pytorch': (n_channels, height, width)}[dim_order]
    self.observation_space = Box(0.0, 1.0, obs_shape)
    self.frames = np.zeros(obs_shape, dtype = np.float32)

  def reset(self):
    self.frames = np.zeros_like(self.frames)
    obs, info = self.env.reset()
    self.update_buffer(obs)
    return self.frames, info

  def observation(self, img):
    img = self.crop(img)
    img = cv2.resize(img, self.img_size)
    if not self.color:
      if len(img.shape) == 3 and img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img.astype('float32') / 255.
    if self.color:
      self.frames = np.roll(self.frames, shift = -3, axis = 0)
    else:
      self.frames = np.roll(self.frames, shift = -1, axis = 0)
    if self.color:
      self.frames[-3:] = img
    else:
      self.frames[-1] = img
    return self.frames

  def update_buffer(self, obs):
    self.frames = self.observation(obs)

def make_env():
  env = gym.make("KungFuMasterDeterministic-v0", render_mode = 'rgb_array')
  env = PreprocessAtari(env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4)
  return env

env = make_env()

state_shape = env.observation_space.shape
number_actions = env.action_space.n
print("State shape:", state_shape)
print("Number actions:", number_actions)
print("Action names:", env.env.env.env.get_action_meanings())

  logger.deprecation(
A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]


State shape: (4, 42, 42)
Number actions: 14
Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']


### Initializing the hyperparameters

In [5]:
learning_rate = 1e-4
discount_factor = 0.99
number_environments = 10

### Implementing the A3C class

In [7]:
class Agent():
    """
    The Agent class implements the Actor-Critic architecture for the A3C algorithm. It interacts with the environment,
    predicts actions, and updates the network based on the observed transitions.

    Attributes:
        device (torch.device): The device (CPU or GPU) on which computations will be performed.
        action_size (int): The size of the action space.
        network (Network): The neural network used for policy and value predictions.
        optimizer (torch.optim.Optimizer): The optimizer used to update the network parameters.
    """

    def __init__(self, action_size):
        """
        Initializes the agent with a neural network, optimizer, and device configuration.

        Args:
            action_size (int): The size of the action space. Determines the output dimension of the policy network.
        """
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.action_size = action_size
        self.network = Network(action_size).to(self.device)  # Instantiate the policy-value network
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=learning_rate)  # Adam optimizer

    def act(self, state):
        """
        Selects actions for a given state using the policy network.

        Args:
            state (np.ndarray): The input state(s) of shape (height, width, channels) or 
                                (batch_size, height, width, channels).

        Returns:
            np.ndarray: An array of selected actions for each state in the batch.
        """
        if state.ndim == 3:  # Single state case
            state = [state]  # Convert to batch with one state
        state = torch.tensor(state, dtype=torch.float32, device=self.device)  # Convert state to tensor
        action_values, _ = self.network(state)  # Get action values (policy logits) and state value
        policy = F.softmax(action_values, dim=-1)  # Convert logits to probabilities
        return np.array([
            np.random.choice(len(p), p=p) for p in policy.detach().cpu().numpy()
        ])  # Sample actions from the policy

    def step(self, state, action, reward, next_state, done):
        """
        Updates the network parameters using the observed transition.

        Args:
            state (np.ndarray): Current state(s) with shape (batch_size, height, width, channels).
            action (np.ndarray): Array of actions taken with shape (batch_size,).
            reward (np.ndarray): Array of rewards received with shape (batch_size,).
            next_state (np.ndarray): Next state(s) with shape (batch_size, height, width, channels).
            done (np.ndarray): Array of done flags (0 or 1) with shape (batch_size,).

        Process:
        - Converts inputs to PyTorch tensors.
        - Computes the target state value using rewards, next state value, and the discount factor.
        - Calculates the advantage by subtracting the predicted state value from the target state value.
        - Computes the actor loss (policy gradient with entropy regularization) and critic loss (mean squared error).
        - Optimizes the total loss (actor + critic).

        Notes:
        - The entropy term encourages exploration by regularizing the policy to avoid being overly deterministic.
        """
        batch_size = state.shape[0]
        
        # Convert numpy arrays to PyTorch tensors
        state = torch.tensor(state, dtype=torch.float32, device=self.device)
        next_state = torch.tensor(next_state, dtype=torch.float32, device=self.device)
        reward = torch.tensor(reward, dtype=torch.float32, device=self.device)
        done = torch.tensor(done, dtype=torch.bool, device=self.device).to(dtype=torch.float32)
        
        # Forward pass for current and next states
        action_values, state_value = self.network(state)
        _, next_state_value = self.network(next_state)
        
        # Compute target state value
        target_state_value = reward + discount_factor * next_state_value * (1 - done)
        
        # Calculate advantage
        advantage = target_state_value - state_value
        
        # Compute policy (logits to probabilities) and entropy
        probs = F.softmax(action_values, dim=-1)
        logprobs = F.log_softmax(action_values, dim=-1)
        entropy = -torch.sum(probs * logprobs, axis=-1)
        
        # Compute actor loss
        batch_idx = np.arange(batch_size)
        logp_actions = logprobs[batch_idx, action]
        actor_loss = -(logp_actions * advantage.detach()).mean() - 0.001 * entropy.mean()
        
        # Compute critic loss
        critic_loss = F.mse_loss(target_state_value.detach(), state_value)
        
        # Total loss and optimization
        total_loss = actor_loss + critic_loss
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()


### Initializing the A3C agent

In [11]:
agent = Agent(number_actions)

### Evaluating our A3C agent on a single episode

In [13]:
def evaluate(agent, env, n_episodes=1):
    """
    Evaluates the performance of a trained agent in a given environment.

    The function runs the agent for a specified number of episodes in the environment,
    collects the total rewards achieved in each episode, and returns them.

    Args:
        agent (Agent): The trained agent to be evaluated. Must implement the `act` method.
        env (gym.Env): The environment in which the agent will be evaluated. Must implement `reset` and `step` methods.
        n_episodes (int, optional): The number of episodes to run for evaluation. Defaults to 1.

    Returns:
        list[float]: A list containing the total rewards achieved in each episode. The length of the list
                     equals the value of `n_episodes`.

    Process:
    - Resets the environment to get the initial state.
    - Repeatedly selects an action using the agent's policy until the episode ends.
    - Tracks the cumulative reward for each episode and appends it to the results list.
    """
    episodes_rewards = []  # List to store total rewards for each episode

    for _ in range(n_episodes):
        state, _ = env.reset()  # Reset the environment and get the initial state
        total_reward = 0  # Initialize the total reward for this episode
        
        while True:
            action = agent.act(state)  # Use the agent to select an action
            # Step in the environment using the selected action
            state, reward, done, info, _ = env.step(action[0])
            total_reward += reward  # Accumulate the reward

            if done:  # Check if the episode has ended
                break
        
        episodes_rewards.append(total_reward)  # Store the total reward for this episode

    return episodes_rewards  # Return the list of total rewards


### Testing multiple agents on multiple environments at the same time

In [15]:
class EnvBatch:
    """
    A batch environment wrapper to manage and interact with multiple environments simultaneously.

    This class allows parallel interactions with multiple environments, useful for training
    reinforcement learning agents that benefit from diverse experiences.

    Attributes:
        envs (list): A list of individual environments, created using the `make_env` function.
    """

    def __init__(self, n_envs=10):
        """
        Initializes the batch environment.

        Args:
            n_envs (int, optional): The number of environments to create. Defaults to 10.
        """
        self.envs = [make_env() for _ in range(n_envs)]  # Create a list of environments

    def reset(self):
        """
        Resets all environments in the batch and returns their initial states.

        Returns:
            np.ndarray: A NumPy array of initial states from all environments. The shape is (n_envs, ...),
                        where `n_envs` is the number of environments.
        """
        _states = []
        for env in self.envs:
            _states.append(env.reset()[0])  # Reset each environment and get the initial state
        return np.array(_states)  # Return initial states as a NumPy array

    def step(self, actions):
        """
        Executes a batch of actions across all environments and collects the resulting transitions.

        Args:
            actions (iterable): A list or array of actions to execute in each environment. The length
                                of `actions` must match the number of environments.

        Returns:
            tuple:
                next_states (np.ndarray): A NumPy array of next states for all environments.
                rewards (np.ndarray): A NumPy array of rewards received from all environments.
                dones (np.ndarray): A NumPy array of boolean flags indicating whether each environment
                                    has finished the episode.
                infos (np.ndarray): A NumPy array of info dictionaries returned by each environment.

        Process:
        - For each environment and corresponding action, execute `env.step(action)` to get the
          transition data.
        - If an environment is done (`dones[i]` is True), reset that environment and replace its
          next state with the reset state.
        """
        # Perform one step in each environment with the corresponding action
        next_states, rewards, dones, infos, _ = map(
            np.array,
            zip(*[env.step(a) for env, a in zip(self.envs, actions)]),
        )  # * unpacks results from each environment

        # Reset environments that have completed their episodes
        for i in range(len(self.envs)):
            if dones[i]:
                next_states[i] = self.envs[i].reset()[0]  # Replace with reset state

        return next_states, rewards, dones, infos


### Training the A3C agent

In [17]:
import tqdm

# Create a batch of environments
env_batch = EnvBatch(number_environments)  # Initialize the batch environment with the specified number of environments
batch_states = env_batch.reset()  # Reset all environments and get the initial states

# Initialize the progress bar for 3001 iterations
with tqdm.trange(0, 3001) as progress_bar:
    for i in progress_bar:
        # Generate actions for the current batch of states using the agent's policy
        batch_actions = agent.act(batch_states)
        
        # Perform one step in the batch of environments with the chosen actions
        batch_next_states, batch_rewards, batch_dones, _ = env_batch.step(batch_actions)
        
        # Scale rewards for stabilization (e.g., normalization or reducing variance)
        batch_rewards *= 0.01
        
        # Update the agent using the observed transitions
        # Includes both actor (policy) and critic (value function) updates
        agent.step(batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones)
        
        # Update the current states for the next iteration
        batch_states = batch_next_states
        
        # Periodically evaluate the agent's performance
        if i % 1000 == 0:
            avg_reward = np.mean(evaluate(agent, env, n_episodes=10))  # Evaluate over 10 episodes
            print("Average agent reward: ", avg_reward)


  logger.deprecation(
  critic_loss = F.mse_loss(target_state_value.detach(), state_value)
  state = torch.tensor(state, dtype=torch.float32, device=self.device)  # Convert state to tensor
  0%|▎                                                                              | 10/3001 [00:23<1:23:32,  1.68s/it]

Average agent reward:  510.0


 34%|██████████████████████████▌                                                    | 1009/3001 [00:58<23:29,  1.41it/s]

Average agent reward:  750.0


 67%|████████████████████████████████████████████████████▉                          | 2009/3001 [01:35<11:20,  1.46it/s]

Average agent reward:  970.0


100%|███████████████████████████████████████████████████████████████████████████████| 3001/3001 [02:14<00:00, 22.36it/s]

Average agent reward:  1140.0





## Part 3 - Visualizing the results

In [25]:
import glob
import io
import base64
import imageio
from IPython.display import HTML, display
import os

# Set the path to FFmpeg explicitly
os.environ["IMAGEIO_FFMPEG_EXE"] = "/home/reefk/miniforge3/envs/ml-gpu-env/bin/ffmpeg"

def show_video_of_model(agent, env, filename='kung_fu_video.mp4'):
    """
    Records a video of the agent interacting with the specified environment.

    Args:
        agent (object): The agent to evaluate.
        env (gym.Env): The environment object to interact with.
        filename (str): Name of the output video file. Defaults to 'kung_fu_video.mp4'.
    """
    state, _ = env.reset()
    done = False
    frames = []
    while not done:
        frame = env.render()
        frames.append(frame)
        action = agent.act(state)
        state, reward, done, _, _ = env.step(action[0])
    env.close()
    imageio.mimsave(filename, frames, fps=30)  # Save the video to the specified filename

def show_video(filename='kung_fu_video.mp4'):
    """
    Displays a video saved as an MP4 file in the Jupyter Notebook.

    Args:
        filename (str): Name of the video file to display.
    """
    if os.path.exists(filename):  # Check if the file exists
        with open(filename, 'rb') as video_file:
            video = video_file.read()
        encoded = base64.b64encode(video)
        display(HTML(data=f'''
            <video alt="test" autoplay loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{encoded.decode('ascii')}" type="video/mp4" />
            </video>
        '''))
    else:
        print(f"Could not find video: {filename}")

# Record a video for the Kung Fu environment
show_video_of_model(agent, env, filename='kung_fu_video.mp4')

# Display the recorded video
show_video(filename='kung_fu_video.mp4')


