In [3]:
import gym
import random
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

2023-04-18 08:22:04.203076: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=20000)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.2
        self.epsilon_decay = 0.999
        self.learning_rate = 0.0001
        self.model = self._build_model()

    def _build_model(self):
        model = Sequential()
        model.add(Dense(128, input_dim=self.state_size, activation='relu'))
        model.add(Dense(128, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))
        return model

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        act_values = self.model.predict(state)
        return np.argmax(act_values[0])

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                target = (reward + self.gamma * np.amax(self.model.predict(next_state)[0]))
            target_f = self.model.predict(state)
            target_f[0][action] = target
            self.model.fit(state, target_f, epochs=1, verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)


In [9]:
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DQNAgent(state_size, action_size)
print('state size', state_size)
print('action size', action_size)

done = False
batch_size = 32
EPISODES = 1000

state size 4
action size 2


In [None]:
for e in range(EPISODES):
    state = env.reset()
    print(state)
    state = np.reshape(state, [1, state_size])
    print('new state', state)
    for time in range(500):
        # env.render()
        action = agent.act([state])
        next_state, reward, done, _ = env.step(action)
        reward = reward if not done else -10
        next_state = np.reshape(next_state, [1, state_size])
        agent.remember(state, action, reward, next_state, done)
        state = next_state
        if done:
            print("episode: {}/{}, score: {}, e: {:.2}".format(e, EPISODES, time, agent.epsilon))
            break
    if len(agent.memory) > batch_size:
        agent.replay(batch_size)
    if e % 100 == 0:
        agent.save("lunar_lander-dqn.h5") # lol wrong name here

[ 0.01646276 -0.03200918  0.04143762  0.01632202]
new state [[ 0.01646276 -0.03200918  0.04143762  0.01632202]]
episode: 0/1000, score: 24, e: 1.0
[ 0.00860815 -0.01877899 -0.02964288  0.01685256]
new state [[ 0.00860815 -0.01877899 -0.02964288  0.01685256]]
episode: 1/1000, score: 14, e: 1.0
[-0.00264927 -0.00089259  0.00946645  0.00933072]
new state [[-0.00264927 -0.00089259  0.00946645  0.00933072]]
episode: 2/1000, score: 30, e: 1.0
[ 0.04659918  0.03501585 -0.016171    0.04161484]
new state [[ 0.04659918  0.03501585 -0.016171    0.04161484]]
episode: 3/1000, score: 13, e: 1.0


[-0.00211357  0.00587764 -0.04813442 -0.02420195]
new state [[-0.00211357  0.00587764 -0.04813442 -0.02420195]]
episode: 4/1000, score: 21, e: 1.0
[ 0.01953335 -0.04364393 -0.03995696 -0.0087153 ]
new state [[ 0.01953335 -0.04364393 -0.03995696 -0.0087153 ]]
episode: 5/1000, score: 34, e: 1.0


[-0.04459907  0.04881554  0.00385285  0.01619631]
new state [[-0.04459907  0.04881554  0.00385285  0.01619631]]
episode: 6/1000, score: 9, e: 1.0
[ 0.04282948 -0.01674461 -0.02480228  0.04432823]
new state [[ 0.04282948 -0.01674461 -0.02480228  0.04432823]]
episode: 7/1000, score: 12, e: 0.99
[ 0.02778559  0.04817513 -0.00258144  0.00558985]
new state [[ 0.02778559  0.04817513 -0.00258144  0.00558985]]
episode: 8/1000, score: 24, e: 0.99


[-0.02298704  0.03677144  0.04272695 -0.00465129]
new state [[-0.02298704  0.03677144  0.04272695 -0.00465129]]
episode: 9/1000, score: 25, e: 0.99
[0.03192025 0.02329931 0.04922074 0.01713847]
new state [[0.03192025 0.02329931 0.04922074 0.01713847]]
episode: 10/1000, score: 36, e: 0.99


