# Deep Q Network

In [1]:
import sys
import argparse
import copy
import json
import random
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns; sns.set; sns.set_style("whitegrid")
import gym
from gym import wrappers, logger

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F

from collections import deque
import random
from utils import set_all_seeds
from tqdm import tqdm

<function seaborn.rcmod.set>

In [6]:
class NN(nn.Module):
    def __init__(self, inSize, outSize, layers=[]):
        super(NN, self).__init__()
        self.layers = nn.ModuleList([])
        for x in layers:
            self.layers.append(nn.Linear(inSize, x))
            inSize = x
        self.layers.append(nn.Linear(inSize, outSize))
        
    def forward(self, x):
        x = self.layers[0](x)
        for i in range(1, len(self.layers)):
            x = torch.nn.functional.leaky_relu(x)
            x = self.layers[i](x)
        return x

In [73]:
class DQL_ER(object):
    def __init__(self, inSize, outSize, env, envx, layers=[200], size_deque=300,
                 minibatch_size=200, eps=0.5, gamma=0.99, lr=0.001, C=4):
        self.D = deque(maxlen=size_deque)
        self.action_space = env.action_space
        self.Q = NN(inSize, outSize, layers)
        self.Q_hat = NN(inSize, outSize, layers)
        self.Q_hat.load_state_dict(self.Q.state_dict())
        self.optimizer = optim.Adam(self.Q.parameters(), lr=lr)
        self.loss = nn.MSELoss()  # nn.SmoothL1Loss()

        self.eps = eps
        self.C = C
        self.envx = envx
        self.gamma = gamma
        self.minibatch_size = minibatch_size

    def fit(self, M, graph=True, verbose=False):
        reward_list = []
        loss_list = []
        v = 0
        c = 1
        for m in range(M):
            state = env.reset()
            done = False
            r = 0
            mean_reward = []
            mean_loss = []
            while not done:
                # With probability eps select a random action
                action = self.select_action(state)

                # Execute action and observe reward r and next state
                next_state, reward, done, _ = env.step(action)
                r += reward

                # Si done = True alors 1 - done = 0 et on retrouve bien que
                # target = reward
                target = reward + (1 - done) * self.gamma * \
                    self.get_max(next_state)
                self.D.append([state, action, reward, next_state, target])

                # Gradient descent
                if c >= self.minibatch_size:
                    minibatch = random.sample(self.D, self.minibatch_size)
                    y = torch.tensor([mb[4] for mb in minibatch],
                                     dtype=torch.float32)
                    x = torch.tensor([mb[0] for mb in minibatch],
                                     dtype=torch.float32, requires_grad=True)
                    self.optimizer.zero_grad()
                    y_pred = self.Q(x)
                    loss = self.loss(torch.max(y_pred, 1)[0], y)
                    loss.backward()
                    self.optimizer.step()   
                    mean_reward.append(r)
                    mean_loss.append(loss)
                        
                    if done:
                        reward_list.append(sum(mean_reward)/len(mean_reward))
                        loss_list.append(sum(mean_loss)/len(mean_loss))
                    
                    if c % self.C == 0:
                        self.Q_hat.load_state_dict(self.Q.state_dict())
                                        
                state = next_state
                c += 1
                                                         
            if verbose:
                print("itération:", v, "reward:", r)

            v += 1
            self.update_eps()

            if c % 100 == 0:
                self.test(self.envx, 1, verbose=True, graph=False)
        results = pd.DataFrame(np.vstack((reward_list, list(map(lambda t: t.item(), loss_list)))).T, columns=['reward', 'loss'])
        results.to_csv("dqn_lunarlander.csv", index=False)
        
        if graph:
            x = [i for i in range(1, M + 1)]
            sns.set_style("whitegrid")
            plt.figure(figsize=(10, 6))
            plt.plot(range(len(reward_list)), reward_list, color='darkblue')
            plt.xlabel("Nombre de parties")
            plt.ylabel("Score")
            plt.title("Reward")
            plt.show()
            plt.figure(figsize=(10, 6))
            plt.plot(range(len(loss_list)), loss_list, color='darkred')
            plt.xlabel("Nombre de parties")
            plt.ylabel("Loss")
            plt.show()

    def select_action(self, state):
        if np.random.uniform() < self.eps:
            return self.action_space.sample()

        with torch.no_grad():
            pred = self.Q(torch.tensor(state, dtype=torch.float32))
        return torch.argmax(pred).numpy()

    def act(self, state):
        with torch.no_grad():
            pred = self.Q(torch.tensor(state, dtype=torch.float32))
        return torch.argmax(pred).numpy()

    def get_max(self, next_state):
        with torch.no_grad():
            #pred = self.Q(torch.tensor(next_state, dtype=torch.float32))
            pred = self.Q_hat(torch.tensor(next_state, dtype=torch.float32))
        maxi = torch.max(pred.detach())
        return maxi

    def update_eps(self):
        self.eps -= 0.005

    def test(self, envx, T, graph=True, verbose=False, demo_jeu=False):
        reward_list = []
        for i in range(T):
            if demo_jeu:
                envx.render(1)
            state = env.reset()
            done = False
            r = 0
            while not done:
                action = self.act(state)
                if demo_jeu:
                    envx.render()
                next_state, reward, done, _ = env.step(action)
                r += reward
                state = next_state

            if verbose:
                print(r)
            reward_list.append(r)
        if graph:
            plt.plot([i for i in range(1, T + 1)], reward_list)
            plt.xlabel("Nombre de parties")
            plt.ylabel("Score")
            plt.show()

