# Tabular Q-Learning Training Notebook

This notebook trains a **Tabular Q-Learning** agent on the Snake game.

**Grid Size:** 5√ó5 (optimal for tabular methods)

**Expected Performance:**
- Episodes to convergence: 500-1000
- Final average score: 8-12 apples
- Training time: 1-2 minutes

---

## 1. Setup and Imports

In [None]:
# For Google Colab
# !git clone https://github.com/MarinCervinschi/rl-snake.git
# %cd rl-snake

In [None]:
# Add parent directory to path for imports
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm


from game.config import GameConfig
from game.engine import SnakeGameEngine
from utils.metrics import TrainingMetrics

# For Google Colab: uncomment the following
# !git clone https://github.com/YOUR_REPO/rl-snake.git
# %cd rl-snake

## 2. Tabular Q-Learning Agent Implementation

In [None]:
import pickle
import random
from typing import Dict, Tuple
from game.entities import Action, State


class QLearningAgent:
    """
    Tabular Q-Learning with dictionary-based Q-table.
    """

    def __init__(
        self,
        grid_size: int = 5,
        learning_rate: float = 0.1,
        discount_factor: float = 0.99,
        epsilon: float = 1.0,
        epsilon_decay: float = 0.995,
        min_epsilon: float = 0.01,
    ):
        self.grid_size = grid_size
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.min_epsilon = min_epsilon

        # Q-table: dictionary mapping state_key -> [Q(s,a) for each action]
        self.q_table: Dict[Tuple, np.ndarray] = {}
        self.updates_performed = 0

    def get_action(self, state: State) -> Action:
        """Select action using Œµ-greedy policy."""
        state_key = self._state_to_key(state)

        if state_key not in self.q_table:
            self.q_table[state_key] = np.zeros(4, dtype=np.float32)

        if random.random() < self.epsilon:
            return random.choice(list(Action))
        else:
            q_values = self.q_table[state_key]
            action_idx = np.argmax(q_values)
            return Action(action_idx)

    def train(
        self,
        state: State,
        action: Action,
        reward: float,
        next_state: State,
        done: bool,
    ) -> None:
        """Update Q-table using Bellman equation."""
        state_key = self._state_to_key(state)
        next_state_key = self._state_to_key(next_state)
        action_idx = action.value

        if state_key not in self.q_table:
            self.q_table[state_key] = np.zeros(4, dtype=np.float32)

        if next_state_key not in self.q_table:
            self.q_table[next_state_key] = np.zeros(4, dtype=np.float32)

        current_q = self.q_table[state_key][action_idx]

        if done:
            target_q = reward
        else:
            max_next_q = np.max(self.q_table[next_state_key])
            target_q = reward + self.discount_factor * max_next_q

        self.q_table[state_key][action_idx] = current_q + self.learning_rate * (
            target_q - current_q
        )

        self.updates_performed += 1

        if done and self.epsilon > self.min_epsilon:
            self.epsilon *= self.epsilon_decay

    def _state_to_key(self, state: State) -> Tuple:
        """Convert State to hashable key for dictionary."""
        return state.to_position_tuple()

    def save(self, filepath: str = "models/tabular_q_learning.pkl") -> None:
        """Save Q-table and agent state."""
        save_path = Path(filepath)
        save_path.parent.mkdir(parents=True, exist_ok=True)

        save_dict = {
            "q_table": self.q_table,
            "grid_size": self.grid_size,
            "epsilon": self.epsilon,
            "learning_rate": self.learning_rate,
            "discount_factor": self.discount_factor,
            "updates_performed": self.updates_performed,
        }

        with open(save_path, "wb") as f:
            pickle.dump(save_dict, f)

        print(f"üíæ Model saved to {filepath}")
        print(f"   Q-table size: {len(self.q_table):,} entries")

    def load(self, filepath: str = "models/tabular_q_learning.pkl") -> None:
        """Load Q-table and agent state."""
        load_path = Path(filepath)

        if not load_path.exists():
            print(f"‚ö†Ô∏è  No saved model found at {filepath}")
            return

        with open(load_path, "rb") as f:
            save_dict = pickle.load(f)

        self.q_table = save_dict["q_table"]
        self.grid_size = save_dict["grid_size"]
        self.epsilon = save_dict["epsilon"]
        self.learning_rate = save_dict["learning_rate"]
        self.discount_factor = save_dict["discount_factor"]
        self.updates_performed = save_dict.get("updates_performed", 0)

        print(f"‚úÖ Model loaded from {filepath}")
        print(f"   States in Q-table: {len(self.q_table):,}")

## 3. Configuration

In [None]:
# Training configuration
GRID_SIZE = 5
EPISODES = 10_000

# Agent hyperparameters
LEARNING_RATE = 0.1
DISCOUNT_FACTOR = 0.99
EPSILON_START = 1.0
EPSILON_DECAY = 0.995
EPSILON_MIN = 0.01

