In [None]:
%pip install tensorflow
from google.colab import drive
drive.mount('/content/drive')

In [20]:
import collections
import random
import numpy as np

board_size=4
max_steps=2*board_size*board_size

class SnakeEnv:
    def __init__(self,board_size , max_steps):
        self.board_size=board_size
        self.max_steps=max_steps
        self.reset()

    def reset(self):
        self.board= np.zeros((board_size, board_size), dtype=int) # Add board back
        self.snake = collections.deque([(0, 0)])
        self.fruit_pos = self._generate_fruit()
        self.steps = 0
        self.game_over = False
        self.board[self.snake[0]] = 1
        return self.board
    def step(self, action):
        if self.game_over:
            return self.board, 0, True, {}
        self.steps += 1
        head_r, head_c = self.snake[0]
        if action == 0: # Up
            next_head = (head_r - 1, head_c)
        elif action == 1: # Down
            next_head = (head_r + 1, head_c)
        elif action == 2: # Left
            next_head = (head_r, head_c - 1)
        elif action == 3: # Right
            next_head = (head_r, head_c + 1)
        else:
            raise ValueError("Invalid action")
        collision = self._is_collision(next_head)

        if collision or self.steps >= self.max_steps:
            self.game_over = True
            reward = -10 # Collision or max steps reached penalty
            return self.board, reward, self.game_over, {}

        prev_dist = abs(head_r - self.fruit_pos[0]) + abs(head_c - self.fruit_pos[1]) if self.fruit_pos else 0
        ate_fruit = False
        self.board[self.snake[0]] =2
        self.board[next_head] = 1 # Add new head
        self.snake.appendleft(next_head)
        if next_head == self.fruit_pos:
            ate_fruit = True
            self.fruit_pos = self._generate_fruit()
            self.steps = 0 # Reset steps on eating fruit
        else:
            self.board[self.snake[-1]] = 0 # Remove tail
            self.snake.pop() # Remove tail
        current_dist = abs(next_head[0] - self.fruit_pos[0]) + abs(next_head[1] - self.fruit_pos[1]) if self.fruit_pos else 0
        reward = self._calculate_reward(prev_dist, current_dist, ate_fruit)
        return self.board, reward, self.game_over, {}
    def _generate_fruit(self):
        # Generate a new fruit position in an empty cell
        all_cells = set((r, c) for r in range(self.board_size) for c in range(self.board_size))
        snake_cells = set(self.snake)
        empty_cells = list(all_cells - snake_cells)
        if not empty_cells:
            return None  # No empty cells
        pos=random.choice(empty_cells)
        self.board[pos]=3
        return pos
    def _is_collision(self, head):
        r, c = head
        if r < 0 or r >= self.board_size or c < 0 or c >= self.board_size:
            return True
        if head in self.snake:
             return True
        return False
    def _calculate_reward(self, prev_dist, current_dist, ate_fruit):
        if ate_fruit:
            return 10
        elif current_dist < prev_dist:
            return 1  # Moving closer to fruit
        elif current_dist > prev_dist:
            return -1  # Moving away from fruit
        else:
            return 0 # No change in distance


In [21]:

state_space_size = board_size * board_size * board_size * board_size # (fruit_r, fruit_c, head_r, head_c)
action_space_size = 4 # (Up, Down, Left, Right)

q_table = np.zeros((state_space_size, action_space_size))

print("Q-table shape:", q_table.shape)

Q-table shape: (256, 4)


In [6]:
def get_state_index(fruit_pos, snake, board_size):
    """
    Converts the game state into a unique integer index.

    Args:
        fruit_pos: A tuple (row, col) of the fruit position.
        snake: A collections.deque of tuples representing the snake's body.
        board_size: The size of the square game board.

    Returns:
        A unique integer index representing the state.
    """
    head_r, head_c = snake[0]
    fruit_r, fruit_c = fruit_pos

    # Calculate a unique index based on the positions
    # Index = fruit_r * board_size^3 + fruit_c * board_size^2 + head_r * board_size^1 + head_c * board_size^0
    state_index = fruit_r * (board_size ** 3) + fruit_c * (board_size ** 2) + head_r * board_size + head_c
    return state_index

# Example usage:
# Assuming env is an instance of SnakeEnv
# state_index = get_state_index(env.fruit_pos, env.snake, env.board_size)
# print(f"State index: {state_index}")

