# Part I [Total: 60 points] - Implementing Advantage Actor Critic (A2C) and Solving Simple Environment

In [5]:
import torch
from torch import nn
import gymnasium as gym
from itertools import count

In [None]:
class Actor(nn.Module):

    def __init__(self,
                 n_states,
                 n_actions):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(n_states, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
            nn.Softmax(dim=-1),
            nn.max(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

In [None]:
class Critic(nn.Module):

    def __init__(self,
                 n_states,
                 n_actions):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(n_states, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

    def forward(self, x):
        return self.fc(x)

In [None]:
env = gym.make("CartPole-v1")
state, _ = env.reset()
del env

global_actor = Actor(n_states=len(state),
                     n_actions=env.action_space.n)
global_critic = Critic(n_states=len(state),
                       n_actions=env.action_space.n)

In [3]:
class Agent():
    def __int__(self,
                env,
                device,
                lr,
                discount_factor):
        self.env = gym.make(env)
        state, _ = self.env.reset()
        self.device = device
        self.discount_factor = discount_factor

        self.actor = Actor(n_states=len(state),
                           n_actions=env.action_space.n).to(device)
        self.critic = Critic(n_states=len(state),
                             n_actions=env.action_space.n).to(device)

        self.actor_loss_fn = nn.SmoothL1Loss()
        self.critic_loss_fn = nn.MSELoss()

        self.actor_optimizer = torch.optim.AdamW(self.actor.parameters(),
                                                 lr=lr,
                                                 amsgrad=True)
        self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(),
                                                  lr=lr,
                                                  amsgrad=True)

        global global_actor
        global global_critic

    def reset_gradients(self):
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()

    def synchronize_threads(self):
        self.actor.load_state_dict(global_actor.state_dict())
        self.critic.load_state_dict(global_critic.state_dict())

    def run_episode(self, state):
        episode_rewards = []
        # episode_observations = []
        episode_states = []

        for step in count():
                action = self.actor.forward(state)
                observation, reward, terminated, truncated, _ = self.env.step(action.item())

                reward = torch.tensor([[reward]],
                                      dtype=torch.float32,
                                      device=self.device)
                # if terminated:
                #     observation = None
                # else:
                #     observation = torch.tensor(observation,
                #                                dtype=torch.float32,
                #                                device=self.device)

                episode_rewards.append(reward)
                # episode_observations.append(observation)
                episode_states.append(state)

                if terminated:
                    R = torch.tensor(0)
                    break
                elif truncated:
                    R = self.critic.forward(observation)
                    break
                else:
                    state = torch.tensor(observation,
                                         dtype=torch.float32,
                                         device=self.device)
        
        return episode_states, episode_rewards, step, R

    def accumulate_gradients(self, episode_states, episode_rewards, episode_steps, R):
        for t in range(episode_steps, -1, -1):
                R *= self.discount_factor
                R += episode_rewards[t]

                state_value = self.critic.forward(episode_states[t])

                actor_loss = 0
                actor_loss.backward()
                # self.actor_optimizer.zero_grad()
                # self.actor_optimizer.step()

                critic_loss = self.critic_loss_fn(state_value, R)
                critic_loss.backward()
                # self.critic_optimizer.zero_grad()
                # self.critic_optimizer.step()
    

    def train(self, episodes):
        for episode in range(episodes):
            self.reset_gradients()
            self.synchronize_threads()

            state, _ = self.env.reset()
            state = torch.tensor(state,
                                 dtype=torch.float32,
                                 device=self.device)
            
            episode_states, episode_rewards, episode_steps, R = self.run_episode(state)
            self.accumulate_gradients(episode_states, episode_rewards, episode_steps, R)            

IndentationError: expected an indented block (1317544345.py, line 60)

In [7]:
a = torch.tensor(0)
a.item()

0