[ 0.02840015 -0.02754864  0.03451644  0.04858935]
new state [[ 0.02840015 -0.02754864  0.03451644  0.04858935]]
episode: 11/1000, score: 55, e: 0.99
[-0.01005673  0.0195688   0.01640136  0.04550994]
new state [[-0.01005673  0.0195688   0.01640136  0.04550994]]
episode: 12/1000, score: 43, e: 0.99


[ 0.04323788 -0.01067193 -0.0306183  -0.01023367]
new state [[ 0.04323788 -0.01067193 -0.0306183  -0.01023367]]
episode: 13/1000, score: 19, e: 0.99
[-0.01083941  0.00753363 -0.00167132 -0.02033372]
new state [[-0.01083941  0.00753363 -0.00167132 -0.02033372]]
episode: 14/1000, score: 19, e: 0.99
[-0.03036311 -0.04226918  0.01852961  0.02241991]
new state [[-0.03036311 -0.04226918  0.01852961  0.02241991]]
episode: 15/1000, score: 36, e: 0.99


[ 0.02935049 -0.02552028  0.03843107 -0.03806436]
new state [[ 0.02935049 -0.02552028  0.03843107 -0.03806436]]
episode: 16/1000, score: 20, e: 0.99
[-0.0016357   0.00652879  0.04405004  0.02417414]
new state [[-0.0016357   0.00652879  0.04405004  0.02417414]]
episode: 17/1000, score: 48, e: 0.98


[ 0.04918692 -0.04643186 -0.02164506  0.04421364]
new state [[ 0.04918692 -0.04643186 -0.02164506  0.04421364]]
episode: 18/1000, score: 14, e: 0.98
[ 0.03857005 -0.02397459 -0.02341094 -0.0354482 ]
new state [[ 0.03857005 -0.02397459 -0.02341094 -0.0354482 ]]
episode: 19/1000, score: 93, e: 0.98
[-0.00698499  0.01512515 -0.00216446  0.04821585]
new state [[-0.00698499  0.01512515 -0.00216446  0.04821585]]


episode: 20/1000, score: 30, e: 0.98
[0.01276033 0.04917678 0.01741035 0.02795635]
new state [[0.01276033 0.04917678 0.01741035 0.02795635]]
episode: 21/1000, score: 17, e: 0.98
[ 0.04325175 -0.02334804  0.04457727 -0.03362882]
new state [[ 0.04325175 -0.02334804  0.04457727 -0.03362882]]
episode: 22/1000, score: 15, e: 0.98


[-0.03015918 -0.02665718 -0.02491799 -0.03046013]
new state [[-0.03015918 -0.02665718 -0.02491799 -0.03046013]]
episode: 23/1000, score: 21, e: 0.98
[ 0.04037614 -0.02258519 -0.01761171 -0.02598657]
new state [[ 0.04037614 -0.02258519 -0.01761171 -0.02598657]]
episode: 24/1000, score: 23, e: 0.98


[-0.00983841  0.01620471  0.02858497 -0.00243957]
new state [[-0.00983841  0.01620471  0.02858497 -0.00243957]]
episode: 25/1000, score: 47, e: 0.98
[ 0.02628166  0.04559146 -0.02040808  0.02753219]
new state [[ 0.02628166  0.04559146 -0.02040808  0.02753219]]
episode: 26/1000, score: 12, e: 0.98
[ 0.0402048  -0.02128769 -0.03702277  0.0391637 ]
new state [[ 0.0402048  -0.02128769 -0.03702277  0.0391637 ]]
episode: 27/1000, score: 13, e: 0.97


[ 0.0056063   0.02753637 -0.03577342  0.00086803]
new state [[ 0.0056063   0.02753637 -0.03577342  0.00086803]]
episode: 28/1000, score: 27, e: 0.97
[ 0.04063189  0.03397553 -0.0162732  -0.02999967]
new state [[ 0.04063189  0.03397553 -0.0162732  -0.02999967]]
episode: 29/1000, score: 35, e: 0.97


