In [1]:
import sys
sys.path.append('..')

In [2]:
from agents.ddqn import *
from utils.replay import *
from environments.wrappers import *
from networks.flexnet import *
from utils.train import *
from utils.logger import *
from utils.render import *

In [3]:
import gym
import torch

### Network, Environment, Agent

In [4]:
network_params = [
            torch.nn.Conv2d(4, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, 4)
        ]

In [5]:
# parameters
img_size = (84, 84)
num_stacked_frames = 4

raw_env = gym.make('BreakoutNoFrameskip-v4')
env = AtariWrapper(raw_env, k=num_stacked_frames, img_size=img_size)

In [6]:
observation_space = raw_env.observation_space
action_space = raw_env.action_space

params = {'epsilon':1.0, 'epsilon_min':0.1, 'epsilon_decay': None, 'eps_ff': 1000000, 'eps_interval':0.9, 'eps_start':1.0, 'gamma':0.99, 'alpha':2.5e-5, 'network_params': network_params,
          'memory_size':150000, 'device':'cuda:0', 'batch_size':32, 'target_net_updates':10000}

agent = DQNAgent(observation_space, action_space, **params)

In [7]:
agent.network

DeepmindCNN(
  (network): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=3136, out_features=512, bias=True)
    (8): ReLU()
    (9): Linear(in_features=512, out_features=4, bias=True)
  )
)

### Standard Train

In [8]:
logger = Logger('training_info')

In [9]:
save_dir = './models/breakout/'
training_params = {'total_steps':20000000, 'logger':logger, 'save_freq':500000, 'e_verbose':50000, 'file_name': 'breakout ddqn', 'save_dir':save_dir}

In [None]:
standard_train(agent, env, **training_params)

Steps : 50000, Average Reward: 1.1214285714285714, Memory Length: 50000, Optimizer Steps: 50000, Time Elapsed: 286.32413387298584, Target Q Updates: 5
Steps : 100000, Average Reward: 1.151624548736462, Memory Length: 100000, Optimizer Steps: 100000, Time Elapsed: 362.72354674339294, Target Q Updates: 10
Steps : 150000, Average Reward: 0.8610169491525423, Memory Length: 150000, Optimizer Steps: 150000, Time Elapsed: 456.6999294757843, Target Q Updates: 15


## Watch

In [None]:
render_agent(agent, env, './models/breakout/atari ddqn.pth', 2)