In [1]:
# allow to import modules from the project root directory
import sys
import os
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

In [2]:
import numpy as np
import random
from collections import defaultdict
from tqdm import tqdm
import pickle
from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv
from src.utils import compute_metrics

<img src="../../artifacts/images/Qlearning.png" alt="Q-learning algorithm" width="1000" />

In [3]:
import numpy as np
import random
from collections import defaultdict
from tqdm import tqdm

class QLearningAgent:
    def __init__(self, env, learning_rate=0.5, discount_factor=0.99, epsilon=0.1, learning_rate_decay=0.8, epsilon_decay=0.9):
        self.env = env
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.learning_rate_decay = learning_rate_decay
        self.epsilon_decay = epsilon_decay
        self.q_table = defaultdict(lambda: np.zeros(env.action_space.n))

    def choose_action(self, state, Greedy=False):
        state = tuple(state.flatten())
        if Greedy:
            return np.argmax(self.q_table[state]), None
        if random.uniform(0, 1) < self.epsilon:
            return self.env.action_space.sample(), None  # Explore
        else:
            return np.argmax(self.q_table[state]), None  # Exploit

    def update_q_value(self, state, action, reward, next_state):
        next_action = self.choose_action(next_state, Greedy=True)  # Choose next action using epsilon-greedy approach
        next_state = tuple(next_state.flatten())
        state = tuple(state.flatten())
        td_target = reward + self.discount_factor * self.q_table[next_state][next_action]
        td_error = td_target - self.q_table[state][action]
        self.q_table[state][action] += self.learning_rate * td_error

    def print_qtable(self):
        for state, actions in self.q_table.items():
            print(f"State: {state}")
            for action, value in enumerate(actions):
                print(f"  Action {action}: {value:.2f}")
            print()

    def train(self, num_episodes):
        for episode in tqdm(range(num_episodes), desc='Training', unit='Episode'):
            state = self.env.reset()
            done = False
            while not done:
                action, _ = self.choose_action(state)
                next_state, reward, done, _ = self.env.step(action)
                self.update_q_value(state, action, reward, next_state)
                state = next_state
            # decaying    
            if episode % 1000 == 0 and episode != 0:
                self.epsilon *= self.epsilon_decay  # Decay epsilon
                self.learning_rate *= self.learning_rate_decay


In [4]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
agent = QLearningAgent(env, epsilon=0.1, discount_factor=0.9, learning_rate=0.9)

In [5]:
# Load the Q-table
# with open('q_learning_agent.pkl', 'rb') as f:
#     agent.q_table = defaultdict(lambda: np.zeros(env.action_space.n), pickle.load(f))

In [6]:
num_episodes = 50000
agent.train(num_episodes=num_episodes)

Training:   0%|          | 0/50000 [00:00<?, ?Episode/s]

  self.q_table[state][action] += self.learning_rate * td_error
Training: 100%|██████████| 50000/50000 [02:00<00:00, 413.52Episode/s]


In [14]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'par'

table_name = f'q_learning_table_{environment}_{num_episodes}_{grid_size}.pkl'
model_weights_dir = os.path.join('../..', 'models', 'q-learning')
os.makedirs(model_weights_dir, exist_ok=True)
table_path = os.path.join(model_weights_dir, table_name)

with open(table_path, 'wb') as f:
        pickle.dump(dict(agent.q_table), f)

In [15]:
if isinstance(env, ParObsSnakeEnv):
    env = ParObsSnakeEnv(grid_size=2*grid_size, interact=False)

num_simulations = 100
metrics_name = f'q_learning_metrics_{environment}_{num_episodes}_{env.grid_size}_{num_simulations}.jsn'
model_metrics_dir = os.path.join('../..', 'artifacts', 'models_stats', 'q-learning')
os.makedirs(model_metrics_dir, exist_ok=True)
metrics_path = os.path.join(model_metrics_dir, metrics_name)

compute_metrics(agent, env, metrics_path, num_simulations=num_simulations)
    

 53%|█████▎    | 53/100 [00:00<00:00, 264.17it/s]

