# A3C for Kung Fu

## Part 0 - Installing the required packages and importing the libraries

### Installing Gymnasium

In [None]:
!pip install gymnasium
!pip install "gymnasium[atari, accept-rom-license]"
!apt-get install -y swig
!pip install gymnasium[box2d]

  and should_run_async(code)


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
swig is already the newest version (4.0.2-1ubuntu1).
0 upgraded, 0 newly installed, 0 to remove and 45 not upgraded.


### Importing the libraries

In [None]:
import cv2
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributions as distributions
from torch.distributions import Categorical
import gymnasium as gym
from gymnasium import ObservationWrapper
from gymnasium.spaces import Box

## Part 1 - Building the AI

### Creating the architecture of the Neural Network

In [None]:
class Network(nn.Module):

  def __init__(self, action_size):
    super(Network, self).__init__()
    self.conv1 = torch.nn.Conv2d(in_channels = 4, out_channels = 32, kernel_size = (3,3), stride = 2)
    self.conv2 = torch.nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = (3,3), stride = 2)
    self.conv3 = torch.nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = (3,3), stride = 2)
    self.flatten = torch.nn.Flatten()
    self.fc1 = torch.nn.Linear(512, 128)
    self.fc2a = torch.nn.Linear(128, action_size)
    self.fc2s = torch.nn.Linear(128, 1)

  def forward(self, state):
    x = self.conv1(state)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = self.conv3(x)
    x = F.relu(x)

    x = self.flatten(x)

    x = self.fc1(x)
    x = F.relu(x)
    action_values = self.fc2a(x)
    state_value = self.fc2s(x)[0]

    return action_values, state_value


## Part 2 - Training the AI

### Setting up the environment

In [None]:
class PreprocessAtari(ObservationWrapper):

  def __init__(self, env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4):
    super(PreprocessAtari, self).__init__(env)
    self.img_size = (height, width)
    self.crop = crop
    self.dim_order = dim_order
    self.color = color
    self.frame_stack = n_frames
    n_channels = 3 * n_frames if color else n_frames
    obs_shape = {'tensorflow': (height, width, n_channels), 'pytorch': (n_channels, height, width)}[dim_order]
    self.observation_space = Box(0.0, 1.0, obs_shape)
    self.frames = np.zeros(obs_shape, dtype = np.float32)

  def reset(self):
    self.frames = np.zeros_like(self.frames)
    obs, info = self.env.reset()
    self.update_buffer(obs)
    return self.frames, info

  def observation(self, img):
    img = self.crop(img)
    img = cv2.resize(img, self.img_size)
    if not self.color:
      if len(img.shape) == 3 and img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img.astype('float32') / 255.
    if self.color:
      self.frames = np.roll(self.frames, shift = -3, axis = 0)
    else:
      self.frames = np.roll(self.frames, shift = -1, axis = 0)
    if self.color:
      self.frames[-3:] = img
    else:
      self.frames[-1] = img
    return self.frames

  def update_buffer(self, obs):
    self.frames = self.observation(obs)

def make_env():
  env = gym.make("KungFuMasterDeterministic-v0", render_mode = 'rgb_array')
  env = PreprocessAtari(env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4)
  return env

env = make_env()

state_shape = env.observation_space.shape
number_actions = env.action_space.n
print("Observation shape:", state_shape)
print("Number actions:", number_actions)
print("Action names:", env.env.env.get_action_meanings())

