In [1]:
import torch as T
import torch.nn as nn
import gym
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from IPython.display import clear_output
import os
import collections
import cv2

## NN

In [2]:
class RNetwork(nn.Module):
    def __init__(self, lr, n_actions, input_dims, chkpt_dir, name):
        super(RNetwork, self).__init__()
        
        self.conv1 = nn.Conv2d(4, 32, 8, stride=4)
        self.conv2 = nn.Conv2d(32, 32, 4, stride=2)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=1)
        
        fc_input_dims = self.calculate_conv_output_dims(input_dims)
        
        self.gru = nn.GRUCell(fc_input_dims, 64)
        
        self.fc3 = nn.Linear(64, 32)
        self.V = nn.Linear(32, 1)
        self.A = nn.Linear(32, n_actions)
        self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
        
        self.checkpoint_dir = chkpt_dir
        self.checkpoint_file = os.path.join(self.checkpoint_dir, name)
        
        self.loss = nn.MSELoss()
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)
        self.double()
        
    def forward(self, state, h = None):
        
        conv1 = F.relu(self.conv1(state))
        conv2 = F.relu(self.conv2(conv1))
        conv3 = F.relu(self.conv3(conv2))
        
        conv_state = conv3.view(conv3.size()[0], -1)
        
        h = F.relu(self.gru(conv_state,h))
        x = F.relu(self.fc3(h))
        
        V = self.V(x)
        A = self.A(x)
        
        return V, A, h
    
    def calculate_conv_output_dims(self, input_dims):
        state = T.zeros(1, 4, 84, 84)
        dims = self.conv1(state)
        dims = self.conv2(dims)
        dims = self.conv3(dims)
        return int(np.prod(dims.size()))

    
    def save_checkpoint(self):
        print('... saving checkpoint ...')
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        print('... loading checkpoint ...')
        self.load_state_dict(T.load(self.checkpoint_file))
    
# model = RNetwork(0.01, env.action_space.n, env.observation_space.shape[0])

## Agent

In [3]:
class Agent():
    def __init__(self, gamma, epsilon, lr, n_actions, input_dims,
                 eps_min=0.5, eps_dec=5e-7):
        self.gamma = gamma
        self.epsilon = epsilon
        self.lr = lr
        self.n_actions = n_actions
        self.input_dims = input_dims
        self.eps_min = eps_min
        self.eps_dec = eps_dec
        self.q_eval = RNetwork(lr,
                               env.action_space.n,
                               env.observation_space.shape[0],
                               chkpt_dir='tmp',
                               name='RDDQN_' + str(self.lr))
        self.q_next = RNetwork(lr,
                               env.action_space.n,
                               env.observation_space.shape[0],
                               chkpt_dir='tmp',
                               name='RDDQN_' + str(self.lr))
        
        self.replace_target_cnt = 50
        self.learn_step_counter = 0
        
        self.state_memory = []
        self.actions_memory = []
        self.done_memory = []
        self.reward_memory = []
    
    def clear_memory(self):
        self.state_memory = []
        self.actions_memory = []
        self.done_memory = []
        self.reward_memory = []
    
    def store_transition(self, state, action, reward, done):
        self.state_memory.append(state)
        self.actions_memory.append(action)
        self.reward_memory.append(reward)
        self.done_memory.append(done)

    def choose_action(self, observation, h):
        observation = T.tensor([observation], dtype = T.double).to(self.q_eval.device)
        _, advantage, h = self.q_eval(observation, h)
        
        if np.random.random() > self.epsilon:
            action = T.argmax(advantage).item()
        else:
            action = np.random.randint(self.n_actions)
        
        return action, h
        
    def butch_predict(self):
        h = None
        adv = []
        val = []
        for i in agent.state_memory[:-1]:
            state = T.tensor([i], dtype = T.double).to(agent.q_eval.device)
            v,a,h = agent.q_eval(state, h)
            adv.append(a)
            val.append(v)
        return T.stack(adv).squeeze(1),T.stack(val).squeeze(1)
    
    def butch_predict_next(self):
        h = None
        adv = []
        val = []
        for i in agent.state_memory[1:]:
            state = T.tensor([i], dtype = T.double).to(agent.q_eval.device)
            v,a,h = agent.q_next(state, h)
            adv.append(a)
            val.append(v)
        return T.stack(adv).squeeze(1),T.stack(val).squeeze(1)

    def replace_target_network(self):
        if self.learn_step_counter % self.replace_target_cnt == 0:
            self.q_next.load_state_dict(self.q_eval.state_dict())

    def decrement_epsilon(self):
        self.epsilon = self.epsilon - self.eps_dec \
                           if self.epsilon > self.eps_min else self.eps_min
        
    def save_models(self):
        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def load_models(self):
        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()

    def learn(self):
        # states  = T.tensor(self.state_memory, dtype = T.double).to(self.q_eval.device)
        actions = T.tensor(self.actions_memory).to(self.q_eval.device)
        dones   = T.tensor(self.done_memory).to(self.q_eval.device)
        rewards = T.tensor(self.reward_memory).to(self.q_eval.device)
        
        self.replace_target_network()
        
        self.q_eval.optimizer.zero_grad()
        
        indices = T.arange(len(self.actions_memory)).to(self.q_eval.device)

        V_s, A_s   = self.butch_predict()
        V_s_, A_s_ = self.butch_predict_next()
        
        q_pred = T.add(V_s,
                        (A_s - A_s.mean(dim=1, keepdim=True)))[indices, actions]
        q_next = T.add(V_s_,
                        (A_s_ - A_s_.mean(dim=1, keepdim=True))).max(dim=1)[0]
        
        q_next[dones] = 0.0

        q_target = rewards + self.gamma*q_next
        loss = self.q_eval.loss(q_target, q_pred.to(T.double)).to(self.q_eval.device)
        loss.backward()
        self.q_eval.optimizer.step()

        self.decrement_epsilon()
        self.clear_memory()
        self.learn_step_counter += 1