[-0.02672876  0.01932293 -0.03811054  0.00771385]
new state [[-0.02672876  0.01932293 -0.03811054  0.00771385]]
episode: 30/1000, score: 18, e: 0.97
[-0.00681095  0.03566371 -0.00053287 -0.04593549]
new state [[-0.00681095  0.03566371 -0.00053287 -0.04593549]]
episode: 31/1000, score: 20, e: 0.97


[-0.01073595  0.01019452  0.03642305 -0.00483799]
new state [[-0.01073595  0.01019452  0.03642305 -0.00483799]]
episode: 32/1000, score: 36, e: 0.97
[ 0.02541576  0.03324082 -0.01045462  0.01668026]
new state [[ 0.02541576  0.03324082 -0.01045462  0.01668026]]
episode: 33/1000, score: 14, e: 0.97
[-0.00393205 -0.03280021 -0.0164567  -0.02997033]
new state [[-0.00393205 -0.03280021 -0.0164567  -0.02997033]]
episode: 34/1000, score: 24, e: 0.97


[-0.01472629  0.01346639 -0.04620692 -0.04480565]
new state [[-0.01472629  0.01346639 -0.04620692 -0.04480565]]
episode: 35/1000, score: 7, e: 0.97
[-0.03055283 -0.04033756 -0.00060058 -0.00143364]
new state [[-0.03055283 -0.04033756 -0.00060058 -0.00143364]]
episode: 36/1000, score: 21, e: 0.97


[-0.03154907  0.01968534 -0.03230795  0.04593696]
new state [[-0.03154907  0.01968534 -0.03230795  0.04593696]]
episode: 37/1000, score: 16, e: 0.96
[0.01102946 0.02460185 0.04980158 0.03788225]
new state [[0.01102946 0.02460185 0.04980158 0.03788225]]
episode: 38/1000, score: 17, e: 0.96


[ 0.03421929 -0.04285293 -0.01333477 -0.04377784]
new state [[ 0.03421929 -0.04285293 -0.01333477 -0.04377784]]
episode: 39/1000, score: 13, e: 0.96
[ 0.01668195  0.03720065 -0.0335441  -0.01171618]
new state [[ 0.01668195  0.03720065 -0.0335441  -0.01171618]]
episode: 40/1000, score: 76, e: 0.96
[ 0.01413283 -0.04206751  0.0473081   0.04215635]
new state [[ 0.01413283 -0.04206751  0.0473081   0.04215635]]
episode: 41/1000, score: 10, e: 0.96


[-0.03090731 -0.03804468 -0.01126964  0.04231154]
new state [[-0.03090731 -0.03804468 -0.01126964  0.04231154]]
episode: 42/1000, score: 14, e: 0.96
[ 0.01154293  0.01411008 -0.01937384  0.03746197]
new state [[ 0.01154293  0.01411008 -0.01937384  0.03746197]]
episode: 43/1000, score: 17, e: 0.96


[-0.00192994 -0.02963998 -0.03143064 -0.04720044]
new state [[-0.00192994 -0.02963998 -0.03143064 -0.04720044]]
episode: 44/1000, score: 23, e: 0.96
[ 0.04842616 -0.00304348 -0.02251    -0.01283614]
new state [[ 0.04842616 -0.00304348 -0.02251    -0.01283614]]
episode: 45/1000, score: 21, e: 0.96


[-0.01472645 -0.01407276  0.04570655  0.02608153]
new state [[-0.01472645 -0.01407276  0.04570655  0.02608153]]
episode: 46/1000, score: 11, e: 0.96
[-0.02650365  0.02973826  0.01023548  0.04865973]
new state [[-0.02650365  0.02973826  0.01023548  0.04865973]]
episode: 47/1000, score: 38, e: 0.96
[ 0.02843166 -0.03377319 -0.03964832  0.03201047]
new state [[ 0.02843166 -0.03377319 -0.03964832  0.03201047]]
episode: 48/1000, score: 17, e: 0.95


