In [27]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from training import *
from models import *
from A2C_agent import *
from helpers import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
device = device_selection() # mps -> cuda -> cpu

# Initialize environment
env = gym.make('CartPole-v1') 

# hyperparameters
batch_size = 256
gamma_ = 0.99
lr_actor = 1e-5
lr_critic = 1e-3
eps = 0.1
num_workers=1
num_episodes = 10000
total_steps_budget = 500000
max_steps_per_episode = 1000

# neural network structure
input_size = env.observation_space.shape[0] # 4
hidden_size = 64
output_size_actor = env.action_space.n # 2
output_size_critic = 1

critic_losses = []
actor_losses = []


# Initialize agent 
agent = Agent(input_size, hidden_size, \
                output_size_actor, output_size_critic, \
                eps, gamma_, lr_actor, lr_critic, num_workers, \
                device=device)

# Initialize batch
batch = []

for episode in range(num_episodes):
    state, _ = env.reset()
    state = torch.from_numpy(state).float().to(device)  # Convert state to a tensor

    episode_reward = 0

    for t in range(max_steps_per_episode):
        action = agent.select_action(state, worker_id=0, policy="eps-greedy")
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        agent.num_steps += 1
        
        next_state = torch.from_numpy(next_state).float().to(device)  # Convert next_state to a tensor
        done = terminated or truncated
        episode_reward += reward

        # Add the experience to the batch
        batch.append((state, action, reward, next_state, done))

        if len(batch) >= batch_size or done:
            # this line is to make batch compatible with the training. 
            # Should be a dict with keys as worker_ids when num_workers>1
            batches_dict = {0: batch}
            # Train the agent
            critic_loss, actor_loss = agent.train(batches_dict, agent.gamma, agent.lr_actor, agent.lr_critic, agent.device)
            
            critic_losses.append(critic_loss)
            actor_losses.append(actor_loss)
            # Clear the batch
            batch.clear()

        state = next_state
        if done: break

    if episode % 100 == 0:
        print(f"Episode {episode} finished after {t+1} steps with reward {episode_reward:.2f}")
        print(f"Actor loss: {actor_loss:.2f}, Critic loss: {critic_loss:.2f}",f"  Total steps: {agent.num_steps}")
        print("--------------------------------------------------")
    
    if (agent.num_steps >= total_steps_budget): 
        print(f"Reached total training budget of {total_steps_budget} steps ----> Stopping training at episode {episode}")
        break

agent.training_done()
agent.save("./A2C_cartpole")


Episode 0 finished after 10 steps with reward 10.00
Actor loss: 0.68, Critic loss: 1.02   Total steps: 10
--------------------------------------------------
Episode 100 finished after 10 steps with reward 10.00
Actor loss: 0.51, Critic loss: 3.48   Total steps: 1051
--------------------------------------------------
Episode 200 finished after 12 steps with reward 12.00
Actor loss: 0.36, Critic loss: 5.55   Total steps: 2019
--------------------------------------------------
Episode 300 finished after 8 steps with reward 8.00
Actor loss: 0.02, Critic loss: 4.99   Total steps: 2986
--------------------------------------------------
Episode 400 finished after 9 steps with reward 9.00
Actor loss: -0.03, Critic loss: 0.30   Total steps: 3968
--------------------------------------------------
Episode 500 finished after 9 steps with reward 9.00
Actor loss: -0.04, Critic loss: 0.08   Total steps: 4940
--------------------------------------------------
Episode 600 finished after 9 steps with re

KeyboardInterrupt: 

workers all use the same parameters and get updated with the same gradient at the same time (if not at the same time --> updates with different policies by definition, but A2C is on-policy). Doesn't matter if they (inevitably) end episodes at different time-steps. What matters is they have the same policy when accumulating a batch (same actor-critic params).

batches vanilla are (1x1), in general (num_time_steps x num_workers)