## Init

In [4]:
class RepeatActionAndMaxFrame(gym.Wrapper):
    def __init__(self, env=None, repeat=4, clip_reward=False, no_ops=0,
                 fire_first=False):
        super(RepeatActionAndMaxFrame, self).__init__(env)
        self.repeat = repeat
        self.shape = env.observation_space.low.shape
        self.frame_buffer = np.zeros_like((2, self.shape))
        self.clip_reward = clip_reward
        self.no_ops = no_ops
        self.fire_first = fire_first

    def step(self, action):
        t_reward = 0.0
        done = False
        for i in range(self.repeat):
            obs, reward, done, info = self.env.step(action)
            if self.clip_reward:
                reward = np.clip(np.array([reward]), -1, 1)[0]
            t_reward += reward
            idx = i % 2
            self.frame_buffer[idx] = obs
            if done:
                break

        max_frame = np.maximum(self.frame_buffer[0], self.frame_buffer[1])
        return max_frame, t_reward, done, info

    def reset(self):
        obs = self.env.reset()
        no_ops = np.random.randint(self.no_ops) + 1 if self.no_ops > 0 else 0
        for _ in range(no_ops):
            _, _, done, _ = self.env.step(0)
            if done:
                self.env.reset()
        if self.fire_first:
            assert self.env.unwrapped.get_action_meanings()[1] == 'FIRE'
            obs, _, _, _ = self.env.step(1)

        self.frame_buffer = np.zeros_like((2, self.shape))
        self.frame_buffer[0] = obs

        return obs

In [5]:
class PreprocessFrame(gym.ObservationWrapper):
    def __init__(self, shape, env=None):
        super(PreprocessFrame, self).__init__(env)
        self.shape = (shape[2], shape[0], shape[1])
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0,
                                                shape=self.shape, dtype=np.float32)

    def observation(self, obs):
        new_frame = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        resized_screen = cv2.resize(new_frame, self.shape[1:],
                                    interpolation=cv2.INTER_AREA)
        new_obs = np.array(resized_screen, dtype=np.uint8).reshape(self.shape)
        new_obs = new_obs / 255.0

        return new_obs

