In [9]:
# allow to import modules from the project root directory
import sys
import os
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

In [10]:
import numpy as np
import random
from collections import defaultdict
from tqdm import tqdm
import pickle
from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv
from src.utils import compute_metrics
import matplotlib.pyplot as plt
import json

<img src="../../artifacts/images/Qlearning.png" alt="Q-learning algorithm" width="1000" />

In [11]:
class QLearningAgent:
    def __init__(self, env, learning_rate=0.5, discount_factor=0.99, epsilon=0.1, learning_rate_decay=0.8, epsilon_decay=0.9):
        self.env = env
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.learning_rate_decay = learning_rate_decay
        self.epsilon_decay = epsilon_decay
        self.q_table = defaultdict(lambda: np.zeros(env.action_space.n))

    def choose_action(self, state, Greedy=False):
        state = tuple(state.flatten())
        if Greedy:
            return np.argmax(self.q_table[state]), None
        if random.uniform(0, 1) < self.epsilon:
            return self.env.action_space.sample(), None  # Explore
        else:
            return np.argmax(self.q_table[state]), None  # Exploit

    def update_q_value(self, state, action, reward, next_state):
        next_action = self.choose_action(next_state, Greedy=True)  # Choose next action using epsilon-greedy approach
        next_state = tuple(next_state.flatten())
        state = tuple(state.flatten())
        td_target = reward + self.discount_factor * self.q_table[next_state][next_action]
        td_error = td_target - self.q_table[state][action]
        self.q_table[state][action] += self.learning_rate * td_error

    def print_qtable(self):
        for state, actions in self.q_table.items():
            print(f"State: {state}")
            for action, value in enumerate(actions):
                print(f"  Action {action}: {value:.2f}")
            print()
    
    def save_table(self, table_path):
        with open(table_path, 'wb') as f:
            pickle.dump(dict(self.q_table), f)
    
    def load_table(self, table_path):
        with open(table_path, 'rb') as f:
            self.q_table = pickle.load(f)

    def train(self, num_episodes, save_plots=False, plots_path='plots.png'):
        self.episode_rewards = []
        self.steps_per_episode = []
        self.epsilon_values = []
        self.learning_rate_values = []
        self.average_q_updates = []

        for episode in tqdm(range(num_episodes), desc='Training', unit='Episode'):
            state = self.env.reset()
            done = False
            total_reward = 0
            steps = 0
            q_update_magnitudes = []

            while not done:
                action, _ = self.choose_action(state)
                next_state, reward, done, _ = self.env.step(action)
                old_q_value = self.q_table[tuple(state.flatten())][action]
                self.update_q_value(state, action, reward, next_state)
                new_q_value = self.q_table[tuple(state.flatten())][action]
                q_update_magnitudes.append(abs(new_q_value - old_q_value))

                total_reward += reward
                state = next_state
                steps += 1

            # Logging metrics for the episode
            self.episode_rewards.append(total_reward)
            self.steps_per_episode.append(steps)
            self.epsilon_values.append(self.epsilon)
            self.learning_rate_values.append(self.learning_rate)
            self.average_q_updates.append(np.mean(q_update_magnitudes))

            # Decay epsilon and learning rate periodically
            if episode % 1000 == 0 and episode != 0:
                self.epsilon *= self.epsilon_decay
                self.learning_rate *= self.learning_rate_decay

        # Save plots after training
        if save_plots:
            self.save_plots(plots_path)

    def save_plots(self, plots_path):
        plots_dir = os.path.dirname(plots_path)
        os.makedirs(plots_dir, exist_ok=True)

        # Create a figure with subplots
        fig, axs = plt.subplots(3, 2, figsize=(15, 15))

        # Plot rewards
        axs[0, 0].plot(self.episode_rewards)
        axs[0, 0].set_xlabel('Episode')
        axs[0, 0].set_ylabel('Cumulative Reward')
        axs[0, 0].set_title('Rewards Over Episodes')

        # Plot steps per episode
        axs[0, 1].plot(self.steps_per_episode)
        axs[0, 1].set_xlabel('Episode')
        axs[0, 1].set_ylabel('Steps per Episode')
        axs[0, 1].set_title('Steps Over Episodes')

        # Plot epsilon values
        axs[1, 0].plot(self.epsilon_values)
        axs[1, 0].set_xlabel('Episode')
        axs[1, 0].set_ylabel('Epsilon')
        axs[1, 0].set_title('Epsilon Decay Over Episodes')

        # Plot learning rate values
        axs[1, 1].plot(self.learning_rate_values)
        axs[1, 1].set_xlabel('Episode')
        axs[1, 1].set_ylabel('Learning Rate')
        axs[1, 1].set_title('Learning Rate Decay Over Episodes')

        # Plot Q-value updates
        axs[2, 0].plot(self.average_q_updates)
        axs[2, 0].set_xlabel('Episode')
        axs[2, 0].set_ylabel('Average Q-Value Update')
        axs[2, 0].set_title('Q-Value Updates Over Episodes')

        # Hide the empty subplot (bottom right)
        fig.delaxes(axs[2, 1])

        # Adjust layout
        plt.tight_layout()

        # Save the figure
        plt.savefig(plots_path)
        plt.close()


In [12]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
agent = QLearningAgent(env, epsilon=0.1, discount_factor=0.9, learning_rate=0.9)

