In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from training import *
from models import *
from A2C_agent import *
from helpers import *

%load_ext autoreload
%autoreload 2

In [None]:
batch_size = 256
gamma_ = 0.99
lr_actor = 1e-5
lr_critic = 1e-3
eps = 0.1

input_size=4
hidden_size = 64
output_size_actor=2
output_size_critic=1
num_workers=10
num_episodes = 10000
max_steps_per_episode = 1000

device = device_selection() # mps -> cuda -> cpu

# Initialize environment
env = gym.make('CartPole-v1') 

# Initialize agent 
agent = Agent(input_size, hidden_size, \
                output_size_actor, output_size_critic, \
                eps, gamma_, lr_actor, lr_critic, num_workers, \
                device=device)

# Initialize batch
batch = []

for episode in range(num_episodes):
    state, _ = env.reset()
    state = torch.from_numpy(state).float().to(device)  # Convert state to a tensor

    episode_reward = 0

    for t in range(max_steps_per_episode):
        action = agent.select_action(state, worker_id=0, policy="eps-greedy")
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        
        next_state = torch.from_numpy(next_state).float().to(device)  # Convert next_state to a tensor
        done = terminated or truncated
        episode_reward += reward

        # Add the experience to the batch
        batch.append((state, action, reward, next_state, done))

        if len(batch) >= batch_size :
            # Train the agent
            worker_losses = agent.train({0:batch}, agent.gamma, agent.lr_actor, agent.lr_critic, agent.device)
            xx=batch
            # Clear the batch
            batch.clear()

        state = next_state
        if done:
            break

    if episode % 100 == 0:
        print(f"Episode {episode} finished after {t+1} steps with reward {episode_reward:.2f}")



Episode 0 finished after 38 steps with reward 38.00
Episode 100 finished after 40 steps with reward 40.00
Episode 200 finished after 24 steps with reward 24.00
Episode 300 finished after 44 steps with reward 44.00
Episode 400 finished after 24 steps with reward 24.00
Episode 500 finished after 19 steps with reward 19.00
Episode 600 finished after 25 steps with reward 25.00
Episode 700 finished after 30 steps with reward 30.00
Episode 800 finished after 43 steps with reward 43.00
Episode 900 finished after 19 steps with reward 19.00
Episode 1000 finished after 29 steps with reward 29.00
Episode 1100 finished after 16 steps with reward 16.00
Episode 1200 finished after 31 steps with reward 31.00
Episode 1300 finished after 58 steps with reward 58.00
Episode 1400 finished after 21 steps with reward 21.00
Episode 1500 finished after 19 steps with reward 19.00
Episode 1600 finished after 44 steps with reward 44.00
Episode 1700 finished after 23 steps with reward 23.00
Episode 1800 finished 

## TO IMPLEMENT:

- fix training for multiple workers (wrong computations for target and advantage probably!)
- methods for plotting and training observation (cf 2.3 of project pdf)
- methods for saving agent 

In [15]:
[x[1] for x in xx]

[tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([1], device='mps:0'),
 tensor([1], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0], device='mps:0'),
 tensor([0