Observation shape: (4, 42, 42)
Number actions: 14
Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']


  logger.deprecation(
  logger.warn(


### Initializing the hyperparameters

In [None]:
learning_rate = 1e-4
discount_factor = 0.99
number_environments = 10

### Implementing the A3C class

In [None]:
class Agent():

  def __init__(self, action_size):
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.action_size = action_size
    self.network = Network(action_size).to(self.device)
    self.optimizer = torch.optim.Adam(self.network.parameters(), lr = learning_rate)

  def act(self, state):
    if state.ndim == 3:
      state = [state]
    state = torch.tensor(state, dtype = torch.float32, device = self.device)
    action_values, _ = self.network(state)
    policy = F.softmax(action_values, dim = -1)
    return np.array([np.random.choice(len(p), p = p) for p in policy.detach().cpu().numpy()])

  def step(self, state, action, reward, next_state, done):
    batch_size = state.shape[0]
    state = torch.tensor(state, dtype = torch.float32, device = self.device)
    next_state = torch.tensor(next_state, dtype = torch.float32, device = self.device)
    reward = torch.tensor(reward, dtype = torch.float32, device = self.device)
    done = torch.tensor(done, dtype = torch.bool, device = self.device).to(dtype = torch.float32)
    action_values, state_value = self.network(state)
    _, next_state_value = self.network(next_state)
    target_state_value = reward + discount_factor * next_state_value * (1 - done)
    advantage = target_state_value - state_value
    probs = F.softmax(action_values, dim = -1)
    logprobs = F.log_softmax(action_values, dim = -1)
    entropy = -torch.sum(probs * logprobs, axis = -1)
    batch_idx = np.arange(batch_size)
    logp_actions = logprobs[batch_idx, action]
    actor_loss = -(logp_actions * advantage.detach()).mean() - 0.001 * entropy.mean()
    critic_loss = F.mse_loss(target_state_value.detach(), state_value)
    total_loss = actor_loss + critic_loss
    self.optimizer.zero_grad()
    total_loss.backward()
    self.optimizer.step()

### Initializing the A3C agent

In [None]:
agent = Agent(number_actions)

### Evaluating our A3C agent on a single episode

In [None]:
def evaluate(agent, env, n_episodes = 1):
  episodes_rewards = []
  for _ in range(n_episodes):
    state, _ = env.reset()
    total_reward = 0
    while True:
      action = agent.act(state)
      state, reward, done, info, _ = env.step(action[0])
      total_reward += reward
      if done:
        break
    episodes_rewards.append(total_reward)
  return episodes_rewards

### Testing multiple agents on multiple environments at the same time

In [None]:
class EnvBatch():

  def __init__(self, n_envs = 10):
    self.envs = [make_env() for _ in range(n_envs)]

  def reset(self):
    _states = []
    for env in self.envs:
      _states.append(env.reset()[0])
    return np.array(_states)

  def step(self, actions):
    next_states, rewards, dones, infos, _ = map(np.array, zip(*[env.step(a) for env, a in zip(self.envs, actions)]))
    for i in range(len(self.envs)):
      if dones[i]:
        next_states[i] = self.envs[i].reset()[0]
    return next_states, rewards, dones, infos

### Training the A3C agent

In [None]:
import tqdm

env_batch = EnvBatch(number_environments)
batch_states = env_batch.reset()

with tqdm.trange(0, 50001) as progress_bar:
  for i in progress_bar:
    batch_actions = agent.act(batch_states)
    batch_next_states, batch_rewards, batch_dones, _ = env_batch.step(batch_actions)
    batch_rewards *= 0.01
    agent.step(batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones)
    batch_states = batch_next_states
    if i % 1000 == 0:
      print("Average agent reward: ", np.mean(evaluate(agent, env, n_episodes = 10)))

  critic_loss = F.mse_loss(target_state_value.detach(), state_value)
  0%|          | 4/50001 [00:36<95:14:21,  6.86s/it] 

Average agent reward:  250.0


  2%|▏         | 1005/50001 [01:50<33:44:05,  2.48s/it]

Average agent reward:  330.0


  4%|▍         | 2005/50001 [02:59<28:28:04,  2.14s/it]

Average agent reward:  310.0


  6%|▌         | 3005/50001 [04:13<32:08:49,  2.46s/it]

Average agent reward:  680.0


  8%|▊         | 4005/50001 [05:24<27:24:21,  2.14s/it]

Average agent reward:  470.0


 10%|█         | 5005/50001 [06:26<21:01:09,  1.68s/it]

Average agent reward:  120.0


 12%|█▏        | 6004/50001 [07:38<33:33:33,  2.75s/it]

Average agent reward:  650.0


 14%|█▍        | 7005/50001 [08:52<32:42:59,  2.74s/it]

Average agent reward:  710.0


 16%|█▌        | 8006/50001 [09:56<22:10:14,  1.90s/it]

Average agent reward:  380.0


 18%|█▊        | 9005/50001 [11:07<26:13:56,  2.30s/it]

Average agent reward:  920.0


 20%|██        | 10004/50001 [12:16<32:30:10,  2.93s/it]

Average agent reward:  440.0


 22%|██▏       | 11005/50001 [13:24<22:58:27,  2.12s/it]

Average agent reward:  670.0


 24%|██▍       | 12005/50001 [14:32<21:48:47,  2.07s/it]

Average agent reward:  810.0


 26%|██▌       | 13005/50001 [15:39<25:15:38,  2.46s/it]

Average agent reward:  440.0


 28%|██▊       | 14005/50001 [16:48<20:19:34,  2.03s/it]

Average agent reward:  540.0


 30%|███       | 15005/50001 [17:59<23:36:26,  2.43s/it]

Average agent reward:  590.0


 32%|███▏      | 16004/50001 [19:04<22:04:38,  2.34s/it]

Average agent reward:  450.0


 34%|███▍      | 17005/50001 [20:07<16:16:07,  1.77s/it]

Average agent reward:  130.0


 36%|███▌      | 18005/50001 [21:11<18:24:11,  2.07s/it]

Average agent reward:  270.0


 38%|███▊      | 19005/50001 [22:25<22:17:00,  2.59s/it]

Average agent reward:  710.0


 40%|████      | 20004/50001 [23:36<19:04:00,  2.29s/it]

Average agent reward:  560.0


 42%|████▏     | 21005/50001 [24:47<18:01:29,  2.24s/it]

Average agent reward:  580.0


 44%|████▍     | 22005/50001 [25:51<13:15:07,  1.70s/it]

Average agent reward:  330.0


 46%|████▌     | 23005/50001 [26:59<14:47:14,  1.97s/it]

Average agent reward:  560.0


 48%|████▊     | 24004/50001 [28:09<18:36:57,  2.58s/it]

Average agent reward:  630.0


 50%|█████     | 25004/50001 [29:26<22:05:16,  3.18s/it]

Average agent reward:  1190.0


 52%|█████▏    | 26005/50001 [30:44<16:50:37,  2.53s/it]

Average agent reward:  930.0


 54%|█████▍    | 27005/50001 [31:50<11:54:57,  1.87s/it]

Average agent reward:  390.0


 56%|█████▌    | 28005/50001 [33:08<17:00:58,  2.78s/it]

Average agent reward:  1060.0


 58%|█████▊    | 29005/50001 [34:24<13:50:43,  2.37s/it]

Average agent reward:  1080.0


 60%|██████    | 30005/50001 [35:34<12:29:31,  2.25s/it]

Average agent reward:  790.0


 62%|██████▏   | 31005/50001 [36:43<14:22:06,  2.72s/it]

Average agent reward:  440.0


 64%|██████▍   | 32004/50001 [37:57<12:24:06,  2.48s/it]

Average agent reward:  880.0


 66%|██████▌   | 33005/50001 [39:18<14:51:23,  3.15s/it]

Average agent reward:  1320.0


 68%|██████▊   | 34005/50001 [40:26<9:11:40,  2.07s/it] 

Average agent reward:  590.0


 70%|███████   | 35005/50001 [41:31<9:08:13,  2.19s/it] 

Average agent reward:  540.0


 72%|███████▏  | 36005/50001 [42:35<7:56:53,  2.04s/it] 

Average agent reward:  340.0


 74%|███████▍  | 37004/50001 [43:41<8:13:08,  2.28s/it] 

Average agent reward:  400.0


 76%|███████▌  | 38005/50001 [44:46<6:37:00,  1.99s/it]

Average agent reward:  310.0


 78%|███████▊  | 39005/50001 [45:53<7:11:32,  2.35s/it] 

Average agent reward:  480.0


 80%|████████  | 40004/50001 [46:56<5:58:57,  2.15s/it]

Average agent reward:  170.0


 82%|████████▏ | 41005/50001 [48:06<6:11:49,  2.48s/it]

Average agent reward:  510.0


 84%|████████▍ | 42005/50001 [49:19<5:56:49,  2.68s/it]

Average agent reward:  740.0


 86%|████████▌ | 43004/50001 [50:29<6:08:27,  3.16s/it]

Average agent reward:  670.0


 88%|████████▊ | 44005/50001 [51:35<3:48:27,  2.29s/it]

Average agent reward:  800.0


 90%|█████████ | 45005/50001 [52:43<3:11:56,  2.31s/it]

Average agent reward:  470.0


 92%|█████████▏| 46005/50001 [53:48<2:47:29,  2.51s/it]

Average agent reward:  250.0


 94%|█████████▍| 47005/50001 [54:58<2:03:01,  2.46s/it]

Average agent reward:  860.0


 96%|█████████▌| 48005/50001 [56:01<1:02:02,  1.86s/it]

Average agent reward:  320.0


 98%|█████████▊| 49003/50001 [57:16<52:12,  3.14s/it]  

Average agent reward:  800.0


100%|██████████| 50001/50001 [58:26<00:00, 14.26it/s]

Average agent reward:  650.0





## Part 3 - Visualizing the results

In [None]:
import glob
import io
import base64
import imageio
from IPython.display import HTML, display
from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder

def show_video_of_model(agent, env):
  state, _ = env.reset()
  done = False
  frames = []
  while not done:
    frame = env.render()
    frames.append(frame)
    action = agent.act(state)
    state, reward, done, _, _ = env.step(action[0])
  env.close()
  imageio.mimsave('video.mp4', frames, fps=30)

show_video_of_model(agent, env)

def show_video():
    mp4list = glob.glob('*.mp4')
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        display(HTML(data='''<video alt="test" autoplay
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
    else:
        print("Could not find video")

show_video()

