In [1]:
import numpy as np
from scene import Scene
import tensorflow as tf
import tensorflow.keras as keras
from keras.layers import Conv2D, MaxPooling2D, Dense, Flatten

scene = Scene(using_cnn=True, init_randomly=True)

pygame 2.1.3 (SDL 2.0.22, Python 3.11.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
model = keras.Sequential([
    Conv2D(16, (3, 3), activation='relu', padding="same", kernel_initializer='he_normal', input_shape=(scene.height, scene.width,scene.elements_count)),
    Conv2D(16, (3, 3), activation='relu', padding="same", kernel_initializer='he_normal', input_shape=(scene.height, scene.width,scene.elements_count)),
    MaxPooling2D(2),
    Conv2D(32, (3, 3), activation='relu', padding="same", kernel_initializer='he_normal'),
    MaxPooling2D(2),
    Flatten(),    
    Dense(16, activation='relu', kernel_initializer='he_normal'),
    Dense(4, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

target = keras.models.clone_model(model)
target.set_weights(model.get_weights())

In [3]:
def epsilon_greedy_policy(state, epsilon=0):
    if np.random.rand() < epsilon:
        return np.random.randint(4)
    else:
        Q_values = model.predict(tf.one_hot(state[np.newaxis], scene.elements_count), verbose=False)
        return np.argmax(Q_values[0])

In [4]:
from collections import deque

# (state, action, reward, next_state, done)
replay_memory = deque(maxlen=2000)

def sample_experiences(batch_size):
    indices = np.random.randint(len(replay_memory), size=batch_size)
    batch = [replay_memory[index] for index in indices]
    states, actions, rewards, next_states, dones = [np.array([experience[field_index]
                                                            for experience in batch])
                                                            for field_index in range(5)]
    return states, actions, rewards, next_states, dones

In [5]:
def play_one_step(scene, state, epsilon):
    action = epsilon_greedy_policy(state, epsilon)
    scene.snake.change_direction(action)
    next_state, reward, done = scene.move()
    replay_memory.append((state, action, reward, next_state, done))
    return next_state, reward, done

In [6]:
batch_size = 32
discount_rate = 0.95
optimizer = keras.optimizers.legacy.Adam(learning_rate=1e-2)
loss_fn = keras.losses.mean_squared_error

def training_step(batch_size):
    states, actions, rewards, next_states, dones = sample_experiences(batch_size)
    next_Q_values = model.predict(tf.one_hot(next_states, scene.elements_count), verbose=False)
    best_next_actions = np.argmax(next_Q_values, axis=1)
    next_mask = tf.one_hot(best_next_actions, 4).numpy()
    next_best_Q_values = (target.predict(tf.one_hot(next_states, scene.elements_count), verbose=False) * next_mask).sum(axis=1)
    target_Q_values = (rewards + (1 - dones) * discount_rate * next_best_Q_values).reshape(-1, 1)
    mask = tf.one_hot(actions, 4)
    with tf.GradientTape() as tape:
        all_Q_values = model(tf.one_hot(states, scene.elements_count))
        Q_values = tf.reduce_sum(all_Q_values * mask, axis=1, keepdims=True)
        loss = tf.reduce_mean(loss_fn(target_Q_values, Q_values))
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

In [None]:
episodes_count = 8000
reward_per_batch = 0

for episode in range(1, episodes_count + 1):
    state = scene.scene_as_matrix()
    epsilon = (1 / (np.linspace(1, 8, episodes_count)**(1/3)))[episode - 1]
    done = False
    
    steps = 0
    while not done:
        steps += 1
        state, reward, done = play_one_step(scene, state, epsilon)
        reward_per_batch += reward
        
    if episode > 50:
        training_step(batch_size)
        if episode % 50 == 0:
            print("Episode number: ", episode)
            target.set_weights(model.get_weights())
            print("Average reward: ", reward_per_batch / 50)
            reward_per_batch = 0
            
    if episode % 1000 == 0:
        model.save(f"models/CNN_model{episode}.h5")