In [31]:
import cv2
import numpy as np
import random
import gymnasium as gym
from gymnasium.spaces import Box
from gymnasium import ObservationWrapper
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributions as distributions
from torch.distributions import Categorical
import tqdm
import cv2

Creating the Architecture

In [20]:
class Network(nn.Module):
    def __init__(self, action_size) -> None:
        super(Network, self).__init__() # Activate inheritance
        self.conv1 = torch.nn.Conv2d(in_channels=4,out_channels=32, kernel_size=(3, 3), stride=2) # we will have a stack of 4 grey scale images
        self.conv2 = torch.nn.Conv2d(in_channels=32,out_channels=32, kernel_size=(3, 3), stride=2)
        self.conv3 = torch.nn.Conv2d(in_channels=32,out_channels=32, kernel_size=(3, 3), stride=2)
        self.flatten = torch.nn.Flatten()
        # 1st fully connected layer
        self.fc1 = torch.nn.Linear(32*4*4, 128)
        # we'll have 2 output layers, 1 for the action values (actor) and other for state value (critic)
        self.fc2a = torch.nn.Linear(128, action_size)
        self.fc2s = torch.nn.Linear(128, 1)

    def forward(self, state):
        x = self.conv1(state)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        action_values = self.fc2a(x)
        state_value = self.fc2s(x)[0] # gets the value instead of the whole array
        return action_values, state_value

Setting up the Environment

In [21]:
class PreprocessAtari(ObservationWrapper):

  def __init__(self, env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4):
    super(PreprocessAtari, self).__init__(env)
    self.img_size = (height, width)
    self.crop = crop
    self.dim_order = dim_order
    self.color = color
    self.frame_stack = n_frames
    n_channels = 3 * n_frames if color else n_frames
    obs_shape = {'tensorflow': (height, width, n_channels), 'pytorch': (n_channels, height, width)}[dim_order]
    self.observation_space = Box(0.0, 1.0, obs_shape)
    self.frames = np.zeros(obs_shape, dtype = np.float32)

  def reset(self):
    self.frames = np.zeros_like(self.frames)
    obs, info = self.env.reset()
    self.update_buffer(obs)
    return self.frames, info

  def observation(self, img):
    img = self.crop(img)
    img = cv2.resize(img, self.img_size)
    if not self.color:
      if len(img.shape) == 3 and img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img.astype('float32') / 255.
    if self.color:
      self.frames = np.roll(self.frames, shift = -3, axis = 0)
    else:
      self.frames = np.roll(self.frames, shift = -1, axis = 0)
    if self.color:
      self.frames[-3:] = img
    else:
      self.frames[-1] = img
    return self.frames

  def update_buffer(self, obs):
    self.frames = self.observation(obs)

def make_env():
  env = gym.make("KungFuMasterDeterministic-v0", render_mode = 'rgb_array')
  env = PreprocessAtari(env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4)
  return env

env = make_env()

state_shape = env.observation_space.shape
number_actions = env.action_space.n
print("State shape:", state_shape)
print("Number actions:", number_actions)
print("Action names:", env.env.env.get_action_meanings())

State shape: (4, 42, 42)
Number actions: 14
Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']


Initialize Hyperparameters

In [22]:
learning_rate = 1e-4
discount_factor = 0.9
number_environments = 10 # basically 10 different agents working seperately

Implementing A3C class

In [23]:
class Agent():
    def __init__(self, action_size):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.action_size = action_size
        # There is only one network, unlike DQN and DCQN which had a local and target network
        self.network = Network(action_size).to(self.device)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr = learning_rate)

    def act(self, state):
        # we always need an extra dimension for the batches
        if state.ndim == 3:
            state = [state] # because state variable is not a pytorch tensor yet
        # now convert the state variable to a torch tensor
        state = torch.tensor(state, dtype = torch.float32, device = self.device)
        action_values, _ = self.network(state)
        # implement Softman policy as the action-selection strategy
        policy = F.softmax(action_values, dim = -1)

        return np.array([np.random.choice(len(p), p = p) for p in policy.detach().cpu().numpy()])
        
    def step(self, state, action, reward, next_state, done): # will be done in batches
        batch_size = state.shape[0]
        state = torch.tensor(state, dtype = torch.float32, device = self.device)
        next_state = torch.tensor(next_state, dtype = torch.float32, device = self.device)
        reward = torch.tensor(reward, dtype = torch.float32, device = self.device)
        done = torch.tensor(done, dtype = torch.bool, device = self.device).to(dtype = torch.float32)
        action_values, state_value = self.network(state)
        _, next_state_value = self.network(next_state) # returns next action and next state, but we are not interested in the next action
        
        # Bellman equation
        target_state_value = reward + discount_factor * next_state_value * (1 - done)
        
        # Implement Advantage
        advantage = target_state_value - state_value 
        
        # actor loss calculation 
        probs = F.softmax(action_values, dim = -1)
        logprobs = F.log_softmax(action_values, dim = -1)
        entropy = -torch.sum(probs * logprobs, axis = -1)
        batch_idx = np.arange(batch_size)
        logp_actions = logprobs[batch_idx, action]
        actor_loss = -(logp_actions * advantage.detach()).mean() - 0.001 * entropy.mean() # the small value is multiplied to balance the importance of the entropy
        
        # critic loss (mse loss between target state value and the state value)
        critic_loss = F.mse_loss(target_state_value.detach(), state_value)

        total_loss = actor_loss + critic_loss

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