## CartPole

In [None]:
inSize = 4
outSize = 2
M = 700

env_id = 'CartPole-v1'
outdir = 'cartpole-v0/random-agent-results'
envx = gym.make(env_id)
envx.verbose = True
env = envx
env = wrappers.Monitor(envx, directory=outdir, force=True,
                       video_callable=False)
set_all_seeds(env, seed=1)

dql = DQL_ER(inSize, outSize, env, envx)
dql.fit(M, graph=False, verbose=True)

## LunarLander 

In [74]:
eps = 0.1
lr = 0.001
M = 1000
minibatch_size = 32
size_deque = 1000000
layers = [20, 20]
gamma = 0.9


env_id = 'LunarLander-v2'
outdir = 'LunarLander-v2/random-agent-results'
envx = gym.make(env_id)
envx.verbose = True
env = envx
env = wrappers.Monitor(envx, directory=outdir, force=True,
                       video_callable=False)
set_all_seeds(env, seed=0)

inSize = env.observation_space.shape[0]
outSize = env.action_space.n


dql = DQL_ER(inSize, outSize, env, envx, eps=eps, lr=lr, size_deque=size_deque, layers=layers,
             minibatch_size=minibatch_size, gamma=gamma)
dql.fit(M, verbose=True, graph=False)

itération: 0 reward: -166.51774379966204
itération: 1 reward: -178.00211686165625
itération: 2 reward: -165.84237934706584
itération: 3 reward: -105.15512657656268
itération: 4 reward: -104.14215029712724
itération: 5 reward: -124.51097775319576
itération: 6 reward: -12.496357330648536
itération: 7 reward: -113.74187502073629
itération: 8 reward: -118.69478970383526
itération: 9 reward: -145.47021113342925
itération: 10 reward: -141.82036552336103
itération: 11 reward: -145.63526210010173
itération: 12 reward: -101.10689307044467
itération: 13 reward: -187.36825899654832
itération: 14 reward: -264.0566585160358
itération: 15 reward: -378.59188442090715
itération: 16 reward: -297.5208808993656
itération: 17 reward: -104.36071701455361
itération: 18 reward: -160.44392873290047
-137.133239821772
itération: 19 reward: 8.020490484952205
itération: 20 reward: -173.54690042360534
itération: 21 reward: -136.1416559511919
itération: 22 reward: -116.52501216255246
itération: 23 reward: -140.0311

itération: 194 reward: -91.30469468913006
itération: 195 reward: -121.24231348242377
itération: 196 reward: -145.66972176415894
itération: 197 reward: -137.67576302354234
itération: 198 reward: -110.2139617638349
itération: 199 reward: -137.78911061027884
itération: 200 reward: -129.57020585198376
itération: 201 reward: -118.28664648185433
itération: 202 reward: -266.51590140120123
itération: 203 reward: -115.96856227771119
itération: 204 reward: -188.98415218489515
itération: 205 reward: -90.43165890132127
itération: 206 reward: -55.12409585708269
itération: 207 reward: -150.38132886854004
itération: 208 reward: -153.14558364738576
itération: 209 reward: -150.69650584027434
itération: 210 reward: -125.46991697926654
itération: 211 reward: -301.2516188156862
itération: 212 reward: -178.84618766619792
itération: 213 reward: -140.8529315326693
itération: 214 reward: -167.81796710337267
itération: 215 reward: -156.4879306214823
itération: 216 reward: -147.20645705113193
itération: 217 rew