[-0.00941349  0.02796272 -0.03550189 -0.01611776]
new state [[-0.00941349  0.02796272 -0.03550189 -0.01611776]]
episode: 49/1000, score: 20, e: 0.95
[ 0.00118752  0.00437707 -0.01536883  0.00506385]
new state [[ 0.00118752  0.00437707 -0.01536883  0.00506385]]
episode: 50/1000, score: 15, e: 0.95


[ 0.03644073  0.04777988 -0.00272979 -0.02850156]
new state [[ 0.03644073  0.04777988 -0.00272979 -0.02850156]]
episode: 51/1000, score: 17, e: 0.95
[-0.04694821  0.00201641  0.01254763 -0.0298079 ]
new state [[-0.04694821  0.00201641  0.01254763 -0.0298079 ]]
episode: 52/1000, score: 18, e: 0.95


[-0.01070622  0.00092327 -0.03739711  0.03304345]
new state [[-0.01070622  0.00092327 -0.03739711  0.03304345]]
episode: 53/1000, score: 14, e: 0.95
[-0.01135794  0.00181781  0.00186671  0.03864685]
new state [[-0.01135794  0.00181781  0.00186671  0.03864685]]
episode: 54/1000, score: 20, e: 0.95
[-0.02563969  0.04001642 -0.04692454  0.02721323]
new state [[-0.02563969  0.04001642 -0.04692454  0.02721323]]
episode: 55/1000, score: 13, e: 0.95


[-0.04190527  0.03358785  0.04299444  0.04179275]
new state [[-0.04190527  0.03358785  0.04299444  0.04179275]]
episode: 56/1000, score: 15, e: 0.95
[ 0.04741172 -0.04864063  0.04989779  0.01197331]
new state [[ 0.04741172 -0.04864063  0.04989779  0.01197331]]
episode: 57/1000, score: 28, e: 0.95


[-0.00811679  0.04218794 -0.01422328 -0.04260173]
new state [[-0.00811679  0.04218794 -0.01422328 -0.04260173]]
episode: 58/1000, score: 9, e: 0.94
[ 0.0188599  -0.0464956  -0.00724291 -0.01181556]
new state [[ 0.0188599  -0.0464956  -0.00724291 -0.01181556]]
episode: 59/1000, score: 50, e: 0.94
[-0.04237885  0.03325386  0.04728063 -0.04095373]
new state [[-0.04237885  0.03325386  0.04728063 -0.04095373]]
episode: 60/1000, score: 15, e: 0.94


[-0.03175172 -0.03597973  0.00248829  0.01949846]
new state [[-0.03175172 -0.03597973  0.00248829  0.01949846]]
episode: 61/1000, score: 36, e: 0.94
[-0.04014236 -0.0399562  -0.02798967 -0.00618239]
new state [[-0.04014236 -0.0399562  -0.02798967 -0.00618239]]
episode: 62/1000, score: 13, e: 0.94


[ 0.0167656   0.00370527 -0.04398123 -0.04268166]
new state [[ 0.0167656   0.00370527 -0.04398123 -0.04268166]]
episode: 63/1000, score: 20, e: 0.94
[0.02073648 0.0428431  0.0300123  0.03333713]
new state [[0.02073648 0.0428431  0.0300123  0.03333713]]
episode: 64/1000, score: 8, e: 0.94


[ 0.00617635 -0.0477563   0.01356061 -0.02805784]
new state [[ 0.00617635 -0.0477563   0.01356061 -0.02805784]]
episode: 65/1000, score: 12, e: 0.94
[-0.00486757 -0.01780252  0.00666952 -0.0296965 ]
new state [[-0.00486757 -0.01780252  0.00666952 -0.0296965 ]]
episode: 66/1000, score: 23, e: 0.94
[-0.03562479  0.00152817 -0.01193096 -0.02841585]
new state [[-0.03562479  0.00152817 -0.01193096 -0.02841585]]
episode: 67/1000, score: 20, e: 0.94


