## Imports

In [34]:
%matplotlib inline

import time

import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T

from collections import deque


## Initialization

In [35]:
use_gpu = torch.cuda.is_available()
use_cpu = not torch.cuda.is_available()
print('Use GPU: {}'.format(use_gpu))
debug = 0

Use GPU: True


## Replay memory

In [36]:

class ReplayMemory():
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def add(self, state, action, reward, done, next_state):
        experience = (state, action, reward, done, next_state)
        self.buffer.append(experience)
        
    def sample(self, batch_size):
        if self.count() >= batch_size:
            batch = random.sample(self.buffer, batch_size)        
        else:
            batch = random.sample(self.buffer, self.count())
            
        next_state_batch = np.array([np.array(experience[4]) for experience in batch])
        done_batch = np.array([experience[3] for experience in batch])
        reward_batch = np.array([experience[2] for experience in batch])
        action_batch = np.array([experience[1] for experience in batch])
        state_batch = np.array([np.array(experience[0]) for experience in batch])        
        
        return state_batch, action_batch, reward_batch, done_batch, next_state_batch
    
    def count(self):
        return len(self.buffer)

## Q-Network

In [37]:
class DQN(nn.Module):
    def __init__(self, num_actions,ch1,ch2,ch3,ch4):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(ch1, ch2, kernel_size=8, stride=4, padding=0)        
        self.conv2 = nn.Conv2d(ch2, ch3, kernel_size=4, stride=2, padding=0)        
        self.conv3 = nn.Conv2d(ch3, ch4, kernel_size=3, stride=1, padding=0)
        
        
        self.fc1 = nn.Linear(7 * 7 * 64, 512)
        self.fc2 = nn.Linear(512, num_actions)
        
    def forward(self, inputs):
        out = F.relu(self.conv1(inputs))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        
        return out

## Agent

In [38]:
import os
from atari_wrappers import wrap_dqn
import datetime

