In [1]:
import numpy as np
import random
from collections import defaultdict
from tqdm import tqdm
import pickle
# allow to import modules from the project root directory
import sys
import os

# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv
from src.utils import compute_metrics

<img src="../../artifacts/images/SARSA.png" alt="SARSA algorithm" width="500" />

In [None]:
class SarsaAgent:
    def __init__(self, env, learning_rate=0.5, discount_factor=0.99, epsilon=0.1, learning_rate_decay=0.8, epsilon_decay=0.9):
        self.env = env
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.learning_rate_decay = learning_rate_decay
        self.epsilon_decay = epsilon_decay
        self.q_table = defaultdict(lambda: np.zeros(env.action_space.n))

    def choose_action(self, state):
        state = tuple(state.flatten())
        if random.uniform(0, 1) < self.epsilon:
            return self.env.action_space.sample(), None  # Explore
        else:
            return np.argmax(self.q_table[state]), None  # Exploit

    def update_q_value(self, state, action, reward, next_state, next_action):
        next_state = tuple(next_state.flatten())
        state = tuple(state.flatten())
        td_target = reward + self.discount_factor * self.q_table[next_state][next_action]
        td_error = td_target - self.q_table[state][action]
        self.q_table[state][action] += self.learning_rate * td_error

    def train(self, num_episodes):
        for episode in tqdm(range(num_episodes), desc='Training', unit='Episode'):
            state = self.env.reset()
            action = self.choose_action(state)
            done = False
            while not done:
                next_state, reward, done, _ = self.env.step(action)
                next_action, _ = self.choose_action(next_state)
                self.update_q_value(state, action, reward, next_state, next_action)
                state = next_state
                action = next_action
            # Decaying    
            if episode % 1000 == 0 and episode != 0:
                self.epsilon *= self.epsilon_decay  # Decay epsilon
                self.learning_rate *= self.learning_rate_decay
                

In [3]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
agent = SarsaAgent(env, epsilon=0.1, discount_factor=0.9, learning_rate=0.9)

In [4]:
num_episodes = 50000
agent.train(num_episodes=num_episodes)

Training: 100%|██████████| 50000/50000 [01:08<00:00, 732.91Episode/s]


In [5]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'par'

table_name = f'sarsa_table_{environment}_{num_episodes}_{grid_size}.pkl'
model_weights_dir = os.path.join('../..', 'models', 'sarsa')
os.makedirs(model_weights_dir, exist_ok=True)
table_path = os.path.join(model_weights_dir, table_name)

with open(table_path, 'wb') as f:
        pickle.dump(dict(agent.q_table), f)

In [6]:
if isinstance(env, ParObsSnakeEnv):
    env = ParObsSnakeEnv(grid_size=2*grid_size, interact=False)

num_simulations = 100
metrics_name = f'sarsa_metrics_{environment}_{num_episodes}_{env.grid_size}_{num_simulations}.jsn'
model_metrics_dir = os.path.join('../..', 'artifacts', 'models_stats', 'sarsa')
os.makedirs(model_metrics_dir, exist_ok=True)
metrics_path = os.path.join(model_metrics_dir, metrics_name)

compute_metrics(agent, env, metrics_path, num_simulations=num_simulations)

 51%|█████     | 51/100 [00:00<00:00, 254.42it/s]

Snake length: 25, Episode reward: 1992
Snake length: 27, Episode reward: 2253
Snake length: 26, Episode reward: 2160
Snake length: 20, Episode reward: 1609
Snake length: 35, Episode reward: 2943
Snake length: 30, Episode reward: 2499
Snake length: 26, Episode reward: 2204
Snake length: 30, Episode reward: 2475
Snake length: 16, Episode reward: 1249
Snake length: 28, Episode reward: 2361
Snake length: 33, Episode reward: 2777
Snake length: 22, Episode reward: 1755
Snake length: 20, Episode reward: 1589
Snake length: 44, Episode reward: 3774
Snake length: 40, Episode reward: 3312
Snake length: 28, Episode reward: 2298
Snake length: 28, Episode reward: 2323
Snake length: 14, Episode reward: 1065
Snake length: 18, Episode reward: 1401
Snake length: 10, Episode reward: 707
Snake length: 23, Episode reward: 1809
Snake length: 17, Episode reward: 1316
Snake length: 24, Episode reward: 1958
Snake length: 34, Episode reward: 2779
Snake length: 43, Episode reward: 3729
Snake length: 25, Episode 

100%|██████████| 100/100 [00:00<00:00, 267.71it/s]

Snake length: 19, Episode reward: 1500
Snake length: 30, Episode reward: 2477
Snake length: 33, Episode reward: 2758
Snake length: 23, Episode reward: 1845
Snake length: 21, Episode reward: 1698
Snake length: 24, Episode reward: 1952
Snake length: 35, Episode reward: 2845
Snake length: 39, Episode reward: 3285
Snake length: 41, Episode reward: 3485
Snake length: 21, Episode reward: 1743
Snake length: 40, Episode reward: 3395
Snake length: 27, Episode reward: 2210
Snake length: 11, Episode reward: 806
Snake length: 46, Episode reward: 3984
Snake length: 18, Episode reward: 1435
Snake length: 26, Episode reward: 2153
Snake length: 17, Episode reward: 1348
Snake length: 24, Episode reward: 1963
Snake length: 25, Episode reward: 2031
Snake length: 20, Episode reward: 1579
Snake length: 37, Episode reward: 3158
Snake length: 22, Episode reward: 1790
Snake length: 24, Episode reward: 1948
Snake length: 32, Episode reward: 2645
Snake length: 35, Episode reward: 2967
Snake length: 11, Episode 




{'snake_lengths': [25,
  27,
  26,
  20,
  35,
  30,
  26,
  30,
  16,
  28,
  33,
  22,
  20,
  44,
  40,
  28,
  28,
  14,
  18,
  10,
  23,
  17,
  24,
  34,
  43,
  25,
  14,
  8,
  17,
  29,
  35,
  38,
  21,
  25,
  47,
  38,
  27,
  21,
  12,
  24,
  17,
  26,
  32,
  28,
  43,
  20,
  18,
  30,
  35,
  32,
  34,
  27,
  20,
  16,
  20,
  15,
  36,
  12,
  24,
  42,
  35,
  16,
  23,
  31,
  17,
  18,
  44,
  47,
  17,
  35,
  19,
  30,
  33,
  23,
  21,
  24,
  35,
  39,
  41,
  21,
  40,
  27,
  11,
  46,
  18,
  26,
  17,
  24,
  25,
  20,
  37,
  22,
  24,
  32,
  35,
  11,
  39,
  18,
  18,
  20],
 'episode_rewards': [1992,
  2253,
  2160,
  1609,
  2943,
  2499,
  2204,
  2475,
  1249,
  2361,
  2777,
  1755,
  1589,
  3774,
  3312,
  2298,
  2323,
  1065,
  1401,
  707,
  1809,
  1316,
  1958,
  2779,
  3729,
  2028,
  1025,
  515,
  1326,
  2422,
  2975,
  3236,
  1687,
  2038,
  4072,
  3100,
  2189,
  1705,
  895,
  1948,
  1334,
  2143,
  2606,
  2323,
  3601,
  1589,

In [9]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action, _ = agent.choose_action(state)
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Rew