[-0.02105985  0.03841231  0.02812487 -0.02589363]
new state [[-0.02105985  0.03841231  0.02812487 -0.02589363]]
episode: 68/1000, score: 14, e: 0.94
[-0.00635043 -0.03488942  0.04154727 -0.01133327]
new state [[-0.00635043 -0.03488942  0.04154727 -0.01133327]]
episode: 69/1000, score: 12, e: 0.93


[ 0.03557399  0.03326568 -0.00050776  0.03177371]
new state [[ 0.03557399  0.03326568 -0.00050776  0.03177371]]
episode: 70/1000, score: 11, e: 0.93
[-0.01312411  0.02249191 -0.03004391 -0.0038504 ]
new state [[-0.01312411  0.02249191 -0.03004391 -0.0038504 ]]
episode: 71/1000, score: 11, e: 0.93


[-0.04037897 -0.04930321 -0.03231727 -0.03769181]
new state [[-0.04037897 -0.04930321 -0.03231727 -0.03769181]]
episode: 72/1000, score: 27, e: 0.93
[0.00891693 0.01945199 0.01113726 0.04534762]
new state [[0.00891693 0.01945199 0.01113726 0.04534762]]
episode: 73/1000, score: 14, e: 0.93
[-0.00039189  0.04737795  0.0110444   0.01123747]
new state [[-0.00039189  0.04737795  0.0110444   0.01123747]]
episode: 74/1000, score: 81, e: 0.93


[ 0.0163474   0.00724788 -0.04939806  0.01357725]
new state [[ 0.0163474   0.00724788 -0.04939806  0.01357725]]
episode: 75/1000, score: 13, e: 0.93
[ 0.02443535  0.03205543 -0.03206044 -0.01762398]
new state [[ 0.02443535  0.03205543 -0.03206044 -0.01762398]]
episode: 76/1000, score: 28, e: 0.93


[-0.02396427  0.02914927  0.03677386  0.04940961]
new state [[-0.02396427  0.02914927  0.03677386  0.04940961]]
episode: 77/1000, score: 7, e: 0.93
[-0.00395555  0.04025981  0.02424108  0.00982595]
new state [[-0.00395555  0.04025981  0.02424108  0.00982595]]
episode: 78/1000, score: 29, e: 0.93


[-4.1906361e-02 -1.9999921e-02 -7.7132205e-03  1.7762031e-05]
new state [[-4.1906361e-02 -1.9999921e-02 -7.7132205e-03  1.7762031e-05]]
episode: 79/1000, score: 19, e: 0.92
[-0.01868773 -0.00696086 -0.01873736 -0.04336971]
new state [[-0.01868773 -0.00696086 -0.01873736 -0.04336971]]
episode: 80/1000, score: 24, e: 0.92
[ 0.04066906  0.03946738  0.00513637 -0.04994571]
new state [[ 0.04066906  0.03946738  0.00513637 -0.04994571]]
episode: 81/1000, score: 15, e: 0.92


[-0.00733723 -0.00945642  0.04312839  0.00908117]
new state [[-0.00733723 -0.00945642  0.04312839  0.00908117]]
episode: 82/1000, score: 20, e: 0.92
[-0.03807164  0.02213339 -0.03487335  0.00649799]
new state [[-0.03807164  0.02213339 -0.03487335  0.00649799]]
episode: 83/1000, score: 14, e: 0.92


[ 0.02721028 -0.01832957 -0.00966586  0.03823278]
new state [[ 0.02721028 -0.01832957 -0.00966586  0.03823278]]
episode: 84/1000, score: 10, e: 0.92
[-0.00015507 -0.0142639   0.02900752 -0.03078469]
new state [[-0.00015507 -0.0142639   0.02900752 -0.03078469]]
episode: 85/1000, score: 8, e: 0.92