itération: 385 reward: -130.542538244893
itération: 386 reward: -144.19427232561077
itération: 387 reward: -117.34307567989522
itération: 388 reward: -114.69095646171226
itération: 389 reward: -144.4083315813874
itération: 390 reward: 19.21100463462618
itération: 391 reward: -114.4001802036704
itération: 392 reward: -94.55812542754929
itération: 393 reward: -97.1588335570082
itération: 394 reward: -160.58473883086276
itération: 395 reward: -133.24853778583045
itération: 396 reward: -138.08812612158607
itération: 397 reward: -97.94999353896236
itération: 398 reward: -212.0540895737069
itération: 399 reward: -124.17739080588737
itération: 400 reward: -115.30948998619317
itération: 401 reward: -121.74586368780771
itération: 402 reward: -101.29730087774664
itération: 403 reward: -125.03013809246733
itération: 404 reward: -94.87334675915451
itération: 405 reward: -123.51983417365895
-90.40646446500149
itération: 406 reward: -126.26326877969251
itération: 407 reward: -83.54743037327836
itéra

itération: 577 reward: -97.05497580110551
itération: 578 reward: -112.07063221742175
itération: 579 reward: -101.56326801298431
itération: 580 reward: -112.22553879045134
itération: 581 reward: -108.21945158468097
itération: 582 reward: -41.57069119042097
itération: 583 reward: -146.7667760923136
itération: 584 reward: -124.89526119815977
itération: 585 reward: -95.92556856970293
itération: 586 reward: -118.34555060831721
itération: 587 reward: -121.84898089701016
itération: 588 reward: -152.5712544435944
itération: 589 reward: -112.65208179742413
itération: 590 reward: -89.17968957029774
itération: 591 reward: -63.14486396936594
itération: 592 reward: -123.3185339505859
itération: 593 reward: -149.352050822354
itération: 594 reward: -55.76397162371173
itération: 595 reward: -120.32810490968083
itération: 596 reward: -134.63960031929437
itération: 597 reward: -256.1205903043065
itération: 598 reward: -137.81853087476412
itération: 599 reward: -232.942503010729
itération: 600 reward: -1

itération: 770 reward: -61.53221839741582
itération: 771 reward: -291.6353895073464
itération: 772 reward: -124.38996015216804
itération: 773 reward: -67.91391787516898
itération: 774 reward: -110.37415004895726
itération: 775 reward: -83.84338358060603
itération: 776 reward: -177.1504594887187
itération: 777 reward: -134.05714950294995
itération: 778 reward: -138.04276933687257
itération: 779 reward: -111.18558613787187
itération: 780 reward: -121.27217211762482
itération: 781 reward: -99.8798918349046
itération: 782 reward: -58.900341156126444
itération: 783 reward: -125.2025627821244
itération: 784 reward: -157.76565101389505
itération: 785 reward: -123.82547517538512
itération: 786 reward: -156.15570510192495
itération: 787 reward: -279.86452191497096
itération: 788 reward: -118.19860688563952
itération: 789 reward: -77.98555311888589
itération: 790 reward: 27.180440620028122
itération: 791 reward: -132.957245533274
itération: 792 reward: -85.42953644119721
itération: 793 reward: -

itération: 963 reward: -195.42976868244097
itération: 964 reward: -229.02264450441618
itération: 965 reward: -252.02840697214384
itération: 966 reward: -82.54457342595347
itération: 967 reward: -153.92489372356908
itération: 968 reward: -65.83826053413726
itération: 969 reward: -301.5629508343907
itération: 970 reward: -59.72377251952423
itération: 971 reward: -137.07658730230762
itération: 972 reward: -116.11744509418651
itération: 973 reward: -233.46542917305112
itération: 974 reward: -92.09690198811634
itération: 975 reward: -150.72974060998916
itération: 976 reward: -213.70434526036317
itération: 977 reward: -263.45509428356513
itération: 978 reward: -239.57814235373078
itération: 979 reward: -278.08080949875733
itération: 980 reward: -152.65723200487324
itération: 981 reward: -119.74025879550337
itération: 982 reward: -129.7192037569617
itération: 983 reward: -271.7080523937923
itération: 984 reward: -114.95773724602765
itération: 985 reward: -200.12642475555865
itération: 986 rew