<a href="https://colab.research.google.com/github/ROBOTdingDONG/Training-Data-Collection/blob/main/Snake_AI_Simulation_(Q_Learning).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import random
import collections
import time
import json
import logging
import math
import sys
from collections import namedtuple

# --- Logging Setup ---
# Configure logging to save to a file and print to console
log_filename = f"snake_rl_log_{time.strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(level=logging.INFO, # Set level to INFO (or DEBUG for more detail)
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.FileHandler(log_filename), # Log to file
                        logging.StreamHandler(sys.stdout) # Log to console
                    ])
logging.info("--- Snake RL Simulation Initialized ---")

# --- Game Components ---
Point = namedtuple('Point', ['x', 'y']) # Simple coordinate representation

# Absolute directions (indices map to vectors)
DIRECTIONS = {
    0: Point(0, -1),  # Up
    1: Point(0, 1),   # Down
    2: Point(-1, 0),  # Left
    3: Point(1, 0)    # Right
}
# Relative actions for the agent (relative to its current direction)
# 0: Go Straight, 1: Turn Left, 2: Turn Right
RELATIVE_ACTIONS = [0, 1, 2]

# --- Game Environment ---
class SnakeGame:
    """
    Represents the Snake game environment.
    Handles game state, rules, actions, and rewards.
    """
    def __init__(self, width=10, height=10):
        self.width = width
        self.height = height
        self.reset()
        logging.info(f"Game environment created ({width}x{height} grid).")

    def reset(self):
        """Resets the game to the starting state for a new episode."""
        # Place snake in the middle, moving right initially
        self.head = Point(self.width // 2, self.height // 2)
        self.snake = [self.head,
                      Point(self.head.x - 1, self.head.y),
                      Point(self.head.x - 2, self.head.y)]
        self.direction = 3 # Start moving right (index in DIRECTIONS)
        self.score = 0
        self.food = None
        self._place_food() # Place initial food
        self.game_over = False
        self.steps_taken = 0
        self.steps_since_food = 0
        # Heuristic limit to prevent infinite loops if agent gets stuck
        self.max_steps_no_food = self.width * self.height * 2
        # Return the initial state representation for the agent
        return self._get_state()

    def _place_food(self):
        """Places food randomly on the grid, ensuring it's not on the snake."""
        while True:
            x = random.randint(0, self.width - 1)
            y = random.randint(0, self.height - 1)
            self.food = Point(x, y)
            # Ensure food is not placed where the snake currently is
            if self.food not in self.snake:
                break

    def _move(self, action_relative):
        """
        Updates the snake's position based on a relative action.
        action_relative: 0 (Straight), 1 (Left Turn), 2 (Right Turn)
        """
        # Determine the new absolute direction based on the relative action
        if action_relative == 1: # Turn Left
            new_direction_idx = (self.direction - 1 + 4) % 4 # +4 ensures positive result
        elif action_relative == 2: # Turn Right
            new_direction_idx = (self.direction + 1) % 4
        else: # Go Straight (action_relative == 0)
            new_direction_idx = self.direction

        self.direction = new_direction_idx # Update snake's current direction
        move_vector = DIRECTIONS[self.direction] # Get the (dx, dy) vector for the move

        # Calculate new head position
        self.head = Point(self.head.x + move_vector.x, self.head.y + move_vector.y)
        # Insert new head at the beginning of the snake list
        self.snake.insert(0, self.head)

    def _is_collision(self, pt=None):
        """
        Checks if a given point (or the snake's head by default) results in a collision.
        Collisions occur with walls or the snake's own body.
        """
        if pt is None:
            pt = self.head # Check collision for the current head position if no point provided

        # Check wall collision
        if not (0 <= pt.x < self.width and 0 <= pt.y < self.height):
            return True
        # Check self collision (if the point is anywhere in the snake's body, excluding the head itself)
        if pt in self.snake[1:]:
            return True
        # No collision
        return False

    def step(self, action_relative):
        """
        Performs one step in the game based on the agent's action.
        Args:
            action_relative: The relative action chosen by the agent (0, 1, or 2).
        Returns:
            tuple: (next_state, reward, game_over)
                   - next_state: The state representation after the action.
                   - reward: The numerical reward obtained from this step.
                   - game_over: Boolean indicating if the game ended.
        """
        self.steps_taken += 1
        self.steps_since_food += 1

        # Move the snake according to the action
        self._move(action_relative)

        # Define rewards/penalties
        reward = 0 # Default reward per step

        # Check for game over conditions (collision or starvation)
        if self._is_collision() or self.steps_since_food > self.max_steps_no_food:
            self.game_over = True
            reward = -100 # Significant penalty for dying/starving
            logging.debug(f"Game Over. Collision: {self._is_collision()}, Starvation: {self.steps_since_food > self.max_steps_no_food}")
            # Return immediately as the game ended
            return self._get_state(), reward, self.game_over

        # Check if the snake ate the food
        if self.head == self.food:
            self.score += 1
            reward = 20 # Positive reward for eating food
            self._place_food() # Place new food
            self.steps_since_food = 0 # Reset starvation counter
            # Snake grows, so we don't pop the tail
            logging.debug(f"Food eaten! Score: {self.score}")
        else:
            # Snake moves, so remove the last segment (tail) if no food was eaten
            self.snake.pop()
            # Small penalty per step to encourage finding food faster
            reward = -0.1

        # Return the outcome of the step
        return self._get_state(), reward, self.game_over


    def _get_state(self):
        """
        Generates a state representation for the Q-learning agent.
        This converts the complex game situation into a simplified, hashable tuple
        that the agent can use to look up Q-values.

        State tuple components:
        1. Danger Straight: Is there an immediate collision if moving straight? (0=No, 1=Yes)
        2. Danger Left: Is there an immediate collision if turning left? (0=No, 1=Yes)
        3. Danger Right: Is there an immediate collision if turning right? (0=No, 1=Yes)
        4. Food Direction X: Relative X direction of food (-1=Left, 0=Same Col, 1=Right)
        5. Food Direction Y: Relative Y direction of food (-1=Up, 0=Same Row, 1=Down)
        """
        head = self.snake[0]

        # Determine the points immediately straight, left, and right relative to the snake's current direction
        point_straight, point_left, point_right = self._get_relative_points()

        # Check for collision danger at these relative points
        danger_straight = self._is_collision(point_straight)
        danger_left = self._is_collision(point_left)
        danger_right = self._is_collision(point_right)

        # Determine the general direction of the food relative to the snake's head
        food_dir_x = 0
        if self.food.x < head.x: food_dir_x = -1 # Food is to the Left
        elif self.food.x > head.x: food_dir_x = 1 # Food is to the Right

        food_dir_y = 0
        if self.food.y < head.y: food_dir_y = -1 # Food is Up
        elif self.food.y > head.y: food_dir_y = 1 # Food is Down

        # Compile the state information into a tuple (must be hashable for dictionary keys)
        state = (
            int(danger_straight), int(danger_left), int(danger_right),
            food_dir_x, food_dir_y,
            # Optional: Could add current direction, but increases state space size
            # self.direction
        )
        return state

    def _get_relative_points(self):
        """Calculates the coordinates of the points directly straight, left, and right
           relative to the snake's current heading."""
        head = self.snake[0]
        current_dir_vector = DIRECTIONS[self.direction] # Vector for current direction (dx, dy)

        # Calculate vectors for relative left and right turns using vector rotation
        # Left Turn: If current=(dx, dy), left=(-dy, dx)
        left_dir_vector = Point(-current_dir_vector.y, current_dir_vector.x)
        # Right Turn: If current=(dx, dy), right=(dy, -dx)
        right_dir_vector = Point(current_dir_vector.y, -current_dir_vector.x)

        # Calculate the actual points on the grid
        point_straight = Point(head.x + current_dir_vector.x, head.y + current_dir_vector.y)
        point_left = Point(head.x + left_dir_vector.x, head.y + left_dir_vector.y)
        point_right = Point(head.x + right_dir_vector.x, head.y + right_dir_vector.y)

        return point_straight, point_left, point_right

    def render_text(self):
        """(Optional) Renders a simple text-based view of the game state to the console."""
        # Create an empty grid representation
        grid = [['.' for _ in range(self.width)] for _ in range(self.height)]
        # Place food
        if self.food:
            if 0 <= self.food.y < self.height and 0 <= self.food.x < self.width:
                 grid[self.food.y][self.food.x] = 'F'
        # Place snake (Head 'H', Body 'S')
        for i, segment in enumerate(self.snake):
            if 0 <= segment.y < self.height and 0 <= segment.x < self.width:
                grid[segment.y][segment.x] = 'H' if i == 0 else 'S'

        # Print grid row by row
        print("-" * (self.width * 2 + 1)) # Top border
        for row in grid:
            print("|" + " ".join(row) + "|") # Rows with borders
        print("-" * (self.width * 2 + 1)) # Bottom border
        print(f"Score: {self.score} | Steps: {self.steps_taken} | Steps w/o Food: {self.steps_since_food}")


# --- Q-Learning Agent ---
class QLearningAgent:
    """
    Implements the Q-learning algorithm to learn how to play Snake.
    Manages the Q-table, exploration/exploitation strategy, and learning updates.
    """
    def __init__(self, actions, learning_rate=0.1, discount_factor=0.95, exploration_rate=1.0, exploration_decay=0.9995, min_exploration_rate=0.01):
        self.actions = actions # List of possible actions (e.g., [0, 1, 2])
        self.alpha = learning_rate      # Learning rate (how much new info overrides old)
        self.gamma = discount_factor    # Discount factor (importance of future rewards)
        self.epsilon = exploration_rate # Initial exploration probability
        self.epsilon_decay = exploration_decay # Rate at which exploration decreases
        self.min_epsilon = min_exploration_rate # Minimum exploration probability
        # Q-table: Stores Q-values for (state, action) pairs.
        # Uses nested defaultdicts for easy handling of unseen states/actions.
        # Q[state] -> {action1: value1, action2: value2, ...}
        self.q_table = collections.defaultdict(lambda: collections.defaultdict(float))
        self.training_mode = True # Agent learns when True
        logging.info("Q-Learning Agent created.")
        logging.info(f"  Learning Rate (alpha): {self.alpha}")
        logging.info(f"  Discount Factor (gamma): {self.gamma}")
        logging.info(f"  Initial Epsilon: {self.epsilon}")
        logging.info(f"  Epsilon Decay: {self.epsilon_decay}")
        logging.info(f"  Min Epsilon: {self.min_epsilon}")

    def set_training_mode(self, mode=True):
        """Enable or disable learning and exploration."""
        self.training_mode = mode
        if not mode:
            self.epsilon = 0 # No exploration if not training (pure exploitation)
        logging.info(f"Agent training mode set to: {self.training_mode}")

    def choose_action(self, state):
        """
        Selects an action based on the current state using the epsilon-greedy strategy.
        - With probability epsilon: Choose a random action (explore).
        - Otherwise: Choose the action with the highest Q-value for the state (exploit).
        """
        if self.training_mode and random.uniform(0, 1) < self.epsilon:
            # --- Exploration ---
            action = random.choice(self.actions)
            logging.debug(f"State: {state} -> Explore Action: {action}")
        else:
            # --- Exploitation ---
            # Get the Q-values for all possible actions from the current state
            q_values_for_state = self.q_table[state]

            # If this state has never been seen before, Q-values are default (0.0).
            # In this case, choose a random action as there's no learned preference.
            if not q_values_for_state:
                 action = random.choice(self.actions)
                 logging.debug(f"State: {state} (unseen) -> Random Action (Exploit fallback): {action}")
            else:
                # Find the maximum Q-value among the actions for this state
                max_q_value = max(q_values_for_state.values())
                # Get all actions that have this maximum Q-value (could be ties)
                best_actions = [a for a, q in q_values_for_state.items() if q == max_q_value]
                # Choose randomly among the best actions to break ties
                action = random.choice(best_actions)
                logging.debug(f"State: {state} -> Exploit Action: {action} (MaxQ={max_q_value:.3f})")

        return action

    def learn(self, state, action, reward, next_state, done):
        """
        Updates the Q-value for the executed state-action pair using the Q-learning rule.
        This is where the agent "learns" from its experience.

        Args:
            state: The state before the action was taken.
            action: The action taken by the agent.
            reward: The reward received after taking the action.
            next_state: The state the agent transitioned to after the action.
            done: Boolean indicating if the episode ended after this action.
        """
        if not self.training_mode:
            return # Do not update Q-table if not in training mode

        # --- Q-Learning Update Rule ---
        # Q(s, a) <- Q(s, a) + alpha * [Target - Q(s, a)]
        # where Target = reward + gamma * max_a'(Q(s', a'))  (if not done)
        #       Target = reward                           (if done)
        # s: current state, a: current action
        # s': next state, a': possible actions in next state
        # alpha: learning rate, gamma: discount factor

        # 1. Get the current Q-value for the (state, action) pair. Default is 0.0 if not seen.
        current_q = self.q_table[state][action]

        # 2. Calculate the 'Target' value (the estimate of the optimal future value)
        if done:
            # If the episode is over, there's no next state to consider.
            # The target is simply the final reward received.
            target = reward
            logging.debug(f"Learn (End State): Target = Reward = {reward:.2f}")
        else:
            # If the episode is not over, estimate the value of the next state.
            # Find the maximum Q-value among all possible actions in the 'next_state'.
            next_q_values = self.q_table[next_state] # Get {action: Q-value} dict for next state
            max_next_q = max(next_q_values.values()) if next_q_values else 0.0 # Max Q-value, or 0 if next_state is new
            # The target includes the immediate reward plus the discounted estimated value of the future (max_next_q).
            target = reward + self.gamma * max_next_q
            logging.debug(f"Learn: R={reward:.1f}, gamma={self.gamma}, max_next_Q={max_next_q:.3f} -> Target={target:.3f}")

        # 3. Calculate the update amount (TD Error scaled by learning rate)
        update_amount = self.alpha * (target - current_q)

        # 4. Update the Q-value in the table
        new_q = current_q + update_amount
        self.q_table[state][action] = new_q

        logging.debug(f"  State={state}, Action={action}")
        logging.debug(f"  CurrentQ={current_q:.3f}, Target={target:.3f}, Update={update_amount:.3f} -> NewQ={new_q:.3f}")


    def decay_epsilon(self):
        """Decreases the exploration rate (epsilon) over time, down to a minimum value."""
        if self.training_mode:
            old_epsilon = self.epsilon
            self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)
            if self.epsilon != old_epsilon:
                 logging.debug(f"Epsilon decayed from {old_epsilon:.4f} to {self.epsilon:.4f}")

    def save_q_table(self, filename="q_table.json"):
        """Saves the learned Q-table to a JSON file."""
        # Convert tuple keys (states) to strings for JSON compatibility
        serializable_q_table = {str(k): dict(v) for k, v in self.q_table.items()}
        try:
            with open(filename, 'w') as f:
                json.dump(serializable_q_table, f, indent=4)
            logging.info(f"Q-table (size {len(self.q_table)}) saved to {filename}")
        except Exception as e:
            logging.error(f"Error saving Q-table to {filename}: {e}")

    def load_q_table(self, filename="q_table.json"):
        """Loads a Q-table from a JSON file."""
        try:
            with open(filename, 'r') as f:
                loaded_q_table_str_keys = json.load(f)

            # Convert string keys back to tuples and inner dict keys/values appropriately
            self.q_table = collections.defaultdict(lambda: collections.defaultdict(float))
            loaded_count = 0
            for state_str, actions_dict in loaded_q_table_str_keys.items():
                try:
                    # Convert state string '(int, int, ...)' back to tuple of ints
                    # This assumes the state tuple contains only integers
                    state_tuple = tuple(map(int, state_str.strip('()').split(',')))
                except ValueError:
                    logging.warning(f"Skipping invalid state key format during load: {state_str}")
                    continue # Skip this state if parsing fails

                # Convert inner action keys (strings) to ints and values to floats
                inner_dict = collections.defaultdict(float)
                for action_str, q_value in actions_dict.items():
                    try:
                         inner_dict[int(action_str)] = float(q_value)
                    except ValueError:
                         logging.warning(f"Skipping invalid action/q-value format in state {state_str}: action={action_str}, value={q_value}")
                         continue # Skip invalid action entry

                self.q_table[state_tuple] = inner_dict
                loaded_count += 1

            logging.info(f"Q-table loaded from {filename}. {loaded_count} states loaded. Total size: {len(self.q_table)} states.")
        except FileNotFoundError:
            logging.warning(f"Q-table file '{filename}' not found. Starting with an empty table.")
        except json.JSONDecodeError:
             logging.error(f"Error decoding JSON from Q-table file '{filename}'. Starting with empty table.")
        except Exception as e:
            logging.error(f"Error loading Q-table from {filename}: {e}. Starting with an empty table.")