[ 0.04674198 -0.01594263  0.03584139 -0.04359256]
new state [[ 0.04674198 -0.01594263  0.03584139 -0.04359256]]
episode: 86/1000, score: 25, e: 0.92
[ 0.02520522 -0.00423639  0.01990399  0.00743822]
new state [[ 0.02520522 -0.00423639  0.01990399  0.00743822]]
episode: 87/1000, score: 10, e: 0.92
[0.02747931 0.01791384 0.04488048 0.03397488]
new state [[0.02747931 0.01791384 0.04488048 0.03397488]]
episode: 88/1000, score: 12, e: 0.92


[-0.0121652  -0.02451534  0.01651127 -0.0059973 ]
new state [[-0.0121652  -0.02451534  0.01651127 -0.0059973 ]]
episode: 89/1000, score: 16, e: 0.92
[ 0.01943829 -0.02467397  0.03556503  0.04798231]
new state [[ 0.01943829 -0.02467397  0.03556503  0.04798231]]
episode: 90/1000, score: 21, e: 0.91


[-0.03190531  0.02951874 -0.00582385 -0.01976667]
new state [[-0.03190531  0.02951874 -0.00582385 -0.01976667]]
episode: 91/1000, score: 11, e: 0.91
[ 0.04811237  0.02758899 -0.01596922 -0.04454935]
new state [[ 0.04811237  0.02758899 -0.01596922 -0.04454935]]
episode: 92/1000, score: 28, e: 0.91


[-0.00515252  0.02014639  0.02814632  0.01800953]
new state [[-0.00515252  0.02014639  0.02814632  0.01800953]]
episode: 93/1000, score: 25, e: 0.91
[ 0.01085204 -0.03964662 -0.04235183  0.00863945]
new state [[ 0.01085204 -0.03964662 -0.04235183  0.00863945]]
episode: 94/1000, score: 19, e: 0.91
[-0.01577126 -0.04265667 -0.01541331 -0.02151093]
new state [[-0.01577126 -0.04265667 -0.01541331 -0.02151093]]
episode: 95/1000, score: 10, e: 0.91


[ 0.04464263  0.0319691  -0.04143246  0.01314071]
new state [[ 0.04464263  0.0319691  -0.04143246  0.01314071]]
episode: 96/1000, score: 77, e: 0.91
[-0.02246374 -0.02147756 -0.00950146  0.00744935]
new state [[-0.02246374 -0.02147756 -0.00950146  0.00744935]]
episode: 97/1000, score: 18, e: 0.91


[ 0.02005563 -0.04485001  0.02107627  0.02045572]
new state [[ 0.02005563 -0.04485001  0.02107627  0.02045572]]
episode: 98/1000, score: 30, e: 0.91
[-0.01222692 -0.04686414 -0.04103479 -0.02322042]
new state [[-0.01222692 -0.04686414 -0.04103479 -0.02322042]]
episode: 99/1000, score: 17, e: 0.91


[ 0.02476844 -0.04470912 -0.02672893 -0.01620481]
new state [[ 0.02476844 -0.04470912 -0.02672893 -0.01620481]]
episode: 100/1000, score: 29, e: 0.91
[ 0.01398906 -0.04109595  0.04000726  0.04238008]
new state [[ 0.01398906 -0.04109595  0.04000726  0.04238008]]
episode: 101/1000, score: 17, e: 0.9


[-0.00054234 -0.02413111  0.043026    0.00118368]
new state [[-0.00054234 -0.02413111  0.043026    0.00118368]]
episode: 102/1000, score: 13, e: 0.9
[-0.03153818 -0.04257079 -0.00643428 -0.02216462]
new state [[-0.03153818 -0.04257079 -0.00643428 -0.02216462]]
episode: 103/1000, score: 17, e: 0.9
[-0.00097222 -0.04649778  0.01122827 -0.01873366]
new state [[-0.00097222 -0.04649778  0.01122827 -0.01873366]]
episode: 104/1000, score: 9, e: 0.9




In [None]:
env = gym.make('CartPole-v1')
state = env.reset()
print(state)
state = np.reshape(state[0], [1, state_size])
print('new state', state)


In [None]:
action = agent.act([state])

In [None]:
env.step(action)