In [1]:
import sys, os
sys.path.append(os.path.abspath('../..'))

In [2]:
import torch
import torch.optim as optim
import gym
from torch.utils.tensorboard import SummaryWriter
from itertools import count

from networks.dqn_atari import DQN
from utils.memory import StandardReplayMemory
from utils.optimization import standard_optimization
from utils.atari_utils import select_action, get_state, eps_decay, transform_reward

In [3]:
n_episodes = 200000
POLICY_UPDATE = 4
TARGET_UPDATE = 4000
INITIAL_MEMORY = 500
REWARD_UPDATE = 100
MEMORY_SIZE = 20 * INITIAL_MEMORY
lr=0.00025
device = 'cuda'

In [4]:
env = gym.make("PongDeterministic-v4")
n_actions = env.action_space.n

In [5]:
policy_net = DQN(n_actions=n_actions).to(device)
target_net = DQN(n_actions=n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.Adam(policy_net.parameters(), lr=lr)
memory = StandardReplayMemory(MEMORY_SIZE)

In [6]:
writer=SummaryWriter()
steps_done = 0

In [7]:
for episode in range(n_episodes):
  obs = env.reset()
  state = get_state(obs)
  total_reward = 0.0
  states_list = [state, state, state]
  for t in count():
    eps_threshold = eps_decay(steps_done, EPS_DECAY=MEMORY_SIZE)
    cur_states = torch.cat(states_list).unsqueeze(0)
    action = select_action(policy_net, cur_states, 0.02, n_actions=n_actions)
    steps_done += 1
    
    obs, reward, done, info = env.step(action)
    
    total_reward += reward
    
    reward = transform_reward(reward)
    
    if not done:
      next_state = get_state(obs)
      states_list.append(next_state)
      states_list = states_list[1:]
      next_states = torch.cat(states_list).unsqueeze(0)
    else:
      next_state = None
      states_list = None
      next_states = None
      
    reward = torch.Tensor([reward])
    
    memory.push(cur_states, action.to('cpu'), next_states, reward.to('cpu'))
    state = next_state
    
    if steps_done > INITIAL_MEMORY and steps_done % POLICY_UPDATE == 0:
      loss = standard_optimization(policy_net, target_net, memory, optimizer)
      writer.add_scalar('Performance/loss', loss, steps_done)
      
    if steps_done % TARGET_UPDATE == 0:
      target_net.load_state_dict(policy_net.state_dict())
    
    if done:
      break
      
    writer.add_scalar('Other/episode', episode, episode)
    writer.add_scalar('Other/epsilon', eps_decay(steps_done, EPS_DECAY=MEMORY_SIZE), episode)
    writer.add_scalar('Performance/reward', total_reward, episode)
    
#   if episode % REWARD_UPDATE:
#     torch.save(policy_net, "models/dqn_expert_breakout_model")
    
#     total_reward = 0.0
#     for _ in range(10):
#       obs = env.reset()
#       state = get_state(obs)
#       for t in count():
#         action = select_action(policy_net, state, 0.02, n_actions=n_actions)
#         obs, reward, done, info = env.step(action)
#         total_reward += reward
      
#         if not done:
#           next_state = get_state(obs)
#         else:
#           next_state = None
    
#         state = next_state
#         if done:
#           break
          
#     total_reward /= 10.0

KeyboardInterrupt: 

In [8]:
torch.save(policy_net, "models/dqn_expert_pong_model")

In [None]:
policy_net = torch.load("models/dqn_expert_pong_model")