In [1]:
import torch

def DQN(env, memory, q_net, t_net, optim, steps = 10000, eps = 1, disc_factor = 0.99, loss = torch.nn.MSELoss(), batch_sz = 128, tgt_update = 10, early = True,
        eps_decay = lambda eps, max_steps, step: eps - eps/max_steps,
        act = lambda s, eps, env: torch.tensor(env.action_space.sample()) if torch.rand(1) < eps else q_net(s).max(0)[1]):

    optimizer = optim(q_net.parameters(), lr = q_net.lr)
    ret = 0
    returns = []
    s = torch.tensor(env.reset(), dtype=torch.float32)  
    for step in range(steps):      
        a = act(s, eps, env)

        s_prime, r, done, _ = env.step(a.numpy())
        s_prime = torch.tensor(s_prime, dtype=torch.float32)
        eps = eps_decay(eps, steps, step)
        
        memory.push(s, a, r, s_prime, done)
        ret += r
        # Optimize
        if step >= batch_sz:
            s_, a_, r_, s_p, d_ = memory.sample(batch_sz)            
            y = r_ + disc_factor * q_net(s_p).max(1)[0] * (1 - d_)  
            predictions = q_net(s_).gather(1, a_.long()).flatten()          
            l = loss(y, predictions)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            
        if step % tgt_update == 0:
            t_net.load_state_dict(q_net.state_dict())
        
        # Test for early break
        if early and done:
            ret = 0
            for _ in range(100):
                done = False
                state = torch.tensor(env.reset(), dtype=torch.float32)
                while not done:
                    s, r, done, _ = env.step(torch.argmax(q_net(s)).numpy())
                    s = torch.tensor(s, dtype=torch.float32)
                    ret += r
            if 195 <= ret/100:
                print('converged in %i steps' %step)
                break
                    
        s = torch.tensor(env.reset(), dtype=torch.float32) if done else s_prime

        

# $\epsilon$-greedy

In [4]:
# Exponential decay


def exp_decay(s, max_steps, step):
    return 0.01 + (1 - 0.01) * math.exp(-1. * step / 500)
    

# Replay Memory

In [5]:
import random
import torch

class ReplayMemory():

    def __init__(self, capacity, state_dim, action_dim):
        self.capacity = capacity
        self.s = torch.zeros([capacity, state_dim])
        self.a = torch.zeros([capacity, 1])
        self.r = torch.zeros([capacity])
        self.s_prime = torch.zeros([capacity, state_dim])
        self.done = torch.zeros([capacity])
        self.mem_ptr = 0

    def push(self, s, a, r, s_prime, done):
        self.s[self.mem_ptr] = s
        self.a[self.mem_ptr] = a
        self.r[self.mem_ptr] = r
        self.s_prime[self.mem_ptr] = s_prime
        self.done[self.mem_ptr] = done
        self.mem_ptr = (self.mem_ptr + 1) % self.capacity
        
    def sample(self, batch_size):
        idx = torch.randperm(self.s.shape[0])[:batch_size]
        return self.s[idx], self.a[idx], self.r[idx], self.s_prime[idx], self.done[idx]

    def __len__(self):
        return len(self.memory)

# Testrun

In [48]:
from torch.distributions import Categorical
import gym
import torch.nn as nn
from py_inforce.generic.mlp import MLP
#from py_inforce.value_based.DQN import DQN
#from py_inforce.generic.Memory import ReplayMemory
import torch.optim as optim
import torch
import numpy as np
import math

env = gym.make('CartPole-v0')
in_dim = env.observation_space.shape[0] # 4
out_dim = env.action_space.n # 2
q_net = MLP([in_dim, 128, 128, out_dim], nn.ReLU, LEARN_RATE = 0.005)
t_net = MLP([in_dim, 128, 128, out_dim], nn.ReLU)
optimizer = optim.Adam
memory = ReplayMemory(1000, in_dim, out_dim)

DQN(env, memory, q_net, t_net, optimizer, steps = 10_000, eps = 1, disc_factor = 0.99, loss = torch.nn.MSELoss(), batch_sz = 32, tgt_update = 100)


converged in 2369 steps


In [49]:
ret = 0
for i in range(100):
    s = torch.tensor(env.reset(), dtype=torch.float32)
    done = False
    
    while not done:
        s, r, done, _ = env.step(torch.argmax(q_net(s)).numpy())
        s = torch.tensor(s, dtype=torch.float32)
        ret += r
ret/100

196.08

In [235]:
torch.tensor([1,2,3]) * torch.tensor([2, 4, 6])

tensor([ 2,  8, 18])

In [143]:
s = torch.tensor(env.reset(), dtype=torch.float32)
s_ = torch.stack([s,s,s,s,s])
q_net = MLP([in_dim, 128, 128, out_dim], nn.ReLU)
s_, q_net(s_), q_net(s_).max(1)[0]

(tensor([[-0.0449, -0.0448, -0.0150,  0.0103],
         [-0.0449, -0.0448, -0.0150,  0.0103],
         [-0.0449, -0.0448, -0.0150,  0.0103],
         [-0.0449, -0.0448, -0.0150,  0.0103],
         [-0.0449, -0.0448, -0.0150,  0.0103]]),
 tensor([[0.0197, 0.0036],
         [0.0197, 0.0036],
         [0.0197, 0.0036],
         [0.0197, 0.0036],
         [0.0197, 0.0036]], grad_fn=<AddmmBackward>),
 tensor([0.0197, 0.0197, 0.0197, 0.0197, 0.0197], grad_fn=<MaxBackward0>))

In [275]:
l = torch.nn.MSELoss()
l(torch.tensor([1,2,3,4,5,6,7.0]), torch.ones(7))

tensor(13.)