In [1]:
# allow to import modules from the project root directory
import sys
import os
# Add the project root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

In [None]:
import numpy as np
import random
from collections import defaultdict
from tqdm import tqdm
import pickle
from src.ParObsSnakeEnv import ParObsSnakeEnv
from src.FullObsSnakeEnv import FullObsSnakeEnv

<img src="../../artifacts/images/Qlearning.png" alt="Q-learning algorithm" width="1000" />

In [3]:
import numpy as np
import random
from collections import defaultdict
from tqdm import tqdm

class QLearningAgent:
    def __init__(self, env, learning_rate=0.5, discount_factor=0.99, epsilon=0.1, learning_rate_decay=0.8, epsilon_decay=0.9):
        self.env = env
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.learning_rate_decay = learning_rate_decay
        self.epsilon_decay = epsilon_decay
        self.q_table = defaultdict(lambda: np.zeros(env.action_space.n))

    def choose_action(self, state, Greedy=False):
        if Greedy:
            return np.argmax(self.q_table[state])
        if random.uniform(0, 1) < self.epsilon:
            return self.env.action_space.sample()  # Explore
        else:
            return np.argmax(self.q_table[state])  # Exploit

    def update_q_value(self, state, action, reward, next_state):
        next_action = self.choose_action(next_state, Greedy=True)  # Choose next action using epsilon-greedy approach
        td_target = reward + self.discount_factor * self.q_table[next_state][next_action]
        td_error = td_target - self.q_table[state][action]
        self.q_table[state][action] += self.learning_rate * td_error

    def print_qtable(self):
        for state, actions in self.q_table.items():
            print(f"State: {state}")
            for action, value in enumerate(actions):
                print(f"  Action {action}: {value:.2f}")
            print()

    def train(self, num_episodes):
        for episode in tqdm(range(num_episodes), desc='Training', unit='Episode'):
            state = self.env.reset()
            state = tuple(state.flatten())  # Convert state to a tuple for dictionary key
            done = False
            while not done:
                action = self.choose_action(state)
                next_state, reward, done, _ = self.env.step(action)
                next_state = tuple(next_state.flatten())  # Convert state to a tuple for dictionary key
                self.update_q_value(state, action, reward, next_state)
                state = next_state
            if episode % 10000 == 0 and episode != 0:
                self.epsilon *= self.epsilon_decay  # Decay epsilon
                self.learning_rate *= self.learning_rate_decay


In [None]:
grid_size = 10
# env = FullObsSnakeEnv(grid_size=grid_size, interact=False)
env = ParObsSnakeEnv(grid_size=grid_size, interact=False)
agent = QLearningAgent(env, epsilon=0.1, discount_factor=0.9, learning_rate=0.9)

In [5]:
# Load the Q-table
# with open('q_learning_agent.pkl', 'rb') as f:
#     agent.q_table = defaultdict(lambda: np.zeros(env.action_space.n), pickle.load(f))

In [None]:
num_episodes = 50000
agent.train(num_episodes=num_episodes)

Training:   2%|▏         | 202423/10000000 [04:39<3:45:13, 725.04Episode/s]


KeyboardInterrupt: 

In [None]:
environment = 'full 'if isinstance(env, FullObsSnakeEnv) else 'part'

with open(f'q_learning_table_{environment}_{num_episodes}_{grid_size}.pkl', 'wb') as f:
        pickle.dump(dict(agent.q_table), f)

In [None]:
if isinstance(env, FullObsSnakeEnv):
    env.interact = True
else:
    env = ParObsSnakeEnv(grid_size=2*grid_size)
    
state = env.reset()
done = False
while not done:
    action = agent.choose_action(tuple(state.flatten()))
    state, reward, done, _ = env.step(action)
    env.render()
    print(f"Reward: {reward}")

Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 76
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 76
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: -1
Reward: 1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: -1
Reward: 1
Reward: -1
Reward: -75
