In [1]:
# %pip install "gymnasium[atari, accept-rom-license, other]"

In [2]:
import gymnasium as gym
from gymnasium.utils.play import play
from gymnasium.vector import SyncVectorEnv, AsyncVectorEnv
from gymnasium.wrappers import AtariPreprocessing, FrameStack
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np

from utils import DQN, ReplayMemory

from hyperparameters import (
    NUM_ENVS,
    BATCH_SIZE,
    LEARNING_RATE,
    TRAIN_STEPS,
    MIN_REPLAY_MEMORY_SIZE,
    MAX_REPLAY_MEMORY_SIZE,
    UPDATE_TARGET_NETWORK,
    MAX_EPSILON,
    MIN_EPSILON,
    EPSILON_DECAY,
    PRINT_LOGS_STEPS,
    DEVICE,
    SAVE_PATH,
    SAVE_INTERVAL
)

In [3]:
def make_env():
    def _init():
        env = gym.make("ALE/Breakout-v5")
        env = AtariPreprocessing(
            env,
            frame_skip=1,  # ALE/Breakout-v5 already uses frame_skip=4
            screen_size=84,
            grayscale_obs=True,
            noop_max=30,
        )
        env = FrameStack(env, 4)
        return env

    return _init


if DEVICE == "cpu":
    envs = SyncVectorEnv([make_env() for _ in range(NUM_ENVS)])
else:
    envs = AsyncVectorEnv([make_env() for _ in range(NUM_ENVS)])


policy_network = DQN((4, 84, 84), 4)
policy_network.to(DEVICE)
target_network = DQN((4, 84, 84), 4)
target_network.to(DEVICE)
target_network.load_state_dict(policy_network.state_dict())
target_network.eval()

replay_memory = ReplayMemory(MAX_REPLAY_MEMORY_SIZE, DEVICE)
optimizer = optim.Adam(policy_network.parameters(), lr=LEARNING_RATE)

summary_writer = SummaryWriter("./logs/atari_vanilla")

In [4]:
# Initialize Replay Buffer with random experiences
states, _ = envs.reset()
for _ in tqdm(range(MIN_REPLAY_MEMORY_SIZE)):
  actions = envs.action_space.sample()
  result = envs.step(actions)
  new_states, rewards, terminated, truncated, _ = result
  
  for state, action, reward, new_state, ter, trunc in zip(states, actions, rewards, new_states, terminated, truncated):
    experience = (state, action, reward, new_state, ter or trunc)
    replay_memory.append(experience)

  states = new_states
    

In [5]:
policy_network.train()
    
reward_logs = [[] for _ in range(NUM_ENVS)]
total_reward = [0 for _ in range(NUM_ENVS)]
loss_logs = []

states, _ = envs.reset()

step = 0
while step < TRAIN_STEPS:
  epsilon = np.interp(step, [0, EPSILON_DECAY], [MAX_EPSILON, MIN_EPSILON])

  # perform action and store experience in replay Memory
  actions = policy_network.act(states, epsilon)

  # perform and observe
  new_states, rewards, terminated, truncated, _ = envs.step(actions)
  
  for i, (state, action, reward, new_state, ter, trunc) in enumerate(zip(states, actions, rewards, new_states, terminated, truncated)):
    experience = (state, action, reward, new_state, ter or trunc)
    replay_memory.append(experience)
    
    total_reward[i] += reward
    if ter or trunc:
      reward_logs[i].append(total_reward[i])
      total_reward[i] = 0
      
  states = new_states

  """ TRAIN DQN """
  batch = replay_memory.sample(BATCH_SIZE)
  loss = policy_network.compute_loss(batch, target_network)
  
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  loss_logs.append(loss.item())

  
  if step % UPDATE_TARGET_NETWORK < NUM_ENVS:
    target_network.load_state_dict(policy_network.state_dict())


  if step % PRINT_LOGS_STEPS < NUM_ENVS and step != 0:
    avg_reward = np.mean([item for sublist in reward_logs for item in sublist[-PRINT_LOGS_STEPS:]])
    avg_reward = avg_reward if not np.isnan(avg_reward) else 0

    avg_loss = np.mean(loss_logs[-PRINT_LOGS_STEPS:])

    print(f"Episode {step}/{TRAIN_STEPS}) reward: {avg_reward:.4f} - loss: {avg_loss:.4f}")

    summary_writer.add_scalar("AVG Reward", avg_reward, global_step=step)
  
  if step % SAVE_INTERVAL < NUM_ENVS and step != 0:
    print("Saving...")
    torch.save(policy_network.state_dict(), SAVE_PATH)
  
  step += 4

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Episode 100/10000) 0.0000
Episode 200/10000) 0.0000
Episode 300/10000) 0.0000
Episode 400/10000) 0.0000
Episode 500/10000) 0.0000
Episode 600/10000) 0.0000
Episode 700/10000) 0.0000
Episode 800/10000) 1.2500
Episode 900/10000) 1.2500
Episode 1000/10000) 1.2500
Episode 1100/10000) 1.2500
Episode 1200/10000) 1.0000
Episode 1300/10000) 0.8333
Episode 1400/10000) 0.8571
Episode 1500/10000) 0.8571
Episode 1600/10000) 1.0000
Episode 1700/10000) 0.8889


KeyboardInterrupt: 