# --- Simulation Loop ---
def run_simulation(episodes=10000, render_every_n=0, log_every_n=100, save_q_table_every_n=5000, load_filename=None):
    """
    Runs the main simulation loop for training the agent.

    Args:
        episodes: Total number of games (episodes) to simulate.
        render_every_n: Render the game board every N episodes (0 to disable).
        log_every_n: Log summary statistics every N episodes.
        save_q_table_every_n: Save the Q-table every N episodes (0 to disable).
        load_filename: Path to a Q-table file to load before starting, or None.
    """
    game = SnakeGame(width=10, height=10) # Create the game instance
    agent = QLearningAgent(actions=RELATIVE_ACTIONS) # Create the agent instance

    # Load pre-trained Q-table if specified
    if load_filename:
        agent.load_q_table(load_filename)

    # Lists to store results for analysis/logging
    episode_scores = []
    episode_steps = []
    max_score_so_far = -1

    logging.info(f"--- Starting Simulation: {episodes} episodes ---")
    start_time = time.time()

    for episode in range(1, episodes + 1):
        state = game.reset() # Start a new game
        done = False
        current_episode_steps = 0

        # --- Single Episode Loop ---
        while not done:
            # 1. Agent chooses action based on current state
            action = agent.choose_action(state)

            # 2. Environment executes action and returns outcome
            next_state, reward, done = game.step(action)

            # 3. Agent learns from the transition (state, action, reward, next_state, done)
            agent.learn(state, action, reward, next_state, done)

            # 4. Update current state for the next iteration
            state = next_state
            current_episode_steps += 1

            # Optional: Render the game board
            if render_every_n > 0 and episode % render_every_n == 0:
                 # Simple console clearing (may vary by OS)
                 # print("\033[H\033[J", end="") # Clears console on Linux/macOS
                 game.render_text()
                 print(f"Episode: {episode}, Step: {current_episode_steps}, Action: {action}, Reward: {reward:.1f}")
                 print(f"State: {state}")
                 time.sleep(0.05) # Pause briefly to make rendering viewable

        # --- End of Episode ---
        agent.decay_epsilon() # Decrease exploration chance for next episode
        episode_scores.append(game.score)
        episode_steps.append(current_episode_steps)
        if game.score > max_score_so_far:
            max_score_so_far = game.score

        # Log summary statistics periodically
        if episode % log_every_n == 0:
             avg_score = sum(episode_scores[-log_every_n:]) / log_every_n
             avg_steps = sum(episode_steps[-log_every_n:]) / log_every_n
             logging.info(f"Ep {episode}/{episodes} | "
                          f"Last Score: {game.score} | Max Score: {max_score_so_far} | "
                          f"Avg Score ({log_every_n} ep): {avg_score:.2f} | "
                          f"Avg Steps ({log_every_n} ep): {avg_steps:.1f} | "
                          f"Epsilon: {agent.epsilon:.4f} | "
                          f"Q-States: {len(agent.q_table)}")

        # Save Q-table periodically
        if save_q_table_every_n > 0 and episode % save_q_table_every_n == 0:
            agent.save_q_table(f"q_table_ep_{episode}.json")

    # --- End of Simulation ---
    end_time = time.time()
    total_time = end_time - start_time
    logging.info(f"--- Simulation Finished ---")
    logging.info(f"Total Episodes: {episodes}")
    logging.info(f"Total Time: {total_time:.2f} seconds ({total_time/episodes:.4f} sec/ep)")
    logging.info(f"Final Max Score: {max_score_so_far}")
    logging.info(f"Final Q-Table size: {len(agent.q_table)} states.")

    # Save the final Q-table
    agent.save_q_table("q_table_final.json")

    # Potential next step: Add plotting of scores over episodes using matplotlib


# --- Main Execution Guard ---
if __name__ == "__main__":
    # Configure simulation parameters here
    NUM_EPISODES = 20000       # How many games to play for training
    RENDER_EVERY = 0         # Show game board every N episodes (0 = never)
    LOG_EVERY = 100          # Print summary log every N episodes
    SAVE_EVERY = 5000        # Save Q-table snapshot every N episodes
    LOAD_FILE = None         # Set to "q_table_final.json" or similar to resume training

    run_simulation(
        episodes=NUM_EPISODES,
        render_every_n=RENDER_EVERY,
        log_every_n=LOG_EVERY,
        save_q_table_every_n=SAVE_EVERY,
        load_filename=LOAD_FILE
    )