In [1]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import random
import gym

In [2]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
env = gym.make('CartPole-v0')

In [4]:
def plot(frame_idx, rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))
    plt.plot(rewards)
    plt.show()

In [35]:
class Memory():
    
    def __init__(self,max_mem,batch_size,state_dims,action_dims):
        
        self.max_mem = max_mem
        self.batch_size = batch_size
        self.mem_ctr = 0
        self.mem_full_flag = False
        
        self.state = np.zeros((self.max_mem,state_dims))
        self.action = np.zeros((self.max_mem))
        self.reward = np.zeros((self.max_mem))
        self.next_state = np.zeros((self.max_mem,state_dims))
        self.done = np.zeros((self.max_mem))
        
    
    def store(self,state,action,reward,next_state,done):
        
        if self.mem_ctr==self.max_mem:
            self.mem_full_flag = True
            self.mem_ctr = self.mem_ctr%self.memory
            
        self.state[self.mem_ctr] = state
        self.action[self.mem_ctr] = action
        self.reward[self.mem_ctr] = reward
        self.next_state[self.mem_ctr] = next_state
        self.done[self.mem_ctr] = done
        
        self.mem_ctr += 1
    
    def sample(self):
        if self.mem_full_flag:
            current_mem = self.max_mem
        else:
            current_mem = self.mem_ctr
        
        batch = np.random.choice(current_mem, self.batch_size, replace = False)
        state_batch = self.state[batch]
        action_batch = self.action[batch]
        reward_batch = self.reward[batch]
        done_batch = self.done[batch]
        next_state_batch = self.next_state[batch]
        
        return state_batch,action_batch,reward_batch,next_state_batch,done_batch

    def __len__(self):
        if self.mem_full_flag:
            return self.max_mem
        else:
            return self.mem_ctr

In [36]:
class DQN(nn.Module):
    def __init__(self,l_r,input_dims,n_actions):
        super(DQN, self).__init__()
        
        self.linear1 = nn.Linear(input_dims,128)
        self.linear2 = nn.Linear(128,128)
        self.action_value = nn.Linear(128,n_actions)
        
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters(),lr = l_r)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, x):
        
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        action_values = self.action_value(x)    
            
        return action_values

In [76]:
class DQNAgent():
    def __init__(self,epsilon,eps_decay,epsilon_min,gamma,l_r,state_dims,action_dims,max_mem,batch_size,target_update,env):
        
        self.epsilon = epsilon
        self.eps_decay = eps_decay
        self.epsilon_min = epsilon_min
        self.gamma = gamma
        self.env = env
        
        self.target_update = target_update
        
        self.batch_size = batch_size
        
        self.memory = Memory(max_mem,batch_size,state_dims,action_dims)
        
        self.Q_net = DQN(l_r,state_dims,action_dims)
        self.targetQ_net = DQN(l_r,state_dims,action_dims)
        
    def store(self,state,action,reward,next_state,terminal):
        self.memory.store(state,action,reward,next_state,terminal)
    
    def epsilon_greedy(self,state):
        r = np.random.random()
        state = torch.Tensor(state).to(self.Q_net.device)
        if r<self.epsilon:
            action = self.env.action_space.sample()
        else:
            with torch.no_grad():
                q_val = self.Q_net.forward(state)
                action = torch.argmax(q_val).item()
        return action
    
    def epsilon_decay(self):
        if self.epsilon>self.epsilon_min:
            self.epsilon = self.epsilon-self.eps_decay
        return self.epsilon
    
    def update_target(self):
        for target_param, param in zip(self.targetQ_net.parameters(), self.Q_net.parameters()):
            target_param.data.copy_(param.data)
    
    def improve(self):
        
        if (len(self.memory)<self.batch_size):
            return
        
        state,action,reward,next_state,done = self.memory.sample()
        batch_index = np.arange(self.batch_size)
        
        state      = torch.FloatTensor(state).to(self.Q_net.device)
        reward     = torch.FloatTensor(reward).to(self.Q_net.device)
        next_state = torch.FloatTensor(next_state).to(self.Q_net.device)
        done       = torch.FloatTensor(done).to(self.Q_net.device)
        '''
        q_val = self.Q_net.forward(state)
        q_next = self.targetQ_net.forward(next_state).detach()
    
        action_values = torch.max(q_next,1)[0].unsqueeze(1)

        
        q_val = q_val[batch_index,np.array(action)].unsqueeze(1)
        q_estimate = reward + self.gamma*action_values*(1-done)
        '''
        
        q_val = self.Q_net.forward(state)
        q_next = self.targetQ_net.forward(next_state) 
        #q_target = self.Q_net.forward(state).detach()  
        q_target = q_val.clone().detach()
        
        action_values = torch.max(q_next,1)[0]
        #print(action_values.shape)
        q_target[batch_index, action] = reward + self.gamma*action_values*(1-done)

        
        loss = self.Q_net.criterion(q_val,q_target).to(self.Q_net.device)
            
        self.Q_net.optimizer.zero_grad()
        loss.backward()
        self.Q_net.optimizer.step()
    

In [77]:
agent =  DQNAgent(epsilon=1,eps_decay=0.05,epsilon_min=0.01,gamma=0.99,l_r=0.0003,state_dims=4,action_dims=2,
                max_mem=100000,batch_size=32,target_update=10,env=env)

In [78]:
frame_idx   = 0
max_steps   = 200
rewards = []
n_games = 400

In [79]:
for i in range(n_games):
    state = env.reset()
    episode_reward = 0
    
    for s in range(max_steps):
        action = agent.epsilon_greedy(state)
        next_state, reward, done, _ = env.step(action)
        
        agent.store(state, action, reward, next_state, done)
        agent.improve()
        
        state = next_state
        episode_reward += reward
        if done:
            break
        if i % agent.target_update == 0:
            agent.update_target()
    rewards.append(episode_reward)
    agent.epsilon_decay()
    print(episode_reward)

25.0
8.0
16.0
22.0
18.0
11.0
13.0
9.0
16.0
9.0
11.0
18.0
10.0
10.0
9.0
13.0
11.0
12.0
10.0
9.0
9.0
10.0
9.0
10.0
10.0
9.0
8.0
9.0
9.0
8.0
9.0
10.0
10.0
9.0
9.0
8.0
10.0
9.0
9.0
8.0
10.0
9.0
10.0
10.0
10.0
11.0
10.0
12.0
9.0
9.0
10.0
9.0
10.0
11.0
9.0
10.0
12.0
11.0
13.0
12.0
11.0
9.0
14.0
17.0
11.0
14.0
45.0
70.0
61.0
31.0
36.0
21.0
17.0
14.0
20.0
19.0
12.0
13.0
52.0
88.0
62.0
200.0
57.0
67.0
68.0
178.0
158.0
200.0
119.0
200.0
135.0
181.0
149.0
185.0
176.0
143.0
146.0
177.0
124.0
154.0
200.0
200.0
184.0
200.0
200.0
200.0
200.0
200.0
200.0
200.0
200.0
200.0
179.0
200.0
200.0
200.0
200.0


KeyboardInterrupt: 