In [183]:
from collections import namedtuple
from functools import partial
import random
from typing import List, NamedTuple

import gym

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

import matplotlib.pyplot as plt
import plotly.express as ple

from IPython.display import clear_output

In [184]:
Transition = namedtuple('Transition', 'state action reward next_state done')

In [460]:
class VectorizeWrapper(gym.Wrapper):
    def __init__(self, make_env, num_envs: int=1):
        super().__init__(make_env())
        self.num_envs = num_envs
        self.envs = [make_env() for env_index in range(self.num_envs)]
    
    def reset(self):
        return np.asarray([env.reset() for env in self.envs])
    
    def reset_at(self, idx: int):
        return self.envs[idx].reset()
    
    def step(self, actions):
        assert len(actions) == len(self.envs)
        next_states, rewards, dones, infos = [], [], [], []
        for env, action in zip(self.envs, actions):
            next_state, reward, done, info = env.step(action)
            next_states.append(next_state)
            rewards.append(reward)
            dones.append(done)
            infos.append(info)
        return np.asarray(next_states), rewards, dones, infos

In [461]:
class DiscreteConverter(gym.ObservationWrapper):
    def __init__(self, env):
        super(DiscreteConverter, self).__init__(env)
        self.n = self.observation_space.n
        self.observation_space = gym.spaces.Box(0, 1, (self.n,))
    
    def obs(self, obs):
        new_obs = np.zeros(self.n)
        new_obs[obs] = 1
        return new_obs

In [462]:
class ReplayBuffer:
    def __init__(self, capacity: int = 100000):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
        
    def push(self, transition):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = transition
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size: int) -> List[NamedTuple]:
        if len(self.buffer) < batch_size:
            raise ValueError(f"Can't sample {batch_size} num elements from buffer of size {self.buffer}")
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

In [463]:
def create_model(state_dim: int, action_dim: int, hidden_sizes: List[int]):
    model = tf.keras.models.Sequential()
    model.add(layers.InputLayer(input_shape=state_dim))
    for hidden_size in hidden_sizes:
        model.add(layers.Dense(hidden_size))
        model.add(layers.ReLU())
    model.add(layers.Dense(action_dim))
    
    return model

def sync_models(model1, model2):
    model2.set_weights(model1.get_weights())

In [464]:
class EpsScheduler:
    def __init__(self, init_eps: float = 1.0, final_eps: float = 0.01, steps: int = 10_000):
        self.init_eps = init_eps
        self.steps = steps
        self.final_eps = final_eps
        self.step = (self.init_eps - self.final_eps) / self.steps
        self.cur_step = 0
        self_eps = init_eps
        
    def __call__(self, current_step = None):
        if current_step is None:
            current_step = self.cur_step
            self.cur_step += 1
        eps = max(self.final_eps, self.init_eps - self.step * current_step)
        self._eps = eps
        return eps
    
    @property
    def eps(self):
        return self._eps

In [465]:
class Agent:
    def __init__(self, state_dim: int, action_dim: int):
        self.state_dim = state_dim
        self.action_dim = action_dim
        
    def act(self, state):
        raise NotImplementedError()
        
    def update(self, transitions):
        raise NotImplementedError()

In [478]:
class DQN(Agent):
    def __init__(self, state_dim: int, action_dim: int, hidden_sizes: List[int], gamma: float=0.95):
        super(DQN, self).__init__(state_dim, action_dim)
        self.q_net = create_model(state_dim, action_dim, hidden_sizes)
        self.optimizer = tf.keras.optimizers.Adam(0.001)
        self.eps = EpsScheduler(steps=10_000)
        self.gamma = gamma
        
    def act(self, state):
        batch_size = state.shape[0]
        eps = self.eps()
        bern_samples = (np.random.rand(batch_size) < eps).astype(int)
        random_steps = np.random.randint(self.action_dim, size=batch_size)
        best_q_steps = np.argmax(self.q_net(state), axis=1)
        action = bern_samples * random_steps + (1 - bern_samples) * best_q_steps
        return action
    
    def _prepare_batches(self, transitions):
        state, action, reward, next_state, done = [], [], [], [], []
        for transition in transitions:
            state.append(transition.state)
            action.append(transition.action)
            reward.append(transition.reward)
            next_state.append(transition.next_state)
            done.append(transition.done)
        state = np.array(state)
        action = np.array(action)
        reward = np.array(reward)
        next_state = np.array(next_state)
        done = np.array(done)
        return state, action, reward, next_state, done
    
    def update(self, transitions):
        state, action, reward, next_state, done = self._prepare_batches(transitions)
        print(state)
        #print(state, action, reward, next_state, done)
        with tf.GradientTape() as tape:
            Q_pred = tf.reduce_sum(self.q_net(state)*tf.one_hot(action, self.action_dim, dtype=tf.float32), axis=1)
            Q_next = tf.stop_gradient(tf.reduce_max(self.q_net(next_state), axis=1))
            #Q_next = tf.reduce_max(self.q_net(next_state), axis=1)
            loss = tf.reduce_mean((reward + (1 - done) * self.gamma * Q_next - Q_pred)**2)
        gradients = tape.gradient(loss, self.q_net.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.q_net.trainable_variables))