In [22]:
def epsilon_greedy_policy(q_table, state_index, epsilon, action_space_size):
    """
    Selects an action based on the epsilon-greedy policy.

    Args:
        q_table: The Q-table.
        state_index: The index of the current state.
        epsilon: The epsilon value for exploration.
        action_space_size: The number of possible actions.

    Returns:
        The selected action (an integer).
    """
    if random.uniform(0, 1) < epsilon:
        # Explore: Choose a random action
        action = random.randrange(action_space_size)
    else:
        # Exploit: Choose the action with the highest Q-value for the current state
        action = np.argmax(q_table[state_index, :])
    return action

# Example usage (assuming q_table, state_index, epsilon, and action_space_size are defined)
# selected_action = epsilon_greedy_policy(q_table, state_index, epsilon, action_space_size)
# print(f"Selected action: {selected_action}")

In [None]:
# Q-learning parameters
alpha = 0.1  # Learning rate
gamma = 0.6  # Discount factor
epsilon = 1.0  # Exploration rate
max_epsilon = 1.0  # Exploration probability at start
min_epsilon = 0.01 # Minimum exploration probability
epsilon_decay_rate = 0.999 # Exponential decay rate for epsilon
num_episodes = 100000 # Number of training episodes (increased by 10 times)
# Create an instance of the environment
env = SnakeEnv(board_size, max_steps)

# Training loop
for episode in range(num_episodes):
    state = env.reset()
    state_index = get_state_index(env.fruit_pos, env.snake, env.board_size)
    done = False

    while not done:
        # Select action using epsilon-greedy policy
        action = epsilon_greedy_policy(q_table, state_index, epsilon, action_space_size)

        # Take action in the environment
        next_state, reward, done, info = env.step(action)

        # Check if fruit_pos is None after the step (board full)
        if env.fruit_pos is None:
            done = True # Treat board full as a terminal state
            next_state_index = state_space_size - 1
            pass # The Q-table update should only happen if not done

        if not done:
            # Get index of the next state
            next_state_index = get_state_index(env.fruit_pos, env.snake, env.board_size)

            # Update Q-value using the Q-learning formula
            # Q(s, a) = Q(s, a) + alpha * [reward + gamma * max(Q(s', a')) - Q(s, a)]
            max_future_q = np.max(q_table[next_state_index, :])
            current_q = q_table[state_index, action]
            new_q = current_q + alpha * (reward + gamma * max_future_q - current_q)
            q_table[state_index, action] = new_q

            # Update current state
            state_index = next_state_index
        else:
            # If done, just update the Q-value for the terminal state transition
            current_q = q_table[state_index, action]
            new_q = current_q + alpha * (reward - current_q) # No future Q-value for terminal state
            q_table[state_index, action] = new_q


    # Decay epsilon
    epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-epsilon_decay_rate*episode)

    if (episode + 1) % 1000 == 0: # Print progress less frequently for more episodes
        print(f"Episode {episode + 1}/{num_episodes} completed. Epsilon: {epsilon:.2f}")

print("Training finished.")

In [24]:
import time
from IPython.display import clear_output

