# Part I [Total: 60 points] - Implementing Advantage Actor Critic (A2C) and Solving Simple Environment

In [1]:
import torch
from torch import nn
import gymnasium as gym
from itertools import count
from torch import multiprocessing as mp
import math
import random

In [2]:
mp.set_start_method('fork')

In [3]:
class ActorCritic(nn.Module):

    def __init__(self,
                 n_states,
                 n_actions):
        super().__init__()
        self.shared_fc = nn.Sequential(
            nn.Linear(n_states, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 64),
            nn.ReLU()
        )
        self.actor = nn.Sequential(
            nn.Linear(64, n_actions),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Linear(64, 1)

    def forward(self, x):
        x = self.shared_fc(x)
        return self.actor(x), self.critic(x)

In [4]:
class A3CAgent():
    def __init__(self,
                 env_name,
                 global_model,
                 num_workers,
                 device,
                 lr,
                 max_episodes,
                 discount_factor):
        self.env_name = env_name
        self.global_model = global_model
        self.num_workers = num_workers
        
        self.lr = lr
        self.discount_factor = discount_factor
        self.max_episodes = max_episodes
        self.device = device

        self.episode_rewards = []
        self.episode_steps = []
        self.episode_losses = []

        print("Finished initializing of A3C agent")
    
    def train(self):
        envs = [gym.make(self.env_name) for _ in range(self.num_workers)]
        worker_models = [ActorCritic(n_states=envs[0].observation_space.shape[0],
                                     n_actions=envs[0].action_space.n)
                                     for _ in range(self.num_workers)]
        
        for i, (env, worker_model) in enumerate(zip(envs, worker_models)):
            #  print("Entered multiprocess loop")
             worker = mp.Process(target=self.train_worker,
                                 args=(env,
                                       self.global_model,
                                       worker_model,
                                       self.lr,
                                       self.max_episodes,
                                       i))
             worker.start()
            #  print("Executed worker")

    def train_worker(self,
                     env,
                     global_model,
                     worker_model,
                     lr,
                     max_episodes,
                     worker_name):
        print(f"Worker {worker_name} initiated")
        ''' Kishorkumar Devasenapathy - 04-15-2024
        # This optimizer is only in the scope of this worker but has a shared copy of the global model's
        # parameters to perform gradient descent and update
        # torch.multiprocessing takes care of synchronizing this step across the workers
        '''
        optimizer = torch.optim.Adam(global_model.parameters(),
                                     lr=lr)
        print(f"Worker {worker_name} initialized optimizer for global model")
        actor_loss_fn = nn.LogSoftmax()
        critic_loss_fn = nn.MSELoss()

        state, _ = env.reset()
        state = torch.tensor(state,
                             dtype=torch.float32,
                             device=self.device).unsqueeze(0)
        print(f"Worker {worker_name} interacted first time with environment")
        episode_reward = 0
        episode_loss = 0

        for _ in range(max_episodes):
            print(f"Worker {worker_name} started an episode")
            worker_model.load_state_dict(global_model.state_dict()) # synchronize worker's model with that of global model

            state, _ = env.reset()
            state = torch.tensor(state,
                                 dtype=torch.float32,
                                 device=self.device)
            print(state.shape)

            for step in count():
                action_probs, current_state_value = worker_model.forward(state)
                print(action_probs.shape)
                action = torch.max(action_probs,
                                   dim=-1).indices
                print(action.shape)

                next_state, reward, terminated, truncated, _ = env.step(action.item())
                next_state = torch.tensor(state,
                                          dtype=torch.float32,
                                          device=self.device)
                _, next_state_value = worker_model.forward(next_state)

                episode_reward += reward

                # reward = torch.tensor([reward],
                #                       dtype=torch.float32,
                #                       device=self.device)
                # print(reward.shape)

                # reward *= 5

                if terminated:
                    target_current_state_value = reward
                else:
                    target_current_state_value = reward + self.discount_factor*next_state_value.item()

                # actor_loss = -torch.log_softmax(action_probs,
                #                                 dim=-1)
                # actor_loss = actor_loss_fn(action_probs)
                # critic_loss = critic_loss_fn(target_current_state_value, current_state_value.item())
                advantage = target_current_state_value - current_state_value

                actor_loss = -torch.log_softmax(action_probs, dim=-1)[action] * advantage
                critic_loss = 0.5 * advantage**2

                episode_loss += (actor_loss + critic_loss)
                
                state = next_state

                if terminated or truncated:
                    print(f"Worker {worker_name} completed an episode")
                    self.episode_rewards.append(episode_reward)
                    self.episode_steps.append(step+1)
                    self.episode_losses.append(episode_loss)
                    break
            
            optimizer.zero_grad()
            episode_loss.backward()
            optimizer.step()
        
        print(f"Worker {worker_name} completed")

           

In [5]:
ENV_NAME        = "CartPole-v1"
NUM_WORKERS     = 2

LR              = 1e-4
MAX_EPISODES    = 5

DISCOUNT_FACTOR = 0.6

if torch.cuda.is_available():
    DEVICE = "cuda"
# elif torch.backends.mps.is_available():
#     DEVICE = "mps"
else:
    DEVICE = "cpu"


In [6]:
env = gym.make(ENV_NAME)
global_model = ActorCritic(n_states=env.observation_space.shape[0],
                           n_actions=env.action_space.n)
global_model.share_memory()

trainer = A3CAgent(env_name=ENV_NAME,
                   global_model=global_model,
                   num_workers=NUM_WORKERS,
                   device=DEVICE,
                   lr=LR,
                   max_episodes=MAX_EPISODES,
                   discount_factor=DISCOUNT_FACTOR)

Finished initializing of A3C agent


In [None]:
trainer.train()

Worker 0 initiated

Worker 1 initiated

Worker 1 initialized optimizer for global modelWorker 0 initialized optimizer for global model

Worker 0 interacted first time with environmentWorker 1 interacted first time with environment

Worker 1 started an episodeWorker 0 started an episode

torch.Size([4])torch.Size([4])

torch.Size([2])torch.Size([2])

torch.Size([])torch.Size([])



  next_state = torch.tensor(state,
  next_state = torch.tensor(state,


torch.Size([2])torch.Size([2])

torch.Size([])torch.Size([])

torch.Size([2])torch.Size([2])

torch.Size([])torch.Size([])

torch.Size([2])torch.Size([2])

torch.Size([])torch.Size([])

torch.Size([2])torch.Size([2])

torch.Size([])torch.Size([])

torch.Size([2])torch.Size([2])

torch.Size([])torch.Size([])

torch.Size([2])torch.Size([2])

torch.Size([])torch.Size([])

torch.Size([2])torch.Size([2])

torch.Size([])torch.Size([])

Worker 0 completed an episodetorch.Size([2])

torch.Size([])
torch.Size([2])
torch.Size([])
Worker 1 completed an episode


objc[39686]: +[MPSGraphObject initialize] may have been in progress in another thread when fork() was called.
objc[39686]: +[MPSGraphObject initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
objc[39687]: +[MPSGraphObject initialize] may have been in progress in another thread when fork() was called.
objc[39687]: +[MPSGraphObject initialize] may have been in progress in another thread when fork() was called. We cannot safely call it or ignore it in the fork() child process. Crashing instead. Set a breakpoint on objc_initializeAfterForkError to debug.
