## Proximal Policy Optimisation in Snake

<center><img src="../assets/snake.png"/></center>

### 1. Environment
The environment is a fully vectorised Snake game implemented in NumPy, with an optional GUI reliant on Pygame. The game area is represented as a 2-dimensional integer array. Blank positions have a value of 0, the food position is marked with -1, and the snake is represented as a sequence of ascending integers, starting at 2. Moving the snake is achieved by adding 1 to each of its segments and inserting the head's index at its next position.

In [None]:
from game import Snake

# Launch a playable instance of Snake. Use the arrow keys to control the snake.
Snake.play(width = 12, height = 9, fps = 3)

#### 1.1 Action Space
While there are technically four possible movement directions ("Left," "Right," "Up," "Down"), the game will be played from the snake's perspective. This results in an action space of three discrete elements: "Turn Left," "Move Forward," and "Turn Right", which reduces the number of learnable parameters and prevents collisions with the snake's second segment.

Both actions and directions are represented as enums of integers, with values defined deliberately to allow arithmetic operations to determine the result of a direction-action pair, eliminating the need for manual mapping. Directions and actions are defined in clockwise order with ascending indices (0 through 3 and 0 through 2, respectively). The next direction after performing an action is computed as the sum of the current direction's index and the action's index (minus 1), followed by modulo division.

#### 1.2 Observation Space
Conventionally, Snake involves a top-down view of a fully observable game area. To reduce the number of learnable parameters compared to a convolutional neural network and to improve convergence speed, a different approach was chosen. Instead of providing a tensor representation of the entire game state, the game is played from the snake's perspective, rendering the game area only partially observable. The agent operates within a low-dimensional observation space containing the following components:
- 5 decimal values representing danger in the left, left-forward, forward, right-forward, and right directions, with values ranging from 0 to 2.
- A 4-dimensional one-hot vector of the snake's current movement direction.
- A 4-dimensional one-hot vector of the food's relative position.
- The percentage of used-up game time, introducing a sense of urgency similar to a penalty per game step.

The snake perceives danger as follows:
1. Generate a matrix where 1s represent positions that would harm the snake by the time it reaches them. Since the snake cannot leave the game area, this "danger map" extends slightly beyond the game state and is framed by danger values.
2. Create an integer array containing the Manhattan (L1) distance between the snake's head and each position on the danger map.
3. Normalise the danger map by dividing it by the squared L1 distance (with a lower bound of 1 to avoid dividing by zero).
4. Compute the perceived danger in each direction as the sum of normalised danger map values in that direction (excluding the head).

<center><img src="../assets/danger_map.png"/></center>

This representation offers the added advantage of being entirely independent of the game's grid size. A trained agent that performs well on a 12x9 grid will achieve similar performance on smaller or larger grids, as neither the observation dimensions nor the value ranges change. One-hot values have a fixed value spectrum, and danger observations, as implemented, form a geometric series converging to a maximum of 2.

#### 1.3 Rewards
Suitable rewards are crucial for successful reinforcement learning, as they substantially impact both convergence speed and the agent's strategy after training. The following rewards are used:
- **+1**: The snake eats food.
- **-2**: The snake collides with itself or runs out of time (i.e., exceeds a number of moves dependent on its current length and the game's grid size).
- A small penalty per game step.

### 2. Model Definition
The model to be trained is an actor-critic policy network comprising the following components:
- **Shared linear layers:** Transform the input observation into a shared feature vector.
- **Value stream:** Predicts the value of the observation using the shared features.
- **Policy stream:** Predicts logarithmic action probabilities based on the shared features.

### 3. Training Algorithm
Training will be done using OpenAI's Proximal Policy Optimisation (PPO) algorithm, a policy-gradient method renowned for its robustness and ease of implementation.

#### 3.1 Advantage Estimation
To stabilise training and improve convergence speed, we estimate the agent's average performance in a given state, referred to as the state's advantage. This metric provides a more accurate evaluation of the quality of actions.  
A popular method for calculating advantages is **Generalised Advantage Estimation (GAE)**, defined as:  
$$
A^{\text{GAE}}_t = \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l}
$$
where the reward deltas are computed as:  
$$
\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)
$$

#### 3.2 Training Data
Each training step simulates an entire game, tracking the following features:
- **Observed game states:** The states encountered during gameplay.
- **Sampled actions:** Actions drawn from the model's policy logits, as predicted by the policy stream.
- **Rewards:** The rewards received for taking the above actions.
- **Action probabilities:** The log probabilities of each action, as predicted by the policy stream.
- **Predicted state values:** The state values predicted by the model, adjusted using the Generalised Advantage Estimate (GAE).

Typically, a training function would require an iteration limit. However, the Snake environment enforces its own iteration limit, making this step redundant.

#### 3.3 Training Objectives
PPO separates the optimisation of the policy and value functions, each with a distinct training objective (and corresponding optimizer):

##### 3.3.1 Value Loss
The value function loss measures the mean squared error between the predicted state values under the current policy and the actual return values:  
$$
L^{\text{value}}(\theta) = \mathbb{E}_t \left[ \left( V_\theta(s_t) - G_t \right)^2 \right]
$$  

##### 3.3.2 Policy Loss
The policy loss ensures stable updates by limiting large changes to the policy. This is achieved through the **Clipped Surrogate Objective**:  
$$
L^{\text{clip}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right]
$$  

Here, the probability ratio between the new and old policies is defined as:  
$$
r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_{\text{old}}}(a_t \mid s_t)}
$$  