In [6]:
class StackFrames(gym.ObservationWrapper):
    def __init__(self, env, repeat):
        super(StackFrames, self).__init__(env)
        self.observation_space = gym.spaces.Box(
            env.observation_space.low.repeat(repeat, axis=0),
            env.observation_space.high.repeat(repeat, axis=0),
            dtype=np.float32)
        self.stack = collections.deque(maxlen=repeat)

    def reset(self):
        self.stack.clear()
        observation = self.env.reset()
        for _ in range(self.stack.maxlen):
            self.stack.append(observation)

        return np.array(self.stack).reshape(self.observation_space.low.shape)

    def observation(self, observation):
        self.stack.append(observation)

        return np.array(self.stack).reshape(self.observation_space.low.shape)

In [7]:
def make_env(env_name, shape=(84, 84, 1), repeat=4, clip_rewards=False,
             no_ops=0, fire_first=False):
    env = gym.make(env_name)
    env = RepeatActionAndMaxFrame(env, repeat, clip_rewards, no_ops, fire_first)
    env = PreprocessFrame(shape, env)
    env = StackFrames(env, repeat)

    return env

In [8]:
env = make_env(env_name='PongNoFrameskip-v4',repeat=4,
                  clip_rewards=False, no_ops=0,
                  fire_first=False)


A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]


In [9]:
agent = Agent(gamma = 0.97, 
              epsilon = 0.7, 
              lr = 0.005, 
              n_actions = env.action_space.n, 
              input_dims = env.observation_space.shape,
              eps_min=0.05,
              eps_dec=5e-3)

## test

In [10]:
# state = env.observation_space.sample()
# state = T.tensor([state], dtype = T.double).to('cuda')

In [11]:
# agent.q_eval(state)

## Main

In [None]:
best_score = -np.inf
scores, eps_history = [], []

num_games = 2000

for i in range(num_games):
    d = False
    observation = env.reset()
    agent.state_memory.append(observation)
    h = None

    score = 0

    while not d:
        a, h = agent.choose_action(observation, h)
        observation_, r, d, info = env.step(a)
        agent.store_transition(state = observation_,
                               action = a,
                               reward = r,
                               done = d)
        score += r
        observation = observation_
    agent.learn()
    
    scores.append(score)
    eps_history.append(agent.epsilon)
    
    avg_score = np.mean(scores[-100:])
    if avg_score > best_score:
        agent.save_models()
        best_score = avg_score
    
    if not bool(i%10):clear_output()
    print(i, score, avg_score, agent.epsilon)

250 -21.0 -20.96 0.05
251 -21.0 -20.96 0.05
252 -21.0 -20.97 0.05
253 -21.0 -20.97 0.05
254 -21.0 -20.97 0.05
255 -21.0 -20.97 0.05
256 -21.0 -20.97 0.05


In [None]:
# states  = T.tensor(agent.state_memory, dtype = T.double).to(agent.q_eval.device)

In [None]:
# adv = []
# val = []

# h = None
# for i in agent.state_memory:
#     state = T.tensor([i], dtype = T.double).to(agent.q_eval.device)
#     # print(state)
#     v,a,h = agent.q_eval(state, h)
#     adv.append(a)
#     val.append(v)
# T.stack(adv).squeeze(1),T.stack(val).squeeze(1)

## Plot

In [None]:
import matplotlib.pyplot as plt

def plot_learning_curve(x, scores, epsilons, filename, lines=None):
    fig=plt.figure()
    ax=fig.add_subplot(111, label="1")
    ax2=fig.add_subplot(111, label="2", frame_on=False)

    ax.plot(x, epsilons, color="C0")
    ax.set_xlabel("Num games", color="C0")
    ax.set_ylabel("Epsilon", color="C0")
    ax.tick_params(axis='x', colors="C0")
    ax.tick_params(axis='y', colors="C0")

    N = len(scores)
    running_avg = np.empty(N)
    for t in range(N):
        running_avg[t] = np.mean(scores[max(0, t-20):(t+1)])

    ax2.scatter(x, running_avg, color="C1")
    ax2.axes.get_xaxis().set_visible(False)
    ax2.yaxis.tick_right()
    ax2.set_ylabel('Score', color="C1")
    ax2.yaxis.set_label_position('right')
    ax2.tick_params(axis='y', colors="C1")

    if lines is not None:
        for line in lines:
            plt.axvline(x=line)

    plt.savefig(filename)

In [None]:
x = [i+1 for i in range(len(scores))]
plot_learning_curve(x, scores, eps_history, filename = 'RDDQN.png')