In [3]:
import numpy as np
import os
from collections import deque
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
import random
import gymnasium as gym

In [4]:
class Agent:
    def __init__(self,state_size =4,action_size = 2):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen = 2000)
        self.gamma = 0.95
        self.eplison = 1
        self.eps_decay = 0.995
        self.eps_min = 0.01
        self.model = self.createModel()

    def createModel(self):
        model = keras.Sequential()
        model.add(keras.Input(shape=(self.state_size,)))
        model.add(Dense(24,activation='relu'))
        model.add(Dense(24,activation='relu'))
        model.add(Dense(self.action_size,activation='linear'))

        model.compile(loss=keras.losses.MeanSquaredError,optimizer=keras.optimizers.Adam())
        return model

    def remember(self,state,action,next_state,reward,done):
        self.memory.append((state,action,next_state,reward,done))

    def act(self,state):
        if np.random.random() <= self.eplison:
            return random.randrange(self.action_size)
        else:
            return np.argmax(self.model.predict(state,verbose = 0)[0])

    def train(self,batch_size = 32):
        minibatch = random.sample(self.memory,batch_size)
        for exp in minibatch:
            state,action,next_state,reward,done = exp
            target_f = self.model.predict(state,verbose=0)
            if done:
                target = reward
            else:
                target = reward + self.gamma*np.max(self.model.predict(next_state,verbose = 0)[0])
            target_f[0][action] = target

            self.model.fit(state,target_f,epochs=1,verbose=0)

        if self.eplison > self.eps_min:
            self.eplison *= self.eps_decay

In [5]:
n_episode = 1000
agent = Agent(state_size=4,action_size=2)
done = False
env = gym.make('CartPole-v1',render_mode = 'rgb_array')
batch_size = 32

In [6]:
for e in range(n_episode):
    state,info = env.reset()
    state = state.reshape((1,-1))

    for i in range(500):
        action = agent.act(state)
        next_state,reward,done,truncate,info = env.step(action)
        reward = reward if not done else -100
        next_state = next_state.reshape((1,-1))
        agent.remember(state,action,next_state,reward,done)
        state = next_state

        if done or truncate:
            print(f'Episode: {e}/{n_episode} Score: {i+1}')
            break
    if len(agent.memory) >= batch_size:
        agent.train()

Episode: 0/1000 Score: 27
Episode: 1/1000 Score: 9
Episode: 2/1000 Score: 15
Episode: 3/1000 Score: 20
Episode: 4/1000 Score: 43
Episode: 5/1000 Score: 32
Episode: 6/1000 Score: 16
Episode: 7/1000 Score: 26
Episode: 8/1000 Score: 26
Episode: 9/1000 Score: 16
Episode: 10/1000 Score: 21
Episode: 11/1000 Score: 27
Episode: 12/1000 Score: 33
Episode: 13/1000 Score: 11
Episode: 14/1000 Score: 10
Episode: 15/1000 Score: 15
Episode: 16/1000 Score: 11
Episode: 17/1000 Score: 18
Episode: 18/1000 Score: 43
Episode: 19/1000 Score: 17
Episode: 20/1000 Score: 31
Episode: 21/1000 Score: 27
Episode: 22/1000 Score: 20
Episode: 23/1000 Score: 11
Episode: 24/1000 Score: 13
Episode: 25/1000 Score: 43
Episode: 26/1000 Score: 14
Episode: 27/1000 Score: 10
Episode: 28/1000 Score: 14
Episode: 29/1000 Score: 15
Episode: 30/1000 Score: 22
Episode: 31/1000 Score: 15
Episode: 32/1000 Score: 20
Episode: 33/1000 Score: 24
Episode: 34/1000 Score: 14
Episode: 35/1000 Score: 35
Episode: 36/1000 Score: 22
Episode: 37/

KeyboardInterrupt: 

In [None]:
env.close()

In [8]:
env = gym.make('CartPole-v1',render_mode = 'human')
state,info = env.reset()
state = state.reshape((1,-1))
agent.eplison = 0

for i in range(500):
    action = agent.act(state)
    next_state,reward,done,truncate,info = env.step(action)
    reward = reward if not done else -100
    next_state = next_state.reshape((1,-1))
    agent.remember(state,action,next_state,reward,done)
    state = next_state

    if done or truncate:
        print(f'Score: {i+1}')
        break

Score: 139
