## DQN with Atari

In [None]:
from atari_wrappers import make_atari, wrap_deepmind
from PERMemory import Memory
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [None]:
env = make_atari('PongNoFrameskip-v4')
env = wrap_deepmind(env, scale = False, frame_stack=True)

### Q-Network

In [None]:
class QNetwork(nn.Module):
    def __init__(self, in_channel, hidden_dim, action_space):
        super(QNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, 12, kernel_size = 8, stride = 4) 
        self.conv2 = nn.Conv2d(12, 24, kernel_size = 3, stride = 2)
        self.conv3 = nn.Conv2d(24, 24, kernel_size = 3, stride = 1)
        self.fc1 = nn.Linear(7*7*24, 256)
        self.fc2 = nn.Linear(256, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_space)
        
    def forward(self, x):
        x = F.relu(self.conv1(x)) ## bs*19*19*12
        x = F.relu(self.conv2(x)) ## bs*9*9*24
        x = F.relu(self.conv3(x)) ## bs*7*7*24
        x = F.relu(self.fc1(x.view(x.size(0),-1))) ## bs,256
        x = F.relu(self.fc2(x)) ## bs,hidden_dim
        x = self.fc3(x) ## bs,action_space
        return x

### Agent

In [None]:
## Agent

class DQNAgent():
    def __init__(self, env, hidden_dim, capacity, batch_size, in_channel = 4,
                 gamma = 0.9, epsilon = 0.1, decay_rate = 1, learning_rate = 1e-4, init = True):
        self.env = env
        self.action_space = self.env.action_space  
        self.obs_space = self.env.observation_space.shape
        self.action_len = len([i for i in range(self.action_space.n)])
        self.memory = Memory(capacity = capacity)
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.behaviour_QNetwork = QNetwork(in_channel, hidden_dim, self.action_len).to(self.device)
        self.target_QNetwork = QNetwork(in_channel, hidden_dim, self.action_len).to(self.device)
        self.loss_fn = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.behaviour_QNetwork.parameters(), lr = learning_rate)
        
        self.batch_size = batch_size
        self.epsilon = epsilon
        self.decay_rate = decay_rate
        self.gamma = gamma
        
        if init:
            print("********** buffer memory **********")
            for _ in tqdm(range(100)):
                s0 = self.env.reset()
                s0 = self.stateTran(s0)
                is_end = False
                while not is_end:
                    a0 = self.action_space.sample()
                    s1, reward, is_end, _ = self.env.step(a0)
                    s1 = self.stateTran(s1)
                    self.memory.store([s0,a0,reward,s1,is_end])
                    s0 = s1
                
    def stateTran(self,state):
        return state._force().transpose(2,0,1)
        
    def policy(self, state, epsilon):
        if np.random.random() < epsilon:
            action = self.action_space.sample()
        else:
            score = self.behaviour_QNetwork(torch.Tensor(state).unsqueeze(0).to(self.device)).detach()
            action = torch.argmax(score).item()
        return action
    
    def learn(self, display = False):
        s0 = self.env.reset()
        s0 = self.stateTran(s0)
        if display:
            self.env.render()
        is_end = False
        episode_reward = 0
        losses = []
        
        while not is_end:
            ## choose an action and make a step
            a0 = self.policy(s0, epsilon = self.epsilon)
            s1, reward, is_end, _ = self.env.step(a0)
            s1 = self.stateTran(s1)
            if display:
                self.env.render()
            
            ## store the transition into memory
            self.memory.store([s0,a0,reward,s1,is_end])
            
            ## sample minibatch from memory
            b_idx, b_memory, ISWeights = self.memory.sample(self.batch_size)
            b_s, b_a, b_r, b_s_next, b_e = [], [], [], [], []
            for batch in b_memory:
                b_s.append(batch[0])
                b_a.append(batch[1])
                b_r.append(batch[2])
                b_s_next.append(batch[3])
                b_e.append(batch[4])
            b_s = torch.Tensor(b_s).to(self.device)
            b_a = torch.LongTensor(b_a).to(self.device).reshape(-1,1)
            b_r = torch.Tensor(b_r).to(self.device).reshape(-1,1)
            b_s_next = torch.Tensor(b_s_next).to(self.device)
            b_e = torch.Tensor(b_e).to(self.device)
            
            ## compute two Q values
            #Q_target = b_r + self.gamma * torch.max(self.target_QNetwork(b_s_next),1)[0].reshape(-1,1) * (1 - b_e).reshape(-1,1)
            max_a = torch.argmax(self.behaviour_QNetwork(b_s_next),1).detach().reshape(-1,1)
            Q_target = b_r + self.gamma * self.target_QNetwork(b_s_next).gather(1,max_a) * (1 - b_e).reshape(-1,1)
            Q_behaviour = self.behaviour_QNetwork(b_s).gather(1,b_a)
            
            ## update memory
            abs_err = np.abs((Q_behaviour-Q_target).cpu().detach().numpy()).reshape(self.batch_size,)
            self.memory.batch_update(b_idx, abs_err)
            
            # learn
            ## compute loss
            #loss = self.loss_fn(Q_target, Q_behaviour)
            loss = torch.sum(torch.Tensor(ISWeights).to(self.device) * (Q_target - Q_behaviour)**2)
            losses.append(loss.item())
            ## back prop
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            ## iteration
            s0 = s1
            episode_reward += reward
        
        ## update target network
        self.target_QNetwork.load_state_dict(self.behaviour_QNetwork.state_dict())
        self.epsilon *= self.decay_rate
        
        return episode_reward, np.mean(losses)
    
    def save_model(self):
        torch.save(self.behaviour_QNetwork, 'saved_model\DQN')
            

In [None]:
MAX_EPISODE = 50
episode_reward = []
episode_loss = []

agent = DQNAgent(env = env, hidden_dim = 6, capacity = 10000, batch_size = 64, 
                  gamma = 0.9, epsilon = 0.1, decay_rate = 0.99999, learning_rate = 1e-3, init = True)

print("**********begin training***********")

for e in tqdm(range(MAX_EPISODE)):
    reward, loss = agent.learn(display = False)
    episode_reward.append(reward)
    episode_loss.append(loss)
    if e % 10 == 0:
        print("episode:", e, "reward:", reward, "loss:", round(loss,4))

In [None]:
plt.plot(episode_reward)
plt.title("Pong with PER")
plt.xlabel("episode")
plt.ylabel("reward")

In [None]:
plt.plot(episode_loss)
plt.title("loss with PER")
plt.xlabel("episode")
plt.ylabel("loss")

In [None]:
# save learning curve
dqn_curve = np.array(episode_reward)
np.save('curve\DQN', dqn_curve)