Snake length: 35, Episode reward: 2952
Snake length: 18, Episode reward: 1426
Snake length: 29, Episode reward: 2457
Snake length: 26, Episode reward: 2095
Snake length: 17, Episode reward: 1287
Snake length: 27, Episode reward: 2157
Snake length: 11, Episode reward: 806
Snake length: 15, Episode reward: 1176
Snake length: 30, Episode reward: 2509
Snake length: 29, Episode reward: 2380
Snake length: 22, Episode reward: 1816
Snake length: 22, Episode reward: 1811
Snake length: 30, Episode reward: 2513
Snake length: 28, Episode reward: 2327
Snake length: 38, Episode reward: 3134
Snake length: 25, Episode reward: 2065
Snake length: 10, Episode reward: 702
Snake length: 39, Episode reward: 3352
Snake length: 22, Episode reward: 1846
Snake length: 35, Episode reward: 2898
Snake length: 24, Episode reward: 2000
Snake length: 30, Episode reward: 2542
Snake length: 38, Episode reward: 3246
Snake length: 16, Episode reward: 1288
Snake length: 29, Episode reward: 2498
Snake length: 17, Episode r

100%|██████████| 100/100 [00:00<00:00, 274.49it/s]

Snake length: 38, Episode reward: 3202
Snake length: 37, Episode reward: 3128
Snake length: 49, Episode reward: 4210
Snake length: 16, Episode reward: 1265
Snake length: 18, Episode reward: 1490
Snake length: 24, Episode reward: 1934
Snake length: 18, Episode reward: 1483
Snake length: 43, Episode reward: 3657
Snake length: 27, Episode reward: 2217
Snake length: 34, Episode reward: 2838
Snake length: 25, Episode reward: 2034
Snake length: 20, Episode reward: 1589
Snake length: 19, Episode reward: 1518
Snake length: 37, Episode reward: 3133
Snake length: 7, Episode reward: 478
Snake length: 20, Episode reward: 1564
Snake length: 46, Episode reward: 4009
Snake length: 28, Episode reward: 2259
Snake length: 25, Episode reward: 2076
Snake length: 17, Episode reward: 1317
Snake length: 13, Episode reward: 986
Snake length: 15, Episode reward: 1163
Snake length: 23, Episode reward: 1864
Snake length: 25, Episode reward: 2055
Snake length: 23, Episode reward: 1894
Snake length: 14, Episode re




{'snake_lengths': [35,
  18,
  29,
  26,
  17,
  27,
  11,
  15,
  30,
  29,
  22,
  22,
  30,
  28,
  38,
  25,
  10,
  39,
  22,
  35,
  24,
  30,
  38,
  16,
  29,
  17,
  9,
  34,
  18,
  42,
  29,
  45,
  33,
  17,
  43,
  20,
  23,
  26,
  24,
  32,
  35,
  42,
  23,
  29,
  28,
  36,
  33,
  25,
  27,
  31,
  20,
  24,
  25,
  32,
  15,
  33,
  31,
  31,
  38,
  37,
  49,
  16,
  18,
  24,
  18,
  43,
  27,
  34,
  25,
  20,
  19,
  37,
  7,
  20,
  46,
  28,
  25,
  17,
  13,
  15,
  23,
  25,
  23,
  14,
  32,
  23,
  35,
  26,
  17,
  23,
  25,
  4,
  38,
  13,
  19,
  2,
  31,
  14,
  28,
  12],
 'episode_rewards': [2952,
  1426,
  2457,
  2095,
  1287,
  2157,
  806,
  1176,
  2509,
  2380,
  1816,
  1811,
  2513,
  2327,
  3134,
  2065,
  702,
  3352,
  1846,
  2898,
  2000,
  2542,
  3246,
  1288,
  2498,
  1318,
  621,
  2909,
  1393,
  3607,
  2435,
  3877,
  2675,
  1323,
  3670,
  1587,
  1857,
  2115,
  1946,
  2679,
  2955,
  3568,
  1874,
  2374,
  2349,
  3036,
  

In [9]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action, _ = agent.choose_action(state)
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: -1
Reward: 1
Reward: 1
R