# Save location
MODEL_PATH = "models/tabular_q_learning.pkl"
RESULTS_DIR = f"results/tabular_{GRID_SIZE}x{GRID_SIZE}"

print(f"üéÆ Training Tabular Q-Learning on {GRID_SIZE}√ó{GRID_SIZE} grid")
print(f"üìà Episodes: {EPISODES:,}")
print(f"üß† Hyperparameters: Œ±={LEARNING_RATE}, Œ≥={DISCOUNT_FACTOR}")

## 4. Initialize Environment and Agent

In [None]:
# Create game environment
config = GameConfig(grid_size=GRID_SIZE)
game = SnakeGameEngine(config)

# Create agent
agent = QLearningAgent(
    grid_size=GRID_SIZE,
    learning_rate=LEARNING_RATE,
    discount_factor=DISCOUNT_FACTOR,
    epsilon=EPSILON_START,
    epsilon_decay=EPSILON_DECAY,
    min_epsilon=EPSILON_MIN,
)

# Initialize metrics tracker
metrics = TrainingMetrics(save_dir=RESULTS_DIR)

print("‚úÖ Environment and agent initialized")

## 5. Training Loop

In [None]:
record_score = 0
pbar = tqdm(range(1, EPISODES + 1), desc="Training")

try:
    for episode in pbar:
        # Reset environment
        state = game.reset()
        done = False
        episode_reward = 0
        steps = 0

        # Episode loop
        while not done:
            # Agent selects action
            action = agent.get_action(state)

            # Execute action in environment
            reward, done, score = game.step(action)
            next_state = game.get_state()

            # Agent learns from transition
            agent.train(state, action, reward, next_state, done)

            # Update tracking
            state = next_state
            episode_reward += reward
            steps += 1

        # Record episode metrics
        metrics.record_episode(episode, score, steps, episode_reward)

        # Update record
        if score > record_score:
            record_score = score

        # Update progress bar
        pbar.set_postfix(
            {
                "Avg Score": f"{metrics.get_recent_average_score():.2f}",
                "Best": record_score,
                "Œµ": f"{agent.epsilon:.3f}",
                "Q-table": len(agent.q_table),
            }
        )
except KeyboardInterrupt:
    print("\n\n‚ö†Ô∏è  Training interrupted by user")

print("\n‚úÖ Training complete!")

## 6. Save Model

In [None]:
agent.save(MODEL_PATH)
print(f"\n‚úÖ Model saved to: {MODEL_PATH}")

## 7. Training Results and Visualization

In [None]:
# Print summary statistics
metrics.print_summary()

In [None]:
# Generate and display plots
metrics.plot(show=True, save=False)
plt.show()

## 8. Analyze Q-Table

In [None]:
# Q-table statistics
print(f"Q-table Statistics:")
print(f"  Total states visited: {len(agent.q_table):,}")
print(f"  Updates performed: {agent.updates_performed:,}")
print(f"  Final epsilon: {agent.epsilon:.4f}")

# Q-value distribution
all_q_values = [q for q_array in agent.q_table.values() for q in q_array]
print(f"\nQ-value Distribution:")
print(f"  Mean: {np.mean(all_q_values):.2f}")
print(f"  Std: {np.std(all_q_values):.2f}")
print(f"  Min: {np.min(all_q_values):.2f}")
print(f"  Max: {np.max(all_q_values):.2f}")

In [None]:
# Plot Q-value distribution
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.hist(all_q_values, bins=50, alpha=0.7, color='blue')
plt.xlabel('Q-value')
plt.ylabel('Frequency')
plt.title('Q-value Distribution')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
max_q_per_state = [np.max(q_array) for q_array in agent.q_table.values()]
plt.hist(max_q_per_state, bins=50, alpha=0.7, color='green')
plt.xlabel('Max Q-value')
plt.ylabel('Frequency')
plt.title('Max Q-value per State Distribution')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Test the Trained Agent

In [None]:
# Set epsilon to 0 for pure exploitation
agent.epsilon = 0.0

# Run test episodes
test_episodes = 10
test_scores = []

print(f"Testing agent for {test_episodes} episodes...")

for ep in range(test_episodes):
    state = game.reset()
    done = False
    
    while not done:
        action = agent.get_action(state)
        reward, done, score = game.step(action)
        state = game.get_state()
    
    test_scores.append(score)
    print(f"  Episode {ep+1}: Score = {score}")

print(f"\nTest Results:")
print(f"  Average Score: {np.mean(test_scores):.2f}")
print(f"  Best Score: {max(test_scores)}")
print(f"  Worst Score: {min(test_scores)}")

## 10. Export for Google Colab (Optional)

If running on Google Colab, you can download the trained model:

In [None]:
# Uncomment to download model in Colab
# from google.colab import files
# files.download(MODEL_PATH)