In [479]:
class DDQN:
    def __init__(self, state_dim: int, action_dim: int, hidden_sizes: List[int], gamma: float=0.95):
        super(DDQN, self).__init__(state_dim, action_dim)
        self.q_net = create_model(state_dim, action_dim, hidden_sizes)
        self.target_net = create_model(state_dim, action_dim, hidden_sizes)
        sync_models(self.q_net, self.target_net)
        self.optimizer = tf.keras.optimizers.Adam()
    
    def synchronize(self):
        sync_models(self.q_net, self.target_net)
    
    def update(self):
        with tf.GradientTape as tape:
            pass

In [480]:
def train(env_name, num_steps: int = 20000, num_envs: int = 32, plot_every: int=300):
    env = VectorizeWrapper(partial(gym.make, env_name), num_envs)
    state_dim = env.observation_space.shape
    action_dim = env.action_space.n
    agent = DQN(state_dim, action_dim, hidden_sizes=[24, 24], gamma=0.95)
    
    rewards = []
    #replay_buffer = ReplayBuffer()
    episode_rewards = [0 for _ in range(num_envs)]
    
    state = env.reset()
    for t in range(num_steps):
        action = agent.act(state)
        next_state, reward, done, info = env.step(action)
        transitions = [Transition(state=s, action=a, reward=r, next_state=n, done=d)
                       for s, a, r, n, d in zip(state, action, reward, next_state, done)]
        #for transition in transitions:
        #    replay_buffer.push(transition)
        
        #if len(replay_buffer) >= 32:
        #    transitions = replay_buffer.sample(32)
        agent.update(transitions)
        
        for i in range(num_envs):
            episode_rewards[i] += reward[i]
        
        for i in range(num_envs):
            if done[i]:
                rewards.append(episode_rewards[i])
                episode_rewards[i] = 0
                env.reset_at(i)
        if t % plot_every == 99:
            clear_output()
            print("Eps: ", agent.eps.eps)
            plt.plot(rewards)
            plt.show()
        state = next_state
    return rewards

In [481]:
rewards = train('CartPole-v0', num_steps=20000, num_envs=32)



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