Initialize the A3C Agent

In [24]:
agent = Agent(number_actions)

Evaulate agent on one episode

In [25]:
def evaluate(agent, env, n_episodes = 1):
    episodes_rewards = []
    for _ in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        while True: # seuence of instructions for agent
            # step 1 -> play an action
            action = agent.act(state)
            # step 2 -> get the next state, reward and status (done or not) and update total reward
            state, reward, done, info, _ = env.step(action[0])
            total_reward += reward
            # step 3 -> if done then finish episode
            if done:
                break
        
        episodes_rewards.append(total_reward)
    
    return episodes_rewards

Managing Multiple Environments Simultaenously

In [26]:
class EnvBatch:
  # initializing multiple environments
  def __init__(self, n_envs = 10):
    self.envs = [make_env() for _ in range(n_envs)]

  # resetting multiple environments
  def reset(self):
    _states = []
    for env in self.envs:
      _states.append(env.reset()[0])
    return np.array(_states)

  # stepping multiple agents in multiple environments
  def step(self, actions):
    next_states, rewards, dones, infos, _ = map(np.array, zip(*[env.step(a) for env, a in zip(self.envs, actions)]))
    for i in range(len(self.envs)):
      if dones[i]:
        next_states[i] = self.envs[i].reset()[0] # we only want the state
    return next_states, rewards, dones, infos

Training multiple agents in multiple environments (Asynchronous)

In [27]:
import tqdm

env_batch = EnvBatch(number_environments)
batch_states = env_batch.reset()

with tqdm.trange(0, 3001) as progress_bar:
  for i in progress_bar:
    batch_actions = agent.act(batch_states)
    batch_next_states, batch_rewards, batch_dones, _ = env_batch.step(batch_actions)
    batch_rewards *= 0.01
    agent.step(batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones)
    batch_states = batch_next_states
    if i % 1000 == 0:
      print("Average agent reward: ", np.mean(evaluate(agent, env, n_episodes = 10)))

  critic_loss = F.mse_loss(target_state_value.detach(), state_value)
  1%|          | 19/3001 [00:10<19:01,  2.61it/s] 

Average agent reward:  320.0


 34%|███▍      | 1027/3001 [00:27<05:20,  6.15it/s]

Average agent reward:  580.0


 67%|██████▋   | 2023/3001 [00:48<02:58,  5.47it/s] 

Average agent reward:  530.0


100%|██████████| 3001/3001 [01:09<00:00, 43.01it/s] 

Average agent reward:  680.0





Visualizing the results

In [32]:
# Reset the environments for visualization
batch_states = env_batch.reset()

# Record observations during agent's interaction for visualization
recorded_observations = []

# Run the agent in the environment for visualization
for i in range(500):  # Assuming you want to visualize the first 500 steps
    batch_actions = agent.act(batch_states)
    batch_next_states, batch_rewards, batch_dones, _ = env_batch.step(batch_actions)
    batch_states = batch_next_states

    # Record the observations for visualization
    for obs in batch_next_states:
        recorded_observations.append(obs)

# Use OpenCV to visualize recorded observations
for obs in recorded_observations:
    # Render the observation as an image
    obs_image = cv2.cvtColor(obs.transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
    cv2.imshow('Kung Fu Environment', obs_image)
    cv2.waitKey(50)  # Adjust the wait time (in milliseconds) as needed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Clean up OpenCV windows
cv2.destroyAllWindows()
