In [1]:
import gymnasium as gym
import ale_py
import torch
import time

from agents.a2c import A2C
from envs.ale_utils import FrameStack, setup_training_dir, load_checkpoint, save_checkpoint, eval_model, generate_video, save_plots

  from pkg_resources import resource_stream, resource_exists


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) 

cpu


In [3]:
num_envs = 8
n_frame_stack = 4
n_steps = 10
max_timesteps = 10000000
gamma = .99
lr = 2.5e-4
c_actor = 1
c_critic = .25
c_entropy = .01
max_grad_norm = .5
checkpoint_frequency = 10000
video_frequency = 10000
eval_frequency = 5000
n_episodes_eval = 10

In [4]:
envs = gym.make_vec("ALE/Breakout-v5", num_envs=num_envs, vectorization_mode="sync")

In [5]:
resume_training = True
version = "v3"
checkpoint = f"training/a2c/{version}/training3/2350000.pth"
training_number = setup_training_dir(resume_training, "a2c", version)

max_training_time = .1 #h

In [6]:
model = A2C(input_channels=n_frame_stack, n_actions=4, 
            gamma=gamma, max_grad_norm=max_grad_norm, 
            c_actor=c_actor, c_critic=c_critic, c_entropy=c_entropy, device=device)
framestack = FrameStack(num_envs, n_frame_stack, 84, 84, device)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
ev_states = torch.load("ev_states/breakout_ev_states.pt")

In [7]:
if resume_training:
    training_vars = load_checkpoint(model, optimizer, checkpoint, device)
    timestep_start, losses, avg_returns = training_vars
else:
    timestep_start = 0
    losses = []
    avg_returns = []

In [8]:
obs, infos = envs.reset()
state = framestack.reset(obs)

current_lives = infos['lives'] + 1 # Set the number of lives to play FIRE on first frame
logits_buffer, log_probs_buffer, values_buffer = [], [], []
rewards_buffer, next_values_buffer, dones_buffer = [], [], []

start_time = time.time()
for timestep in range(timestep_start, max_timesteps):

  
    actor_logits, value = model(state)
    m = torch.distributions.Categorical(logits=actor_logits)
    action = m.sample()
    
    log_prob = m.log_prob(action)

    obs, reward, terminated, truncated, infos = envs.step(action)
    next_state = framestack.step(obs)
    
    done = terminated | truncated

    with torch.no_grad():
        _, next_value = model(next_state) 

    logits_buffer.append(actor_logits) # tensor (n_env, n_actions)
    log_probs_buffer.append(log_prob) # tensor (n_env)
    values_buffer.append(value.squeeze(-1)) # tensor (n_env)
    rewards_buffer.append(reward) # np_array (n_env)
    next_values_buffer.append(next_value.squeeze(-1)) # tensor (n_env) Detach from computational graph because only used for bootstraping
    dones_buffer.append(done.astype(float)) # np_array (n_env)

    state = next_state
    
    if (timestep + 1) % n_steps == 0:
        update_losses = model.update(optimizer, logits_buffer, log_probs_buffer, values_buffer, rewards_buffer, next_values_buffer, dones_buffer)
        # losses.append(update_losses)
        logits_buffer, log_probs_buffer, values_buffer = [], [], [] # Clear buffers
        rewards_buffer, next_values_buffer, dones_buffer = [], [], []

    if (timestep + 1) % eval_frequency == 0:
        avg_return = eval_model(model, "ALE/Breakout-v5", n_episodes_eval, device)
        avg_returns.append(avg_return)
        print(f"Average return after {timestep + 1} timesteps : {avg_return}")
        save_plots([], avg_returns, f"training/a2c/{version}/training{training_number}", timestep+1, eval_frequency, plot_losses=False)
    
    if (timestep + 1) % video_frequency == 0:
        generate_video(model, "ALE/Breakout-v5", f"training/a2c/{version}/training{training_number}/{timestep+1}.mp4", device)

    if (timestep + 1) % checkpoint_frequency == 0:
        save_checkpoint(model, optimizer, timestep, losses, avg_returns, f"training/a2c/{version}/training{training_number}/{timestep+1}.pth")

    if time.time() - start_time > 3600 * max_training_time:
        print(f"Maximum training time of {max_training_time}h exceeded. Interrupting training after {timestep} timesteps.")
        break 


