# Reinforcement Learning with Atari Games

## 1. Initial Setup

In [1]:
from core.create import create_model, get_utility_params
from utils.helper import set_device
from utils.model_utils import load_model

import torch

In [2]:
# Get utility parameters from yaml file
util_params = get_utility_params()

# Set them as hyperparameters
NUM_EPISODES = util_params['num_episodes']
SAVE_EVERY = util_params['save_every']
print(f'NUM_EPISODES={NUM_EPISODES}, SAVE_EVERY={SAVE_EVERY}')

NUM_EPISODES=100000, SAVE_EVERY=1000


In [3]:
# Set CUDA device
device = set_device()

CUDA available. Device set to GPU -> 'cuda:0'


## 2. Model Creation and Training

### 2a. Rainbow Deep Q-Network (RDQN)

In [4]:
# Create Rainbow DQN instance
dqn = create_model('rainbow', device)

  logger.warn(


In [5]:
dqn.env_details

{'gym_name': 'ALE/SpaceInvaders-v5', 'name': 'SpaceInvaders', 'obs_space': Box(0, 255, (4, 128, 128), uint8), 'action_space': Discrete(6), 'input_shape': (4, 128, 128), 'n_actions': 6, 'img_size': 128, 'stack_size': 4, 'capture_video': False, 'record_every': 1000}

In [6]:
# Train model
dqn.train(num_episodes=4, print_every=1, save_count=2)

Training agent on SpaceInvaders with 4 episodes.
Buffer size: 1k, batch size: 32, max timesteps: 1k, num network updates: 4, replay period: 100.
(1/4)  Episode Score: 120,   Train Loss: 3.65424,  Time taken: 8.28 secs.
(2/4)  Episode Score: 120,   Train Loss: 3.78992,  Time taken: 9.46 secs.
Saved model at episode 2 as: 'rainbow_batch32_buffer1000_ep2.pt'.
Saved logger data to 'saved_models/rainbow_logger_data.tar.gz'. Total size: 609 bytes
(3/4)  Episode Score: 110,   Train Loss: 3.53139,  Time taken: 9.13 secs.
(4/4)  Episode Score: 100,   Train Loss: 2.42060,  Time taken: 5.96 secs.
Saved model at episode 4 as: 'rainbow_batch32_buffer1000_ep4.pt'.
Saved logger data to 'saved_models/rainbow_logger_data.tar.gz'. Total size: 657 bytes
Training complete. Access metrics from 'logger' attribute. 

In [7]:
dqn.logger

Available attributes: '['avg_returns', 'actions', 'train_losses', 'ep_scores']'

In [8]:
dqn.logger.actions

[Counter({4: 4720, 0: 928, 5: 2522, 1: 1552, 2: 974, 3: 568})]

### 2b. Proximal Policy Optimization (PPO)

In [9]:
# Create PPO instance
ppo = create_model('ppo', device)

In [10]:
ppo.env_details

{'gym_name': 'ALE/SpaceInvaders-v5', 'name': 'SpaceInvaders', 'obs_space': Box(0, 255, (4, 128, 128), uint8), 'action_space': Discrete(6), 'input_shape': (4, 128, 128), 'n_actions': 6, 'img_size': 128, 'stack_size': 4, 'capture_video': False, 'record_every': 1000}

In [11]:
torch.cuda.empty_cache()

In [12]:
PPO_NUM_EPISODES = ppo.params.rollout_size * ppo.params.num_agents * NUM_EPISODES
demo_episodes = int((PPO_NUM_EPISODES / NUM_EPISODES) * 4)

In [13]:
ppo.train(num_episodes=demo_episodes, print_every=1, save_count=2)  # 4 training iterations

Training agent on SpaceInvaders with 3K episodes.
Surrogate clipping size: 0.1, rollout size: 100, num agents: 8, num network updates: 4, batch size: 800, training iterations: 4.
(1/4) Episodic Return: 0.54303,  Approx KL: 0.00003,  Total Loss: 0.13992,  Policy Loss: -0.00433,  Value Loss: 0.32427,  Entropy Loss: 1.78845,  Time taken: 1.55 secs.
(2/4) Episodic Return: 0.62296,  Approx KL: 0.69779,  Total Loss: -0.13326,  Policy Loss: -0.23963,  Value Loss: 0.23830,  Entropy Loss: 1.27855,  Time taken: 1.50 secs.
Saved model at episode 2 as: 'ppo_rollout100_agents8_ep2.pt'.
Saved logger data to 'saved_models/ppo_logger_data.tar.gz'. Total size: 719 bytes
(3/4) Episodic Return: 0.50526,  Approx KL: 0.01363,  Total Loss: 0.02596,  Policy Loss: -0.00246,  Value Loss: 0.05697,  Entropy Loss: 0.00628,  Time taken: 1.63 secs.
(4/4) Episodic Return: 0.35854,  Approx KL: -0.00000,  Total Loss: 0.00774,  Policy Loss: 0.00000,  Value Loss: 0.01548,  Entropy Loss: 0.00001,  Time taken: 1.66 secs.


In [14]:
ppo.logger

Available attributes: '['actions', 'avg_rewards', 'avg_returns', 'policy_losses', 'value_losses', 'entropy_losses', 'total_losses', 'approx_kl']'

In [15]:
ppo.logger.actions

[Counter({2: 952, 3: 1148, 0: 960, 4: 1028, 1: 1124, 5: 7588})]