def evaluate_agent(q_table, env, num_eval_episodes, visualize=False, delay=0.1):
    episode_rewards = []
    fruits_eaten_per_episode = []
    final_snake_length_per_episode = []

    print(f"\nStarting evaluation for {num_eval_episodes} episodes...")

    for episode in range(num_eval_episodes):
        state = env.reset()
        # Ensure fruit_pos is not None after reset before getting state index
        if env.fruit_pos is None:
             # This case should ideally not happen immediately after reset,
             # but handling it for robustness.
             print(f"Warning: Fruit is None after reset in evaluation episode {episode + 1}")
             episode_rewards.append(0) # Or some other appropriate reward
             fruits_eaten_per_episode.append(0)
             final_snake_length_per_episode.append(len(env.snake))
             continue # Skip this episode

        state_index = get_state_index(env.fruit_pos, env.snake, env.board_size)
        done = False
        total_reward = 0
        fruits_eaten = 0


        while not done:
            if visualize:
                clear_output(wait=True)
                print(f"Evaluation Episode {episode + 1}/{num_eval_episodes}")
                print("Current State:")
                # Assuming the state returned by env.step and env.reset is the board
                # Print the board using characters instead of numbers
                for r in range(env.board_size):
                    row_display = ""
                    for c in range(env.board_size):
                        if (r, c) == env.snake[0]:
                            row_display += " H " # Snake head
                        elif (r, c) in list(env.snake)[1:]:
                            row_display += " S " # Snake body
                        elif (r, c) == env.fruit_pos:
                            row_display += " F " # Fruit
                        else:
                            row_display += " . " # Empty cell
                    print(row_display)

                print(f"Total Reward: {total_reward}")
                print(f"Fruits eaten: {fruits_eaten}")
                print(f"Snake length: {len(env.snake)}")
                time.sleep(delay)

            # Select the best action based on the Q-table (exploitation only)
            action = np.argmax(q_table[state_index, :])

            # Take action in the environment
            next_state, reward, done, info = env.step(action)

            total_reward += reward
            state = next_state

            # Check if fruit was eaten in this step
            if reward == 10: # Assuming reward of 10 is only for eating fruit
                 fruits_eaten += 1


            # Update state index for the next step if not done
            if not done:
                 # Ensure fruit_pos is not None before getting state index
                if env.fruit_pos is not None:
                    state_index = get_state_index(env.fruit_pos, env.snake, env.board_size)
                else:
                    # Handle case where fruit_pos is None during gameplay (board full)
                    # This is likely a terminal state, so state_index won't be used for Q-lookup
                    done = True # Ensure the loop terminates
                    print(f"Evaluation Episode {episode + 1}/{num_eval_episodes} ended: Board Full")


        episode_rewards.append(total_reward)
        fruits_eaten_per_episode.append(fruits_eaten)
        final_snake_length_per_episode.append(len(env.snake))

        if visualize:
            clear_output(wait=True)
            print(f"Evaluation Episode {episode + 1}/{num_eval_episodes} finished with total reward: {total_reward}, fruits eaten: {fruits_eaten}, final length: {len(env.snake)}")
            print("Final State:")
            # Print the final board state using characters
            for r in range(env.board_size):
                row_display = ""
                for c in range(env.board_size):
                    if (r, c) == env.snake[0]:
                        row_display += " H " # Snake head
                    elif (r, c) in list(env.snake)[1:]:
                        row_display += " S " # Snake body
                    elif (r, c) == env.fruit_pos:
                        row_display += " F " # Fruit
                    else:
                        row_display += " . " # Empty cell
                print(row_display)
            print(f"Final Total Reward: {total_reward}, Fruits eaten: {fruits_eaten}, Final length: {len(env.snake)}")


    avg_reward = np.mean(episode_rewards)
    avg_fruits_eaten = np.mean(fruits_eaten_per_episode)
    avg_final_snake_length = np.mean(final_snake_length_per_episode)


    print(f"\nAverage reward over {num_eval_episodes} evaluation episodes: {avg_reward:.2f}")
    print(f"Average fruits eaten per episode: {avg_fruits_eaten:.2f}")
    print(f"Average final snake length per episode: {avg_final_snake_length:.2f}")


    return episode_rewards, fruits_eaten_per_episode, final_snake_length_per_episode

# Example usage after training:
# eval_rewards, eval_fruits, eval_lengths = evaluate_agent(q_table, env, num_eval_episodes=100, visualize=True)

In [None]:
# Evaluate the trained agent with visualization
num_eval_episodes = 100
eval_rewards, eval_fruits, eval_lengths = evaluate_agent(q_table, env, num_eval_episodes, visualize=True)

### 儲存訓練好的 Q-table

In [28]:
import numpy as np

# Define the path to save the q_table
q_table_save_path = '/content/drive/MyDrive/Colab Notebooks/snake_by_Qlearning1/snake_by_Qlearning_some_state_table.npy' # You can change the filename and path

# Save the q_table to a file
np.save(q_table_save_path, q_table)

print(f"Q-table saved successfully to {q_table_save_path}")

Q-table saved successfully to /content/drive/MyDrive/Colab Notebooks/snake_by_Qlearning1/snake_by_Qlearning_some_state_table.npy


### 載入儲存的 Q-table

In [29]:
# Define the path to load the q_table from
q_table_load_path = '/content/drive/MyDrive/Colab Notebooks/snake_by_Qlearning1/snake_by_Qlearning_some_state_table.npy' # Make sure this matches the save path

# Load the q_table from the file
try:
    q_table = np.load(q_table_load_path)
    print(f"Q-table loaded successfully from {q_table_load_path}")
    print("Current Q-table shape:", q_table.shape)
except FileNotFoundError:
    print(f"Error: Q-table file not found at {q_table_load_path}")
    print("Please ensure the file exists or train a new Q-table.")
    # Optionally, initialize a new Q-table if the file is not found
    # q_table = np.zeros((state_space_size, action_space_size))

Q-table loaded successfully from /content/drive/MyDrive/Colab Notebooks/snake_by_Qlearning1/snake_by_Qlearning_some_state_table.npy
Current Q-table shape: (256, 4)
