In [5]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gym
from gym import wrappers
env = gym.make('CartPole-v0') # Create environment


In [6]:
class DQN(nn.Module): 
    def __init__(self, input_dim, hidden_dim, output_dim): 
        super(DQN, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), 
            nn.ReLU(), 
            nn.Linear(hidden_dim, hidden_dim), 
            nn.ReLU(), 
            nn.Linear(hidden_dim, hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )        
        
    def forward(self, data): 
        x = self.layers(data)
        return F.softmax(x, dim = 0)

in_dim = env.observation_space.shape[0]
action = env.action_space.n    

model_base = DQN(in_dim, 256, action)
model_target = DQN(in_dim, 256, action)

In [7]:
import random
class Buffer: 
    
    mem = []
    
    def __init__(self, length): 
        self.length = length
        
    def add(self, observation):
        self.mem.append(observation)
        if len(self.mem) > self.length:
            self.mem = self.mem[-self.length:]
    
    def __len__(self): 
        return len(self.mem)

    def __iter__(self): 
        for k in self.mem: 
            yield k
            
    def sample(self, batch_size = 32):
        return random.sample(self.mem, batch_size)

In [8]:
class Agent: 
    
    def __init__(self, model, epsilon, e_decay, optimizer, epsilon_bound = 0.01, decay_frequency = 0.5):
        self.model_base = model
        self.model_target = model
        self.model_target.load_state_dict(model_base.state_dict())
        self.epsilon = epsilon
        self.e_decay = e_decay
        self.epsilon_bound = epsilon_bound
        self.optimizer = optimizer
        self.decay_frequency = decay_frequency

    def act(self, state): 
        tmp = random.uniform(0,1)
        if tmp < self.epsilon: 
            a = env.action_space.sample()
        else:
            logs = self.model_base(torch.tensor(state, dtype=torch.float))
            a = torch.multinomial(logs, 1).item()
        tmp = random.uniform(0,1)
        if tmp < self.decay_frequency:
            self.decay()
        return a
    
    def update_network(self):
        self.model_target.load_state_dict(model_base.state_dict())
        
    def loss(self):
        
        
    def train(self, batch):
        optimizer.zero_grad()
        states = batch[:,0]
        rewards = batch[:, 1]
        actions = batch[:, 2]
        next_states = batch[:, 3]
        
        action_indices = self.model_base(states).gather(1,actions)
        
        
        
    def decay(self): 
        tmp_epsilon = self.epsilon * self.e_decay
        if self.epsilon_bound < tmp_epsilon: 
            self.epsilon = tmp_epsilon
    

In [9]:
epsilon = 0.05
optimizer = torch.optim.Adam(model_base.parameters(), lr = 0.005)

agent = Agent(model_base, 1, 0.99, optimizer, epsilon)

In [16]:
def train(optimizer): 
    b = Buffer(1000)
    batch_size = 32
    s = env.reset()
    
    episodes = 100
    max_length_episode = 1000
    
    for i in range(episodes): 
               
        #if i % 100 == 0:
        for j in range(max_length_episode): 
            
            a = agent.act(s)
            if(len(b) >= batch_size):
                agent.train(b.sample(batch_size))
                #print(b.sample(batch_size))
            
            
            s1, r, done, _ = env.step(a)
            b.add((s, a, r, s1))
            s = s1
            if done:
                s = env.reset()
                break
                
train(optimizer)



[(array([ 0.2888586 ,  1.77270442, -0.57158922, -3.48702391]), 1, 0.0, array([ 0.32431269,  1.96384117, -0.6413297 , -3.88720003])), (array([ 0.72391678,  2.95803   , -1.47430715, -6.68444369]), 0, 0.0, array([ 0.78307738,  2.73692249, -1.60799603, -6.94512407])), (array([ 0.61470539,  2.65202198, -1.22750198, -5.99219294]), 1, 0.0, array([ 0.66774583,  2.80854735, -1.34734583, -6.34806587])), (array([ 0.24956608,  1.96462592, -0.49961189, -3.59886693]), 0, 0.0, array([ 0.2888586 ,  1.77270442, -0.57158922, -3.48702391])), (array([-0.02158224,  0.58616519, -0.04115698, -0.8898227 ]), 0, 1.0, array([-0.00985894,  0.39162514, -0.05895343, -0.61035647])), (array([ 0.06475541,  1.17897822, -0.18035946, -1.95723115]), 1, 1.0, array([ 0.08833498,  1.37549794, -0.21950409, -2.29996786])), (array([-0.03329374, -0.00026287, -0.02387952,  0.01213699]), 1, 1.0, array([-0.03329899,  0.19519326, -0.02363678, -0.28798353])), (array([ 0.02938601,  0.78601366, -0.12203111, -1.29407828]), 1, 1.0, array

[(array([ 2.01455056,  2.48500374, -5.34248325, -4.87073652]), 1, 0.0, array([ 2.06425064,  2.68256215, -5.43989798, -4.80780108])), (array([ 0.21415037,  1.77078571, -0.43538719, -3.2112347 ]), 1, 0.0, array([ 0.24956608,  1.96462592, -0.49961189, -3.59886693])), (array([-0.02158224,  0.58616519, -0.04115698, -0.8898227 ]), 0, 1.0, array([-0.00985894,  0.39162514, -0.05895343, -0.61035647])), (array([ 1.17761416,  1.956691  , -2.89370064, -9.48199939]), 1, 0.0, array([ 1.21674798,  2.12613334, -3.08334062, -9.30774132])), (array([ 1.08537594,  2.42440975, -2.5410914 , -8.58557112]), 0, 0.0, array([ 1.13386414,  2.18750105, -2.71280282, -9.04489079])), (array([-0.03329374, -0.00026287, -0.02387952,  0.01213699]), 1, 1.0, array([-0.03329899,  0.19519326, -0.02363678, -0.28798353])), (array([ 0.45336402,  2.51660788, -0.89868662, -5.0903259 ]), 1, 0.0, array([ 0.50369617,  2.69111717, -1.00049314, -5.48336802])), (array([ 0.78307738,  2.73692249, -1.60799603, -6.94512407]), 1, 0.0, array

[(array([ 1.03956366,  2.29061403, -2.37059862, -8.52463878]), 1, 0.0, array([ 1.08537594,  2.42440975, -2.5410914 , -8.58557112])), (array([ 0.83781583,  2.87443755, -1.74689851, -7.23124916]), 0, 0.0, array([ 0.89530458,  2.64302805, -1.89152349, -7.58151429])), (array([ 3.30096618,  2.49036144, -7.30110852, -4.92809825]), 0, 0.0, array([ 3.35077341,  2.29199668, -7.39967049, -5.02204566])), (array([ 1.48847926,  2.6955834 , -4.16111673, -7.93202855]), 1, 0.0, array([ 1.54239092,  2.93659631, -4.3197573 , -7.49222853])), (array([  4.17966252,   1.86728707, -10.5709516 ,  -6.19843023]), 1, 0.0, array([  4.21700826,   2.08850782, -10.6949202 ,  -5.79383233])), (array([ 0.04510628,  0.98245653, -0.14791267, -1.62233946]), 1, 1.0, array([ 0.06475541,  1.17897822, -0.18035946, -1.95723115])), (array([ 0.2888586 ,  1.77270442, -0.57158922, -3.48702391]), 1, 0.0, array([ 0.32431269,  1.96384117, -0.6413297 , -3.88720003])), (array([  4.13936564,   2.01484398, -10.44429271,  -6.3329444 ]), 0