[[ 0.03773911 -0.01753327 -0.04134456 -0.01408053]
 [-0.02640657  0.0342235   0.0101877  -0.01205414]
 [ 0.04297993 -0.00646969 -0.03191558  0.01064766]
 [-0.04628796 -0.01666316 -0.04803373  0.01202409]
 [ 0.00908547 -0.02783308  0.04316197 -0.04640822]
 [-0.00825338  0.04272905 -0.00192835  0.03838726]
 [ 0.04525715  0.04420695 -0.0067588   0.00326507]
 [ 0.03234303 -0.02482226 -0.02468527  0.00791439]
 [-0.01933219 -0.03730451  0.00564883 -0.01719497]
 [-0.01435952  0.0320611  -0.0027177   0.03380491]
 [ 0.00078912  0.00278776  0.00022677  0.03051471]
 [-0.0121811  -0.01772245  0.0306801  -0.00233361]
 [ 0.04313494 -0.00246167  0.00909308 -0.00166253]
 [ 0.01332417 -0.01957167 -0.018591 

[[ 0.07342271  0.19139945 -0.13066936 -0.61618914]
 [ 0.06904181 -0.14842384 -0.13738432  0.00561857]
 [ 0.0266319  -0.1941339  -0.02653068  0.13887532]
 [-0.05688578 -0.97696802 -0.06710439  1.14147698]
 [ 0.00886713 -0.22813466  0.03672853  0.36077016]
 [-0.07570151  0.22784582  0.13665274 -0.03401233]
 [ 0.01206306  0.18608257 -0.03774558 -0.3640627 ]
 [-0.13032283 -1.39670145  0.21505308  2.23134437]
 [-0.11521552 -0.62945897  0.13753219  1.0189035 ]
 [-0.06047933 -0.5541688   0.08414114  0.93354715]
 [ 0.06372461  0.97821908 -0.08050306 -1.43216318]
 [ 0.05981466  0.35775189 -0.05070195 -0.66926812]
 [ 0.0893035   0.1953201  -0.06115246 -0.35360807]
 [-0.04606517 -0.60360581  0.04521405  0.80768887]
 [-0.12133331 -1.38391461  0.1217226   2.00047326]
 [-0.03936441  0.23383989 -0.04069672 -0.42407662]
 [-0.09769984 -0.61771127  0.03982055  0.80885676]
 [ 0.15471611  1.40300661 -0.18533521 -2.27891231]
 [ 0.04219344  0.63277097 -0.04678753 -0.82416505]
 [-0.03263914 -0.18425895 -0.00

[[ 4.59963640e-02  6.01201838e-01 -6.08631731e-02 -9.84192217e-01]
 [-4.45359677e-02 -9.98060595e-01  9.58542578e-02  1.52000259e+00]
 [-5.14392406e-02 -4.39252116e-01  4.70163472e-04  5.48483187e-01]
 [-2.68149026e-01 -5.96120713e-01  1.47456483e-01  7.67332516e-01]
 [ 4.19737884e-03  1.48415023e-01  8.89581118e-02  7.75732809e-02]
 [ 2.54900979e-02 -3.45825214e-02  3.69791992e-02  1.05545631e-01]
 [-4.08909813e-02 -6.30643101e-01  1.65024538e-01  1.06727285e+00]
 [ 4.56371863e-02 -3.57804460e-01  2.02143299e-02  7.21887923e-01]
 [ 1.88509308e-02  4.26754003e-01  7.91992614e-02 -3.11868178e-01]
 [ 1.05018360e-01  8.10026029e-01 -1.76345949e-01 -1.33713067e+00]
 [-4.19017750e-02 -3.28570310e-03  6.40166258e-02  8.94610244e-02]
 [-2.19485770e-02 -3.48992056e-02 -4.22016269e-02 -8.42157633e-02]
 [-1.03648689e-01 -1.74831641e+00  1.66223504e-01  2.43138665e+00]
 [ 1.21511220e-01  9.50730795e-01 -6.99950864e-02 -1.38360575e+00]
 [ 2.25257992e-02 -3.80449666e-01  5.92645258e-02  6.77416766e

[[-2.28129818e-02 -4.35251311e-01  2.93103561e-03  5.56862176e-01]
 [-1.06293195e-01 -1.20664386e+00  1.42195929e-01  1.89239997e+00]
 [-2.08495691e-02  3.40440439e-01 -6.93409407e-02 -6.05940598e-01]
 [-3.15737497e-02  3.50623297e-01  1.51531251e-02 -6.17613591e-01]
 [ 7.97724237e-03  3.92021818e-01  4.42189976e-03 -5.65739429e-01]
 [ 8.28264601e-02  1.12329063e+00  5.90090990e-03 -1.36731587e+00]
 [ 9.87340665e-02  7.85455431e-01 -2.30562170e-01 -1.56624242e+00]
 [ 7.00582620e-02  2.11650308e-02  7.51175999e-02  3.86143136e-01]
 [ 9.09133578e-02  2.27865792e-02  1.18911508e-01  5.80993930e-01]
 [-2.15788313e-03  1.49547319e-01  8.38652772e-03 -2.43492183e-01]
 [-1.46410541e-02 -6.22158908e-01 -5.84497962e-03  8.17560706e-01]
 [ 3.32937185e-04  3.74492373e-01 -1.65744444e-01 -1.10522385e+00]
 [-5.28338500e-02 -9.57364753e-01  2.19669938e-01  1.85979001e+00]
 [-8.55847052e-02 -2.33027558e-01  3.51566233e-02  2.12677915e-01]
 [ 6.78192529e-03 -4.04681446e-01  1.84743691e-01  1.22941503e

[[ 4.69265404e-02 -1.53233779e-01  2.06518039e-02  3.00663359e-01]
 [-5.22398454e-03  2.04879137e-01  1.80606768e-02 -3.08623299e-01]
 [-8.77194458e-02 -8.11008508e-01  1.54319186e-01  1.33690239e+00]
 [-5.37471009e-02 -1.99270685e-01 -6.31966109e-03  2.86530713e-01]
 [ 1.03369303e-01  7.95091661e-01 -1.55934528e-01 -1.44865421e+00]
 [ 5.23271102e-02  3.44869668e-01 -7.90970858e-02 -6.66983127e-01]
 [ 2.86853698e-02  6.00684538e-01 -1.00096192e-02 -9.09447090e-01]
 [ 4.34256416e-02 -1.75630969e-02 -1.35761715e-02  2.18040489e-02]
 [ 2.96116460e-01  1.16046845e+00  1.09692109e-01 -4.52971899e-01]
 [ 2.16236985e-02  1.45746062e-01 -1.98881599e-02 -1.59660646e-01]
 [-8.21515657e-02  1.46745865e-01  8.10921066e-02 -9.81059588e-02]
 [ 1.85830312e-01  9.35213595e-01 -1.64395371e-01 -1.54088248e+00]
 [-7.13364719e-02 -7.50904114e-01  9.64626900e-02  1.23330671e+00]
 [-9.10490338e-02  5.40148157e-01  2.29221469e-02 -7.97267949e-01]
 [ 4.65089383e-02  1.57491989e-01 -4.59016883e-02 -3.60581668e

[[ 2.73936784e-02 -3.56150718e-01  8.93356159e-02  7.67828500e-01]
 [-4.91849922e-02 -7.72580348e-01  8.48964342e-02  1.19889523e+00]
 [-1.29105236e-02  5.06273387e-02 -1.97925850e-02 -1.01105119e-01]
 [ 3.88088930e-02  5.79800101e-01 -2.66747889e-02 -9.30207994e-01]
 [-4.65834883e-02 -2.49772100e-01  1.52563891e-01  6.07066634e-01]
 [-5.79535516e-02 -1.37485880e+00  5.15716140e-02  1.90496755e+00]
 [ 4.57905231e-03  5.62958211e-01 -3.76194917e-02 -8.25716856e-01]
 [ 1.09785492e-02  5.65642526e-01  3.08691369e-02 -8.09230358e-01]
 [ 5.74704703e-01  1.33029395e+00  1.45268742e-01 -1.86757201e-01]
 [-3.26483280e-02 -5.74774343e-02  8.82196226e-02  3.12879048e-01]
 [-4.69695638e-02 -6.02379710e-01  4.95521446e-02  8.98222685e-01]
 [ 3.70488042e-02  4.02001816e-01  6.63151433e-03 -5.66420179e-01]
 [-1.00145352e-01 -4.29878121e-01  1.32874272e-01  7.99644131e-01]
 [ 5.28130862e-02  7.51968622e-01 -2.17548312e-01 -1.48865814e+00]
 [ 4.55334773e-02 -1.62249086e-02 -1.33490682e-01 -5.44193351e

[[-6.23191475e-02  2.18867519e-01 -1.81604559e-02 -3.74343646e-01]
 [-7.26921498e-04  1.83933989e-01 -6.00450456e-02 -3.47213165e-01]
 [-3.74915514e-02 -1.48056993e-01 -3.37985983e-02  2.13084950e-01]
 [ 1.01822323e-01 -1.78577707e-01 -1.89846051e-01 -2.51330622e-01]
 [ 2.11594547e-01  1.28692756e+00 -5.43380290e-02 -1.20696685e+00]
 [-3.75560317e-02 -3.30390887e-02  9.32747587e-02  2.70133288e-01]
 [ 4.67915009e-02 -1.96872224e-01 -1.47746523e-01 -1.11428149e-01]
 [-5.80636564e-02 -2.19266084e-01  2.04858179e-02  2.90956200e-01]
 [ 3.72633977e-02 -2.10656022e-01 -1.88710257e-03  3.21571935e-01]
 [ 5.46267712e-02 -7.95695147e-02  1.03431779e-01  8.02685030e-01]
 [ 6.13907120e-02 -1.23180598e-02 -6.56111847e-02 -1.21108225e-01]
 [ 5.77786909e-02 -7.56827514e-01 -4.06176988e-02  9.28375475e-01]
 [-3.46411301e-03 -1.44687871e-01 -4.18601852e-02  2.66379819e-01]
 [ 9.45353706e-02  8.12792610e-01 -1.96413463e-01 -1.42722646e+00]
 [-3.77478668e-02 -1.87914473e-01  3.18897426e-02  3.43162792e

AttributeError: 'DQN' object has no attribute 'synchronize'

In [475]:
ple.line(x=range(len(rewards)), y=rewards, title="Q-learning")