In [1]:
from PIL import Image
import numpy as np
import random
import gymnasium as gym
from collections import deque
from tensorflow.keras.models import Sequential, clone_model
from tensorflow.keras.layers import Dense, Activation, Flatten, Convolution2D,Permute
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import cv2

In [2]:
pwd

'/Users/taesan/Udemy/reinforcement'

In [3]:
env = gym.make("snake:snake-v0",render_mode="human")
env.reset()

num_observations = env.observation_space.shape[0]
num_actions = env.action_space.n
IMG_SHAPE = (200,200)
WINDOW_LENGTH = 4
BATCH_SIZE = 32

model = Sequential(
    
    [Permute( (2,3,1), input_shape=(WINDOW_LENGTH,IMG_SHAPE[0],IMG_SHAPE[1])),
    Convolution2D(32, (8,8), strides= (4,4),kernel_initializer="he_normal"),
    Activation('relu'),
    
    Convolution2D(64, (4,4), strides= (2,2),kernel_initializer="he_normal"),
    Activation('relu'),
     
    Convolution2D(64, (4,4), strides= (1,1),kernel_initializer="he_normal"),
    Activation('relu'),
        
    Flatten(),
    Dense(512),
     Activation('relu'),
     
     Dense(num_actions),
    Activation('linear')
    ]
    
)
print(model.summary())
target_model = clone_model(model)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 permute (Permute)           (None, 200, 200, 4)       0         
                                                                 
 conv2d (Conv2D)             (None, 49, 49, 32)        8224      
                                                                 
 activation (Activation)     (None, 49, 49, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 23, 23, 64)        32832     
                                                                 
 activation_1 (Activation)   (None, 23, 23, 64)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 20, 20, 64)        65600     
                                                                 
 activation_2 (Activation)   (None, 20, 20, 64)        0

  logger.deprecation(
  logger.deprecation(
  logger.warn(


In [4]:
#declare hyperparameters
EPISODES = 1000 
LEARNING_RATE = 0.001 #LEARNING RATE FOR MODEL
GAMMA = 0.95 #DISCOUNT RATE

epsilon = 1.0 #greedy-epsilon
EPSILON_REDUCE = 0.995

replay_buffer = deque(maxlen=50000)
image_buffer= deque(maxlen=WINDOW_LENGTH)
update_target_model = 10

model.compile(loss="mse", optimizer=Adam(learning_rate=LEARNING_RATE))



In [5]:
#process observation
def process_observation(observation):
    img = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
    img = img/255.0
    return img

def process_reward(reward):
    return np.clip(reward, -1.0,1.0)

#action selection
def epsilon_greedy_action_selection(model, epsilon, observation):
    observation = observation.reshape(1,WINDOW_LENGTH,IMG_SHAPE[0],IMG_SHAPE[1])
    #Exploitation
    if np.random.random() > epsilon:
        prediction = model.predict((observation), verbose=0)
        action = np.argmax(prediction)
        
    #Exploration
    else:
        action = env.action_space.sample()
        
    return action
    
# reduce epsilon
def reduce_epsilon(epsilon,epoch):
    return min_epsilon + (max_epsilon-min_epsilon)*np.exp(-decay_rate * epoch)

def update_model_handler(epoch,update_target_model, model, target_model):
    if epoch > 0 and epoch % update_target_model == 0:
        target_model.set_weights(model.get_weights())
        

In [6]:
def replay(replay_buffer,batch_size, model, target_model):
    if len(replay_buffer)< batch_size:
        return
    
    samples = random.sample(replay_buffer, batch_size)
    
    target_batch=[]
    
    zipped_samples = list(zip(*samples))
    states, actions ,rewards, new_states, dones, truncated = zipped_samples
    targets = target_model.predict(np.array(states),verbose=False)
    
    q_values = model.predict( np.array(new_states),verbose=False)

    for i in range(batch_size):
        q_value = max(q_values[i])
    
        target = targets[i].copy()
        if dones[i] or truncated[i]:
            target[actions[i]] = rewards[i]
        else:
            target[actions[i]] =rewards[i]+q_value * GAMMA
        target_batch.append(target)
    
    model.fit(np.array(states),np.array(target_batch),verbose=False)
    


In [None]:
best_so_far = 0
observation, _ = env.reset()
preprocessed_obs = process_observation(observation)
for i in range(WINDOW_LENGTH):
    image_buffer.append(preprocessed_obs)
  
image_buffer_numpy = np.stack((image_buffer[0],image_buffer[1],image_buffer[2],image_buffer[3]),axis=0)

for episode in range(EPISODES):
    observation, _ = env.reset()
    preprocessed_obs = process_observation(observation)
    done = False
    truncated = False
    
    points = 0
    
    while not (done or truncated):
        action = epsilon_greedy_action_selection(model,epsilon,image_buffer_numpy)
    
        next_observation, reward, done, truncated, info = env.step(action)
        procssed_next_observation = process_observation(next_observation)
        
        preprocessed_obs = process_observation(observation)
        preprocessed_reward = process_reward(reward)
        
        image_buffer.append(preprocessed_obs)
        image_buffer_numpy = np.stack((image_buffer[0],image_buffer[1],image_buffer[2],image_buffer[3]),axis=0)
        next_observation_numpy = np.stack((image_buffer[1],image_buffer[2],image_buffer[3],preprocessed_obs),axis=0)

        replay_buffer.append((image_buffer_numpy, action, reward, next_observation_numpy,done,truncated))
        
        
        observation = next_observation
        points += reward
        
        replay(replay_buffer,32, model, target_model)
        
    epsilon *= EPSILON_REDUCE
    
    update_model_handler(episode, update_target_model, model,target_model)
    
    if points > best_so_far:
        best_so_far= points
    if episode %25 == 0:
        print("Episode {0:} , Best so far: {1:}".format(episode,best_so_far))
        

  logger.warn(
  logger.warn(


Episode 0 , Best so far: 0
Episode 25 , Best so far: 0
Episode 50 , Best so far: 0
Episode 75 , Best so far: 0


In [None]:
test_env =gym.make("snake:snake-v0",render_mode="human")
state, _ = test_env.reset()
preprocessed_obs = process_observation(observation)
for i in range(WINDOW_LENGTH):
    image_buffer.append(preprocessed_obs)
  
image_buffer_numpy = np.stack((image_buffer[0],image_buffer[1],image_buffer[2],image_buffer[3]),axis=0).reshape(1,WINDOW_LENGTH,IMG_SHAPE[0],IMG_SHAPE[1])

point = 0
done = False
truncated = False
for steps in range(300):
    action = np.argmax(model.predict(image_buffer_numpy))
    
    state, reward, done, truncated , info = test_env.step(action)
    image_buffer.append(preprocessed_obs)
    image_buffer_numpy = np.stack((image_buffer[0],image_buffer[1],image_buffer[2],image_buffer[3]),axis=0).reshape(1,WINDOW_LENGTH,IMG_SHAPE[0],IMG_SHAPE[1])
    
    point+= reward
    if done or truncated: 
        print("done")
        break
env.close()