##### 3.3.3 Clipping Mechanism
PPO's clipping mechanism restricts policy updates to a trusted range, enabling multiple gradient steps on the same data. For each game, up to `max_policy_updates` policy updates are performed. The training episode ends if the **Kullback–Leibler (KL) divergence** between the old and new action distributions exceeds a predefined threshold. The KL divergence is defined as:  
$$
D_{\text{KL}}(P \| Q) = \sum_x P(x) \log \frac{P(x)}{Q(x)}
$$

### 4. Training Loop
The model is trained using the following process:
1. The agent plays the game until it reaches a terminal state.
2. Training data is collected, and the Generalised Advantage Estimate (GAE) is computed.
3. The data is randomly shuffled to enhance training stability.
4. Separate gradient steps are performed for the policy and value networks.

During training, the agent's high score is monitored. If it surpasses the previous best, a checkpoint of the agent is saved. After the final game, a final checkpoint is created, concluding the process.

In [None]:
from agent.training import PPOTrainer # The trainer class contains all optimization logic for the PPO algorithm.
from agent.training import train      # The training function runs the training loop for a given environment and trainer.
from agent import ACN                 # The model is an actor-critic network that predicts action probabilities and state values.
from game.env import SnakeEnv         # The environment class wraps a vectorised game instance, computing observations and rewards.
from game.gui import render_replay    # The render function creates a gif of a game replay.

# Create a new game instance and agent model.
env = SnakeEnv(
    width  = 5,
    height = 5,
    gui    = False,
)
model = ACN(
    observation_space = env.observation_space,
    action_space      = env.action_space,
    hidden_size       = 128,
)
trainer = PPOTrainer(
    model              = model,
    ppo_clip_value     = 0.2,  # PPO clip value. Clamps the ratio of old and new policy to prevent large updates.
    max_policy_updates = 30,   # Maximum number of policy updates per game. Prevents overfitting.
    target_kl_div      = 0.02, # Maximum allowed KL divergence. Specifies the safe range for policy updates.
    policy_lr          = 2e-5, # Learning rate for the policy network. Higher values can skip local minima, but may overshoot.
    value_lr           = 2e-6, # Learning rate for the value network. Lower than the policy rate to reduce feature layer drift.
    gamma              = 0.99, # Discount factor. Determines the importance of future rewards.
)

# Train the agent.
best_game = train(
    env            = env,
    trainer        = trainer,
    episodes       = 100_000,
    checkpoint_dir = '../assets/models/',
)
# Finally, render a gif of the best game.
render_replay(best_game, output_file = '../assets/gameplay.gif')

EPISODE    100 | Mean:  0.1, Highscore:  1 |▕█▏
EPISODE    200 | Mean:  0.3, Highscore:  2 |▕█ ▏
EPISODE    300 | Mean:  0.4, Highscore:  3 |▕█░ ▏
EPISODE    400 | Mean:  0.2, Highscore:  3 |▕█ ▏
EPISODE    500 | Mean:  0.6, Highscore:  3 |▕█░ ▏
EPISODE    600 | Mean:  0.4, Highscore:  3 |▕█░▏
EPISODE    700 | Mean:  0.5, Highscore:  3 |▕█░▏
EPISODE    800 | Mean:  0.8, Highscore:  3 |▕░█░▏
EPISODE    900 | Mean:  0.9, Highscore:  3 |▕░█░▏
EPISODE   1000 | Mean:  2.3, Highscore:  9 |▕ ░█░     ▏
EPISODE   1100 | Mean:  4.5, Highscore: 11 |▕   ░█░░    ▏
EPISODE   1200 | Mean:  8.8, Highscore: 21 |▕      ░░░█░░░        ▏
EPISODE   1300 | Mean: 16.9, Highscore: 32 |  ▕          ░░░░█░░░░░         ▏
EPISODE   1400 | Mean: 21.6, Highscore: 35 |▕                 ░░░░░█░░░        ▏
EPISODE   1500 | Mean: 25.1, Highscore: 39 | ▕                     ░░█░░░░         ▏
EPISODE   1600 | Mean: 25.3, Highscore: 39 |      ▕                ░░░█░░░        ▏
EPISODE   1700 | Mean: 26.0, Highscore: 39 |  

### 5. Results
After training for 100,000 episodes, we observe that the best agent developed a strategy that achieved a highscore of 75 on a 12x9 grid, reaching a total length of 79 segments. This corresponds to nearly 75% of the game area, despite the limited, snake-centric observation space.  
<center><img src="../assets/gameplay.gif"/></center>

In [2]:
from game.env import SnakeEnv
from agent import ACN

# Create a new game instance, this time with a GUI, and load the best model.
env = SnakeEnv(
    width      = 12,
    height     = 9,
    gui        = False,
    fps        = 3,
)
model = ACN(
    observation_space = env.observation_space,
    action_space      = env.action_space,
    hidden_size       = 128,
).load('../assets/models/preview/snake_agent_ppo_75.pth')

# Let the agent play the game.
best_game = env.evaluate(model, num_games = 150)

Mean: 36.84 | Highscore: 66


### 6. References
This project was inspired by the following video tutorials:
- Patrick Loeber (2020). *[Teach AI to Play Snake](https://www.youtube.com/playlist?list=PLqnslRFeH2UrDh7vUmJ60YrmWd64mTTKV)*. Available on YouTube.
- Edan Meyer (2021). *[Let’s Code Proximal Policy Optimization](https://www.youtube.com/watch?v=HR8kQMTO8bk)*. Available on YouTube.