In [1]:
import cv2
import numpy as np
import random
import gymnasium as gym
from gymnasium.spaces import Box
from gymnasium import ObservationWrapper
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributions as distributions
from torch.distributions import Categorical

Creating the Architecture

In [5]:
class Network(nn.Module):
    def __init__(self, action_size) -> None:
        super(Network, self).__init__() # Activate inheritance
        self.conv1 = torch.nn.Conv2d(in_channels=4,out_channels=32, kernel_size=(3, 3), stride=2) # we will have a stack of 4 grey scale images
        self.conv2 = torch.nn.Conv2d(in_channels=32,out_channels=32, kernel_size=(3, 3), stride=2)
        self.conv3 = torch.nn.Conv2d(in_channels=32,out_channels=32, kernel_size=(3, 3), stride=2)
        self.flatten = torch.nn.Flatten()
        # 1st fully connected layer
        self.fc1 = torch.nn.Linear(512, 128)
        # we'll have 2 output layers, 1 for the action values (actor) and other for state value (critic)
        self.fc2a = torch.nn.Linear(128, action_size)
        self.fc2s = torch.nn.Linear(128, 1)

    def forward(self, state):
        x = self.conv1(state)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        action_values = self.fc2a(x)
        state_value = self.fc2s(x)[0] # gets the value instead of the whole array

        return action_values, state_value

Setting up the Environment

In [9]:
from gymnasium.core import Env


class Preprocess(ObservationWrapper):
    def __init__(self, env, height = 42, width = 42, crop = lambda img: img, dim_order = 'pytorch', color = False, n_frames = 4):
        super(Preprocess, self).__init__(env)
        self.img_size = (height, width)
        self.crop = crop
        self.dim_order = dim_order
        self.color = color
        self.frame_stack = n_frames
        n_channels = 3 * n_frames if color else n_frames
        obs_shape = {'tensorflow': (height, width, n_channels), 'pytorch': (height, width, n_channels)}[dim_order]
        self.observation_space = Box(0.0, 1.0, obs_shape)
        self.frames = np.zeros(obs_shape, dtype=np.float32)

    def reset(self):
        self.frames = np.zeros_like(self.frames)
        obs, info = self.env.reset()
        self.update_buffer(obs)
        
        return self.frames, info
    
    def observation(self, img):
        img = self.crop(img)
        img = cv2.resize(img, self.img_size)
        if not self.color:
            if len(img.shape) == 3 and img.shap[2] == 3:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img = img.astype('float32') / 255.
        
        if self.color:
            self.frames = np.roll(self.frames, shift=-3, axis=0)
            self.frames[-3: ] = img
        else:
            self.frames = np.roll(self.frames, shift=-1, axis=0)
            self.frames[-1: ] = img

        return self.frames
    
    def update_buffer(self, obs):
        self.frames = self.observation(obs)

def make_env():
    env = gym.make("KungFuMasterDeterministic-v4", render_mode="rgb_array")
    env = Preprocess(env, height=42, width=42, crop=lambda img: img, dim_order='pytorch', color=False, n_frames=4)
    
    return env

env = make_env()
state_shape = env.observation_space.shape
number_actions = env.action_space.n 

print('State Shape: ', state_shape)
print('Number of actions: ', number_actions)
print('Action Names: ', env.env.get_wrapper_attr('get_action_meanings'))

State Shape:  (42, 42, 4)
Number of actions:  14
Action Names:  <bound method AtariEnv.get_action_meanings of <shimmy.atari_env.AtariEnv object at 0x7fae8b5f3520>>


Initialize Hyperparameters

In [10]:
learning_rate = 1e-4
discount_factor = 0.9
number_environments = 10 # basically 10 different agents working seperately

Implementing A3C class

In [15]:
class Agent():
    def __init__(self, action_size):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.action_size = action_size
        # There is only one network, unlike DQN and DCQN which had a local and target network
        self.network = Network(action_size).to(self.device)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr = learning_rate)

    def act(self, state):
        # we always need an extra dimension for the batches
        if state.ndim == 3:
            state = [state] # because state variable is not a pytorch tensor yet
        # now convert the state variable to a torch tensor
        state = torch.tensor(state, dtype=torch.float32, device=self.device)
        action_values, _ = self.network(state)
        # implement Softman policy as the action-selection strategy
        policy = F.softmax(action_values, dim=-1)

        return np.array([np.random.choice(len(p), p = p) for p in policy.detach().cpu.numpy()])
    
    def step(self, state, action, reward, next_state, done): # will be done in batches
        batch_size = state.shape[0]
        state = torch.tensor(state, dtype=torch.float32, device=self.device)
        next_state = torch.tensor(next_state, dtype=torch.float32, device=self.device)
        reward = torch.tensor(reward, dtype=torch.float32, device=self.device)
        done = torch.tensor(done, dtype=torch.bool, device=self.device).to(torch.float32)
        action_values, state_value = self.network(state)
        _, next_state_value = self.network(next_state) # returns next action and next state, but we are not interested in the next action
        
        # Bellman equation
        target_state_value = reward + discount_factor * next_state_value * (1 - done)
        
        # Implement Advantage
        advamtage = target_state_value - state_value 
        
        # actor loss calculation 
        probs = F.softmax(action_values, dim=-1)
        logprobs = F.log_softmax(action_values, dim=-1)
        entropy = -torch.sum(probs * logprobs, axis=-1)
        batch_idx = np.arange(batch_size)
        logp_actions = logprobs[batch_idx, action]
        actor_loss = -(logp_actions * advamtage.detach()).mean() - 0.001 * entropy.mean() # the small value is multiplied to balance the importance of the entropy
        
        # critic loss (mse loss between target state value and the state value)
        critic_loss = F.mse_loss(target_state_value.detach(), state_value)

        total_loss = actor_loss + critic_loss

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

Initialize the A3C Agent

In [16]:
agent = Agent(action_size=number_actions)

  from .autonotebook import tqdm as notebook_tqdm
