# Imports

In [14]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import random
import time

In [15]:
# %pip install stable-baselines3[extra]

In [16]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.utils import get_linear_fn
from stable_baselines3.common.policies import ActorCriticCnnPolicy

# Settings

In [25]:
parameters = {
    "device" : torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    "total_time_steps" : 1000000,
    "checkpoint_freq" : 200000,
    "eval_freq" : 50000,
    "n_steps" : 2048,
    "batch_size" : 64,
    "gae_lambda" : 0.95,
    "ent_coef" : 0.01,
    "gamma" : 0.99,
    "verbose" : 0,
    "clip_range" : 0.2
}

In [26]:
parameters['device']

device(type='cuda')

# Initial Model

In [27]:
env_id = 'AssaultNoFrameskip-v4'
env = make_atari_env(env_id, n_envs=1, seed=0)

In [28]:
# Wrap the environment to stack frames and normalize observations
env = VecFrameStack(env, n_stack=4)

In [29]:
tensorboard_log_dir = "./ppo_assault_tensorboard/"

In [30]:
# Create the PPO model
#model = PPO('CnnPolicy', env, verbose=0, tensorboard_log=tensorboard_log_dir) # Change verbose to 1 for info messages and 2 for debug messages

# Fine tuned model with custom actor-critic policy

In [31]:
class CustomCnnPolicy(ActorCriticCnnPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomCnnPolicy, self).__init__(*args, **kwargs,
            net_arch=[dict(pi=[256, 256], vf=[256, 256])])


In [32]:
learning_rate_schedule = get_linear_fn(start=3e-4, end=1e-6, end_fraction=0.9)
model = PPO(CustomCnnPolicy, env, learning_rate=learning_rate_schedule, verbose=parameters['verbose'], 
            tensorboard_log=tensorboard_log_dir, n_steps=parameters['n_steps'], 
            batch_size=parameters['batch_size'], clip_range=parameters['clip_range'], gae_lambda=parameters['gae_lambda'], 
            ent_coef=parameters['ent_coef'], gamma=parameters['gamma'])



# Training

In [33]:
# Callbacks for evaluation and saving models
#checkpoint_callback = CheckpointCallback(save_freq=parameters['checkpoint_freq'], save_path='./logs/', name_prefix='ppo_assault_2m') # Save checkpoint trained state every 10k time steps. Might need to remove
eval_callback = EvalCallback(env, best_model_save_path='./logs/best_model/assault_1m_steps_tuned',
                             log_path='./logs/results', eval_freq=parameters['eval_freq'])

In [10]:
start_time = time.time()
model.learn(total_timesteps=parameters['total_time_steps'], callback=[eval_callback]) #[checkpoint_callback, eval_callback])

  return F.conv2d(input, weight, bias, self.stride,


Eval num_timesteps=50000, episode_reward=302.40 +/- 43.24
Episode length: 2583.20 +/- 579.39
New best mean reward!
Eval num_timesteps=100000, episode_reward=449.40 +/- 60.28
Episode length: 3048.40 +/- 395.93
New best mean reward!
Eval num_timesteps=150000, episode_reward=390.60 +/- 16.80
Episode length: 3124.00 +/- 359.69
Eval num_timesteps=200000, episode_reward=260.40 +/- 99.92
Episode length: 2900.00 +/- 1006.90
Eval num_timesteps=250000, episode_reward=260.40 +/- 54.11
Episode length: 2565.20 +/- 321.76
Eval num_timesteps=300000, episode_reward=348.60 +/- 148.85
Episode length: 2766.60 +/- 559.65
Eval num_timesteps=350000, episode_reward=411.60 +/- 75.83
Episode length: 3026.20 +/- 453.08
Eval num_timesteps=400000, episode_reward=568.80 +/- 93.18
Episode length: 4570.80 +/- 1031.50
New best mean reward!
Eval num_timesteps=450000, episode_reward=415.60 +/- 197.79
Episode length: 3401.20 +/- 1213.71
Eval num_timesteps=500000, episode_reward=525.00 +/- 75.13
Episode length: 3455.00 +

<stable_baselines3.ppo.ppo.PPO at 0x252141a4a50>

* Time Elapsed 1m steps: 84m
* Time Elapsed 2m steps: 140m

In [12]:
model.save("ppo_assault_1m_tuned")

# Fine tuning

# Loading and evaluating the model

* TODO: Load policy model and run for evaluation