Average return after 2350000 timesteps : 90.2
Average return after 2355000 timesteps : 106.3
Average return after 2360000 timesteps : 113.5
Maximum training time of 0.1h exceeded. Interrupting training after 2361282 timesteps.


After 2350000 timesteps : Out of Memory because of losses
-> Stop saving ang logging them

Time per timestep : (with 4 envs and n_step=10)
- Eval (+ plots) (/eval_frequency) : 4s for 10 bad ep
- Backward (*1/n_step) : 1e-1
- Forward (*2) : 2e-3
- Step : 3e-3

In [9]:
assert 1 == 0

AssertionError: 

## Benchmark

Run 1 episode to estimate time of different actions each frame

In [None]:
# step_times = []
# forward_times = []
# backward_times = []


# tot_t_start = time.time()

# done = False
# doFire = True # Start game by using FIRE
# ep_return = 0
# last_frames = deque(maxlen=frame_stack)
# logits, log_probs, values, rewards, next_values, dones = [], [], [], [], [], []

# frame, info = env.reset()
# current_lives = info['lives']
# phi_frame = preprocess_frame(frame)

# # Initially, fill the last_frames buffer with the first frame
# for _ in range(frame_stack):
#     last_frames.append(phi_frame)

# state = get_state(last_frames, device)

# while not done:
    
#     tic = time.time()
#     actor_logits, value = model(state)
#     forward_times.append(time.time() - tic)
#     m = torch.distributions.Categorical(logits=actor_logits)
#     action = m.sample()
#     log_prob = m.log_prob(action)

#     if doFire: # Do FIRE action if just lost a life to launch back the game
#         action = torch.tensor([1])
#         log_prob = m.log_prob(action)
#         doFire = False

#     tic = time.time()
#     frame, reward, done, truncated, info = env.step(action.item())
#     step_times.append(time.time() - tic)

#     if info['lives'] > current_lives: # Do FIRE next frame if just lost a life
#         doFire = True
#         current_lives = info['lives']

#     phi_frame = preprocess_frame(frame)
#     last_frames.append(phi_frame) # Automatically removes the oldest frame
#     next_state = get_state(last_frames)

#     ep_return += reward

#     with torch.no_grad():
#         _, next_value = model(next_state)

#     logits.append(actor_logits.squeeze(0))
#     log_probs.append(log_prob)
#     values.append(value)
#     rewards.append(reward)
#     next_values.append(next_value)
#     dones.append(float(done))

#     state = next_state

#     if len(log_probs) == batch_size or done:
        
#         tic = time.time()
#         update_network(optimizer, logits, log_probs, values, rewards, next_values, dones, gamma, c_actor, c_critic, c_entropy)
#         backward_times.append(time.time() - tic)
#         logits, log_probs, values, rewards, next_values, dones = [], [], [], [], [], [] # Clear buffers

# returns.append(ep_return)

# model.eval()
# with torch.no_grad():
#     _, values = model(torch.cat(ev_states, dim=0)) # Evaluate average value on evaluation states
# avg_values.append(values.mean().item())
# model.train()

# tot_time = time.time() - tot_t_start    

# print(f"Total episode time : {tot_time}")
# print(f"Average time by forward pass : {np.mean(forward_times)}")
# print(f"Average time by backward pass : {np.mean(backward_times)}")
# print(f"Average time by env step : {np.mean(step_times)}")

The actions take approximately on my CPU (seconds) : 

- 1e-1 / optimizer step
- 1e-3 / forward pass
- 5e-4 / env step (emulator)  

Clearly, the backward pass and optimization steps are the most time-consuming training operations, even though they are only performed every 20 steps. Using a GPU would definitely speed up the agent training.

---

In [None]:
assert 1 == 0

AssertionError: 

Colab Code

In [None]:
!pip install gymnasium[atari,accept-rom-license] ale-py torch torchvision imageio

In [None]:
!git clone https://github.com/LucasSchummer/RL_ALE

In [None]:
cd RL_ALE

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
print(torch.cuda.is_available())  # should print True
print(torch.cuda.get_device_name(0))