In [1]:
from SnakeEnv import SnakeEnv
import numpy as np
import random
from collections import defaultdict
from tqdm import tqdm

In [2]:
class SarsaAgent:
    def __init__(self, env, learning_rate=0.1, discount_factor=0.99, epsilon=0.1):
        self.env = env
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.q_table = defaultdict(lambda: np.zeros(env.action_space.n))

    def choose_action(self, state):
        if random.uniform(0, 1) < self.epsilon:
            return self.env.action_space.sample()  # Explore
        else:
            return np.argmax(self.q_table[state])  # Exploit

    def update_q_value(self, state, action, reward, next_state, next_action):
        td_target = reward + self.discount_factor * self.q_table[next_state][next_action]
        td_error = td_target - self.q_table[state][action]
        self.q_table[state][action] += self.learning_rate * td_error

    def train(self, num_episodes):
        for episode in tqdm(range(num_episodes), desc='Training', unit='Episode'):
            state = self.env.reset()
            state = tuple(state.flatten())  # Convert state to a tuple for dictionary key
            action = self.choose_action(state)
            done = False
            while not done:
                next_state, reward, done, _ = self.env.step(action)
                next_state = tuple(next_state.flatten())  # Convert state to a tuple for dictionary key
                next_action = self.choose_action(next_state)
                self.update_q_value(state, action, reward, next_state, next_action)
                state = next_state
                action = next_action
                

In [4]:
# Example usage
if __name__ == "__main__":
    env = SnakeEnv(grid_size=10)
    agent = SarsaAgent(env)
    agent.train(num_episodes=10000)

    # Test the trained agent
    state = env.reset()
    done = False
    while not done:
        action = agent.choose_action(tuple(state.flatten()))
        state, reward, done, _ = env.step(action)
        env.render()
        print(f"Reward: {reward}")

Training: 100%|██████████| 10000/10000 [00:08<00:00, 1176.81Episode/s]

. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . S . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. F . . . . . . . . 
Reward: 0
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . S . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. F . . . . . . . . 
Reward: 0
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . S . . 
. . . . . . . . . . 
. . . . . . . . . . 
. F . . . . . . . . 
Reward: 0
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . S . . 
. . . . . . . . . . 
. F . . . . . . . . 
Reward: 0
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . .


