In [None]:
### Imports ###
import gym
import numpy as np
import warnings
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from keras.models import Model
from keras.layers import Conv2D, Flatten, Input, Dense, AveragePooling2D
from keras.callbacks import Callback
from rl.agents import SARSAAgent
from rl.policy import EpsGreedyQPolicy, LinearAnnealedPolicy

In [None]:
### Functions ###
class CustomModelCheckpoint(Callback):

    def __init__(self, model, path, interval):
        super().__init__()
        self.model = model
        self.path = path
        self.interval = interval

    def on_episode_end(self, episode, logs={}):
        if episode % self.interval == 0:
            self.model.save_weights(self.path.format(episode), overwrite=True)


def plot_history(data, title: str, smoothing: bool = False, smoothing_window: int = 100):

    if smoothing:
        data = np.convolve(data, np.ones((smoothing_window,)
                                         )/smoothing_window, mode='valid')
        title = title + " (Smoothed)"

    plt.figure(figsize=(12, 6))
    plt.plot(data)
    plt.title(title)
    plt.ylabel(title)
    plt.xlabel('Epoch')
    plt.legend(['Train'], loc='upper left')
    plt.show()


def rgb_to_gray(rgb):

    return tf.tensordot(rgb[..., :3], [0.2989, 0.5870, 0.1140], axes=1)


def build_model_cnn(states, actions):

    inputs = Input(shape=states)
    x = Conv2D(64, (3, 3), activation='leaky_relu')(inputs)
    x = tf.squeeze(x, axis=1)
    x = AveragePooling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='leaky_relu')(x)
    x = AveragePooling2D((2, 2))(x)
    x = Flatten()(x)
    x = Dense(1500, activation='leaky_relu')(x)
    x = Dense(500, activation='leaky_relu')(x)
    x = Dense(64, activation='leaky_relu')(x)
    x = Dense(actions, activation='linear')(x)
    outputs = x
    model = Model(inputs=inputs, outputs=outputs)
    return model


# def build_model(states, actions):
#     inputs = Input(shape=(1,) + states)
#     x = Lambda(rgb_to_gray)(inputs)  # Apply rgb_to_gray function to input
#     #x = Flatten()(inputs)
#     x = Flatten()(x)
#     x = Dense(1000, activation='leaky_relu')(x)
#     x = Dense(400, activation='leaky_relu')(x)
#     x = Dense(200, activation='leaky_relu')(x)
#     x = Dense(actions, activation='leaky_relu')(x)
#     outputs = x
#     model = Model(inputs=inputs, outputs=outputs)
#     return model


def build_agent(model, actions):
    # policy = MaxBoltzmannQPolicy()
    # policy = EpsGreedyQPolicy()
    # policy = BoltzmannQPolicy()
    policy = LinearAnnealedPolicy(EpsGreedyQPolicy(
    ), attr='eps', value_max=1., value_min=.01, value_test=.01, nb_steps=1000000)

    # Define the agent
    agent = SARSAAgent(model=model, policy=policy,
                       nb_actions=actions, nb_steps_warmup=1)
    return agent

In [None]:
warnings.filterwarnings('ignore')

run_name = "sarsa_cnn_9"

In [None]:
### Main ###


# Create the Snake environment
env = gym.make('snake-v0', n_foods=1, unit_size=1, unit_gap=0, grid_size=[15, 15], snake_size=3, n_snakes=1)

states = (1,) + (15, 15, 3)
actions = env.action_space.n

# Build the SARSA agent
actions = env.action_space.n
model = build_model_cnn(states, actions)

agent = build_agent(model, actions)
agent.compile(Adam(lr=0.0001), metrics=['mae'])


# Save the model every n episodes
checkpoint_callback = CustomModelCheckpoint(
    agent.model, path=run_name + "_weights_{:d}.h5", interval=1000)  # saves every 1000 episodes

history = agent.fit(env, nb_steps=10000000, visualize=False, verbose=1,
                    nb_max_start_steps=1, log_interval=10000, callbacks=[checkpoint_callback])


# Save the agent
agent.model.save(run_name + '.h5')

In [None]:
%matplotlib inline

print("gespeicherte Metriken:", history.history.keys())

plot_history(history.history['episode_reward'], "episode_reward", smoothing=True, smoothing_window=100)
plot_history(history.history['episode_reward'], "episode_reward", smoothing=False)
plot_history(history.history['nb_episode_steps'], "nb_episode_steps", smoothing=True, smoothing_window=500)
plot_history(history.history['nb_steps'], "nb_steps", smoothing=False)

In [None]:
%matplotlib qt

# Test the agent
agent.test(env, nb_episodes=3, visualize=True, verbose=1)

In [None]:
%matplotlib qt


# loading and testing an existing model

states = (1,) + (15, 15, 3)
actions = 4

model = build_model_cnn(states, actions)
agent = build_agent(model, actions)
agent.compile(Adam(lr=0.001), metrics=['mae'])

# Load the entire model
run_name = "sarsa_cnn_7"
agent.model.load_weights(run_name + '_weights_50000.h5')

env = gym.make('snake-v0', n_foods = 1, unit_size=1, unit_gap=0)
env.reset() 

# Build a new agent with the loaded model
actions = env.action_space.n
new_agent = build_agent(model, actions)
new_agent.compile(Adam(lr=0.001), metrics=['mae'])

# Run a trial with the new agent
new_agent.test(env, nb_episodes=5, visualize=True)