class PongAgent:
    def __init__(self):
        self.env = wrap_dqn(gym.make('TennisDeterministic-v0'))
        self.num_actions = self.env.action_space.n
        ch1 = 4
        ch2 = 32
        ch3 = 64
        ch4 = 64
        self.dqn = DQN(self.num_actions,ch1,ch2,ch3,ch4)
        self.target_dqn = DQN(self.num_actions,ch1,ch2,ch3,ch4)
        
        if not use_cpu:
            self.dqn.cuda()
            self.target_dqn.cuda()        
        
        self.buffer = ReplayMemory(1000000)
        
        self.gamma = 0.99
        
        self.mse_loss = nn.MSELoss()
        self.optim = optim.RMSprop(self.dqn.parameters(), lr=0.0001)
        
        self.out_dir = './model_d_v0'
        
        if not os.path.exists(self.out_dir):
            os.makedirs(self.out_dir)

        
    def to_var(self, x):
        x_var = Variable(x)
        if not use_gpu:
            return x_var
        return x_var.cuda()        
        
    def predict_q_values(self, states):
        states = self.to_var(torch.from_numpy(states).float())
        actions = self.dqn(states)
        return actions

    
    def predict_q_target_values(self, states):
        states = self.to_var(torch.from_numpy(states).float())
        actions = self.target_dqn(states)
        return actions

    
    def select_action(self, state, epsilon):
        choice = np.random.choice([0, 1], p=(epsilon, (1 - epsilon)))
        if choice:
            state = np.expand_dims(state, 0)
            actions = self.predict_q_values(state)
            return np.argmax(actions.data.cpu().numpy())            
        else:
            return np.random.choice(range(self.num_actions))

        
    def update(self, states, targets, actions):
        targets = self.to_var(torch.unsqueeze(torch.from_numpy(targets).float(), -1))
        actions = self.to_var(torch.unsqueeze(torch.from_numpy(actions).long(), -1))
        
        predicted_values = self.predict_q_values(states)
        affected_values = torch.gather(predicted_values, 1, actions)
        loss = self.mse_loss(affected_values, targets)
        
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

        
    def get_epsilon(self, total_steps, max_epsilon_steps, epsilon_start, epsilon_final):
        temp_max = epsilon_start - total_steps / max_epsilon_steps
        if(epsilon_final < temp_max):
            return temp_max
        return epsilon_final
    
    def sync_target_network(self):
        primary_params = list(self.dqn.parameters())
        target_params = list(self.target_dqn.parameters())
        for i in range(0, len(primary_params)):
            target_params[i].data[:] = primary_params[i].data[:]
            
    def increament(self, var, value):
        var = var + value
        return var
    
    def calculate_q_targets(self, next_states, rewards, dones):
        dones_mask = (dones == 1)
        
        predicted_q_target_values = self.predict_q_target_values(next_states)
        
        next_max_q_values = np.max(predicted_q_target_values.data.cpu().numpy(), axis=1)
        next_max_q_values[dones_mask] = 0 # no next max Q values if the game is over
        q_targets = self.increament(rewards, self.gamma * next_max_q_values)
        
        return q_targets
        
        
    def train(self, replay_buffer_fill_len, batch_size, episodes, stop_reward,
              max_epsilon_steps, epsilon_start, epsilon_final, sync_target_net_freq, prev_reward):
        start_time = time.time()
        local_time = time.asctime(time.localtime(time.time()))
        print('Start training at: '+ local_time)
        
        # populate replay memory
        print('Populating replay buffer... \n')
        state = self.env.reset()
        for i in range(replay_buffer_fill_len):
            action = self.select_action(state, 1) # force to choose a random action
            next_state, reward, done, _ = self.env.step(action)
            
            self.buffer.add(state, action, reward, done, next_state)
            
            state = next_state
            if done:
                self.env.reset()
                
        print('replay buffer populated with {} transitions, start training... \n'.format(self.buffer.count()))
        
        running_episode_reward = 0
        prev_reward = 0
        total_steps = 0
        
        # main loop - iterate over episodes
        for i in range(1, episodes + 1):
            # reset the environment
            if(debug):
                print("current episode is: ", i)
                
            done = False
            state = self.env.reset()
            
            # reset spisode reward and length
            episode_reward = 0
            episode_length = 0
            
            # play until it is possible
            while not done:
                # synchronize target network with estimation network in required frequence
                comp = total_steps % sync_target_net_freq
                if(debug):
                    print("synchronizing after ",comp, " episodes")
                    
                if(comp):
                    pass
                if (not comp):
                    print('synchronizing target network...\n')                    
                    self.sync_target_network()

                # calculate epsilon and select greedy action
                epsilon = self.get_epsilon(total_steps, max_epsilon_steps, epsilon_start, epsilon_final)
                action = self.select_action(state, epsilon)
                
                # execute action in the environment
                next_state, reward, done, _ = self.env.step(action)
                self.buffer.add(state, action, reward, done, next_state)
                
                # sample random minibatch of transactions
                batch_params = []
                batch_params = self.buffer.sample(batch_size)
                s_batch = batch_params[0]
                a_batch = batch_params[1]
                r_batch = batch_params[2]
                d_batch = batch_params[3]
                next_s_batch = batch_params[4]
                
                # estimate Q value using the target network
                q_targets = self.calculate_q_targets(next_s_batch, r_batch, d_batch)
                
                # update weights in the estimation network
                self.update(s_batch, q_targets, a_batch)
                
                # set the state for the next action selction and update counters and reward
                state = next_state
                total_steps = self.increament(total_steps,1)
                episode_length = self.increament(episode_length,1)
                episode_reward = self.increament(episode_reward,reward)
                prev_reward = self.increament(prev_reward,reward)                
            running_episode_reward = running_episode_reward * 0.9 + 0.1 * episode_reward

            if (i % 10) == 0 or (running_episode_reward > stop_reward):
                print('global step: {}'.format(total_steps))
                print('episode: {}'.format(i))
                print('running reward: {}'.format(round(running_episode_reward, 2)))
                print('current epsilon: {}'.format(round(epsilon, 2)))
                print('episode_length: {}'.format(episode_length))
                print('episode reward: {}'.format(episode_reward))
                print('\n')
                
            if (i % 50) == 0 or (running_episode_reward > stop_reward):
                curr_time = time.time()
                print('current time: ' + time.asctime(time.localtime(curr_time)))
                print('running for: ' + str(datetime.timedelta(seconds=curr_time - start_time)))
                print('saving model after {} episodes...'.format(i))
                print('\n')
                filename = '{}/current_model_{}.pth'.format(self.out_dir, i)
                torch.save(self.dqn.state_dict(), filename)
            
            if running_episode_reward > stop_reward:
                print('stop reward reached!')
                print('saving final model...')
                print('\n')
                filename = '{}/final_model_.pth'.format(self.out_dir)
                torch.save(self.dqn.state_dict(), filename)
                break
        
        print('Finish training at: '+ time.asctime(time.localtime(start_time)))

In [39]:
agent = PongAgent()

In [40]:
agent.train(replay_buffer_fill_len=100, 
            batch_size=32, 
            episodes=10**5,
            stop_reward=19,
            max_epsilon_steps=10**5,
            epsilon_start=1.0,
            epsilon_final=0.02,
            sync_target_net_freq=10000,
            prev_reward=0)

Start training at: Mon Nov  4 18:16:07 2019
Populating replay buffer... 

replay buffer populated with 100 transitions, start training... 

synchronizing target network...



KeyboardInterrupt: 