In [13]:
num_episodes = 15000
agent.train(num_episodes=num_episodes)

  self.q_table[state][action] += self.learning_rate * td_error
Training: 100%|██████████| 15000/15000 [00:35<00:00, 427.81Episode/s]


In [14]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'par'

table_name = f'q_learning_table_{environment}_{num_episodes}_{grid_size}.pkl'
model_weights_dir = os.path.join('../..', 'models', 'q-learning')
os.makedirs(model_weights_dir, exist_ok=True)
table_path = os.path.join(model_weights_dir, table_name)

agent.save_table(table_path)

In [None]:
if isinstance(env, ParObsSnakeEnv):
    env = ParObsSnakeEnv(grid_size=2*grid_size, interact=False)

model_metrics_dir = os.path.join('../..', 'artifacts', 'models_stats', 'q-learning')
os.makedirs(model_metrics_dir, exist_ok=True)

train_metrics_name = f'q_learning_train_metrics_{environment}_{num_episodes}_{grid_size}.png'
train_metrics_path = os.path.join(model_metrics_dir, train_metrics_name)
agent.save_plots(train_metrics_path)

num_simulations = 100
sim_metrics_name = f'q_learning_sim_metrics_{environment}_{num_episodes}_{env.grid_size}_{num_simulations}.json'
sim_metrics_path = os.path.join(model_metrics_dir, sim_metrics_name)
compute_metrics(agent, env, sim_metrics_path, num_simulations=num_simulations)


 81%|████████  | 81/100 [00:00<00:00, 373.11it/s]

Snake length: 1, Episode reward: -49
Snake length: 40, Episode reward: 3450
Snake length: 30, Episode reward: 2505
Snake length: 25, Episode reward: 2037
Snake length: 24, Episode reward: 1969
Snake length: 22, Episode reward: 1839
Snake length: 5, Episode reward: 302
Snake length: 9, Episode reward: 617
Snake length: 15, Episode reward: 1169
Snake length: 21, Episode reward: 1687
Snake length: 16, Episode reward: 1244
Snake length: 18, Episode reward: 1435
Snake length: 24, Episode reward: 1913
Snake length: 25, Episode reward: 1979
Snake length: 17, Episode reward: 1311
Snake length: 15, Episode reward: 1133
Snake length: 24, Episode reward: 1908
Snake length: 5, Episode reward: 284
Snake length: 26, Episode reward: 2205
Snake length: 12, Episode reward: 930
Snake length: 19, Episode reward: 1572
Snake length: 15, Episode reward: 1141
Snake length: 29, Episode reward: 2436
Snake length: 19, Episode reward: 1493
Snake length: 16, Episode reward: 1215
Snake length: 12, Episode reward: 

100%|██████████| 100/100 [00:00<00:00, 377.99it/s]

Snake length: 38, Episode reward: 3181
Snake length: 33, Episode reward: 2774
Snake length: 26, Episode reward: 2146
Snake length: 18, Episode reward: 1427
Snake length: 13, Episode reward: 954
Snake length: 7, Episode reward: 452
Snake length: 23, Episode reward: 1887
Snake length: 9, Episode reward: 640
Snake length: 15, Episode reward: 1186
Snake length: 6, Episode reward: 352
Snake length: 7, Episode reward: 470
Snake length: 23, Episode reward: 1845
Snake length: 26, Episode reward: 2190
Snake length: 33, Episode reward: 2768
Snake length: 23, Episode reward: 1838
Snake length: 19, Episode reward: 1543
Snake length: 24, Episode reward: 1939
Snake length: 22, Episode reward: 1742
Snake length: 31, Episode reward: 2628





{'snake_lengths': [1,
  40,
  30,
  25,
  24,
  22,
  5,
  9,
  15,
  21,
  16,
  18,
  24,
  25,
  17,
  15,
  24,
  5,
  26,
  12,
  19,
  15,
  29,
  19,
  16,
  12,
  25,
  36,
  23,
  14,
  29,
  4,
  22,
  11,
  36,
  27,
  3,
  35,
  27,
  36,
  39,
  30,
  29,
  24,
  5,
  22,
  33,
  26,
  25,
  29,
  17,
  29,
  30,
  5,
  14,
  25,
  18,
  7,
  12,
  28,
  10,
  21,
  13,
  24,
  16,
  23,
  19,
  25,
  34,
  4,
  16,
  15,
  37,
  14,
  7,
  37,
  22,
  10,
  27,
  34,
  6,
  38,
  33,
  26,
  18,
  13,
  7,
  23,
  9,
  15,
  6,
  7,
  23,
  26,
  33,
  23,
  19,
  24,
  22,
  31],
 'episode_rewards': [-49,
  3450,
  2505,
  2037,
  1969,
  1839,
  302,
  617,
  1169,
  1687,
  1244,
  1435,
  1913,
  1979,
  1311,
  1133,
  1908,
  284,
  2205,
  930,
  1572,
  1141,
  2436,
  1493,
  1215,
  906,
  2031,
  3033,
  1909,
  1101,
  2394,
  171,
  1780,
  844,
  3044,
  2252,
  120,
  2993,
  2256,
  2934,
  3290,
  2525,
  2404,
  1995,
  283,
  1834,
  2784,
  2164,
  205

In [16]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action, _ = agent.choose_action(state)
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: