# üêç Snake RL with MCTS + PPO

AlphaZero-style training for Snake game using:
- **PPO** (Proximal Policy Optimization) for policy learning
- **MCTS** (Monte Carlo Tree Search) for action selection

## Setup
Run all cells to train your agent. **GPU is recommended!**


In [None]:
# Install dependencies
!pip install torch -q

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


## Snake Environment
30-feature state representation with relative actions (straight/right/left)


In [None]:
import numpy as np
import random
from collections import deque

class SnakeEnv:
    """Snake environment with 30-feature state"""
    DIRECTIONS = [(0, -1), (1, 0), (0, 1), (-1, 0)]
    
    def __init__(self, grid_size=10):
        self.grid_size = grid_size
        self.action_space = 3
        self.observation_space = 30
        self.reset()
    
    def reset(self):
        self.snake = deque([(self.grid_size // 2, self.grid_size // 2)])
        self.direction = 1
        self.food = self._place_food()
        self.score = 0
        self.steps = 0
        self.max_steps = self.grid_size * self.grid_size * 20
        return self._get_state()
    
    def _place_food(self):
        while True:
            food = (random.randint(0, self.grid_size - 1), random.randint(0, self.grid_size - 1))
            if food not in self.snake:
                return food
    
    def _is_collision(self, pos):
        if pos[0] < 0 or pos[0] >= self.grid_size or pos[1] < 0 or pos[1] >= self.grid_size:
            return True
        return pos in list(self.snake)[:-1]
    
    def _get_depth(self, start_pos, direction_idx):
        dx, dy = self.DIRECTIONS[direction_idx]
        x, y = start_pos
        distance = 0
        while True:
            x += dx; y += dy; distance += 1
            if x < 0 or x >= self.grid_size or y < 0 or y >= self.grid_size: break
            if (x, y) in self.snake: break
        return distance / self.grid_size
    
    def _get_wall_distance(self, start_pos, direction_idx):
        dx, dy = self.DIRECTIONS[direction_idx]
        x, y = start_pos
        distance = 0
        while True:
            x += dx; y += dy; distance += 1
            if x < 0 or x >= self.grid_size or y < 0 or y >= self.grid_size: break
        return distance / self.grid_size
    
    def _get_next_pos(self, action):
        head = self.snake[0]
        dx, dy = self.DIRECTIONS[action]
        return (head[0] + dx, head[1] + dy)
    
    def _manhattan_distance(self, pos1, pos2):
        return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
    
    def _is_pos_safe(self, pos, snake_body):
        return 0 <= pos[0] < self.grid_size and 0 <= pos[1] < self.grid_size and pos not in snake_body
    
    def _count_safe_moves(self, head_pos, direction, snake_body):
        safe_count = 0
        for turn in [0, 1, 2]:
            new_dir = direction if turn == 0 else (direction + 1) % 4 if turn == 1 else (direction - 1) % 4
            dx, dy = self.DIRECTIONS[new_dir]
            if self._is_pos_safe((head_pos[0] + dx, head_pos[1] + dy), snake_body):
                safe_count += 1
        return safe_count
    
    def _lookahead_safe_moves(self, action_dir, depth=2):
        head = self.snake[0]
        dx, dy = self.DIRECTIONS[action_dir]
        next_head = (head[0] + dx, head[1] + dy)
        snake_set = set(self.snake)
        if not self._is_pos_safe(next_head, snake_set): return (0, 0)
        new_snake = list(self.snake)
        new_snake.insert(0, next_head)
        if next_head != self.food: new_snake.pop()
        new_snake_set = set(new_snake)
        safe_after_first = self._count_safe_moves(next_head, action_dir, new_snake_set)
        if depth <= 1: return (safe_after_first, safe_after_first)
        min_safe, max_safe, paths_found = 3, 0, 0
        for turn in [0, 1, 2]:
            dir2 = action_dir if turn == 0 else (action_dir + 1) % 4 if turn == 1 else (action_dir - 1) % 4
            dx2, dy2 = self.DIRECTIONS[dir2]
            pos2 = (next_head[0] + dx2, next_head[1] + dy2)
            if self._is_pos_safe(pos2, new_snake_set):
                snake2 = list(new_snake)
                snake2.insert(0, pos2)
                if pos2 != self.food: snake2.pop()
                safe_after_second = self._count_safe_moves(pos2, dir2, set(snake2))
                min_safe = min(min_safe, safe_after_second)
                max_safe = max(max_safe, safe_after_second)
                paths_found += 1
        return (min_safe, max_safe) if paths_found > 0 else (0, 0)
    
    def _can_reach_tail(self, head_pos, snake_body_set, tail_pos):
        if head_pos == tail_pos: return True
        visited = set([head_pos])
        queue = [head_pos]
        while queue:
            pos = queue.pop(0)
            for dx, dy in self.DIRECTIONS:
                new_pos = (pos[0] + dx, pos[1] + dy)
                if new_pos == tail_pos: return True
                if new_pos not in visited and 0 <= new_pos[0] < self.grid_size and 0 <= new_pos[1] < self.grid_size and new_pos not in snake_body_set:
                    visited.add(new_pos)
                    queue.append(new_pos)
        return False
    
    def get_action_mask(self):
        head = self.snake[0]
        snake_set = set(self.snake)
        snake_len = len(self.snake)
        mask = [False, False, False]
        best_options = [-1, -1, -1]
        can_reach_tail_score = [0, 0, 0]
        food_adjacent = [False, False, False]
        for action in range(3):
            abs_dir = self.direction if action == 0 else (self.direction + 1) % 4 if action == 1 else (self.direction - 1) % 4
            dx, dy = self.DIRECTIONS[abs_dir]
            next_pos = (head[0] + dx, head[1] + dy)
            if not self._is_pos_safe(next_pos, snake_set):
                best_options[action] = -100
                continue
            new_snake_list = [next_pos] + list(self.snake)
            eating_food = (next_pos == self.food)
            food_adjacent[action] = eating_food
            if not eating_food: new_snake_list.pop()
            new_snake_set = set(new_snake_list)
            new_tail = new_snake_list[-1]
            safe_after = self._count_safe_moves(next_pos, abs_dir, new_snake_set)
            best_options[action] = safe_after
            if eating_food:
                best_options[action] += 200
                can_reach_tail_score[action] = 100
            elif snake_len > 30:
                if self._can_reach_tail(next_pos, new_snake_set - {new_tail}, new_tail):
                    can_reach_tail_score[action] = 100
                    best_options[action] += 100
                else:
                    best_options[action] -= 50
            if safe_after >= 1: mask[action] = True
        for action in range(3):
            if food_adjacent[action] and best_options[action] > 0: mask[action] = True
        if snake_len > 30 and not any(food_adjacent):
            tail_reachable_mask = [can_reach_tail_score[i] > 0 for i in range(3)]
            if any(tail_reachable_mask): mask = tail_reachable_mask
        if not any(mask):
            mask[max(range(3), key=lambda i: best_options[i])] = True
        return mask
    
    def _get_state(self):
        head = self.snake[0]
        features = []
        relative_dirs = [self.direction, (self.direction + 1) % 4, (self.direction - 1) % 4]
        # Danger (3)
        for abs_dir in relative_dirs:
            features.append(1 if self._is_collision(self._get_next_pos(abs_dir)) else 0)
        # Food direction (3)
        food_dx, food_dy = self.food[0] - head[0], self.food[1] - head[1]
        for abs_dir in relative_dirs:
            dir_dx, dir_dy = self.DIRECTIONS[abs_dir]
            features.append(1 if food_dx * dir_dx + food_dy * dir_dy > 0 else 0)
        # Food distance (1)
        features.append(self._manhattan_distance(head, self.food) / (2 * self.grid_size))
        # Wall distance (3)
        for abs_dir in relative_dirs:
            features.append(self._get_wall_distance(head, abs_dir))
        # Depth (3)
        for abs_dir in relative_dirs:
            features.append(self._get_depth(head, abs_dir))
        # Would die (3)
        for abs_dir in relative_dirs:
            features.append(1 if self._is_collision(self._get_next_pos(abs_dir)) else 0)
        # Would eat (3)
        for abs_dir in relative_dirs:
            features.append(1 if self._get_next_pos(abs_dir) == self.food else 0)
        # Distance change (3)
        current_dist = self._manhattan_distance(head, self.food)
        for abs_dir in relative_dirs:
            next_pos = self._get_next_pos(abs_dir)
            if self._is_collision(next_pos): features.append(1)
            else: features.append((self._manhattan_distance(next_pos, self.food) - current_dist) / (2 * self.grid_size))
        # Snake length (1)
        features.append(len(self.snake) / (self.grid_size * self.grid_size))
        # Adjacent body (1)
        adjacent_body = sum(1 for d in range(4) if (head[0] + self.DIRECTIONS[d][0], head[1] + self.DIRECTIONS[d][1]) in list(self.snake)[1:])
        features.append(adjacent_body / 4)
        # Lookahead (6)
        for abs_dir in relative_dirs:
            min_safe, max_safe = self._lookahead_safe_moves(abs_dir, depth=2)
            features.append(min_safe / 3.0)
            features.append(max_safe / 3.0)
        return np.array(features, dtype=np.float32)
    
    def step(self, action):
        self.steps += 1
        new_direction = self.direction if action == 0 else (self.direction + 1) % 4 if action == 1 else (self.direction - 1) % 4
        self.direction = new_direction
        head = self.snake[0]
        dx, dy = self.DIRECTIONS[new_direction]
        new_head = (head[0] + dx, head[1] + dy)
        if new_head[0] < 0 or new_head[0] >= self.grid_size or new_head[1] < 0 or new_head[1] >= self.grid_size:
            return self._get_state(), -10, True, {"score": self.score, "reason": "wall"}
        if new_head in list(self.snake)[:-1]:
            return self._get_state(), -10, True, {"score": self.score, "reason": "self"}
        self.snake.appendleft(new_head)
        if new_head == self.food:
            self.score += 1
            reward = 10
            self.food = self._place_food()
        else:
            self.snake.pop()
            reward = 1 if self._manhattan_distance(new_head, self.food) < self._manhattan_distance(head, self.food) else -1.5
            snake_len = len(self.snake)
            if snake_len > 10: reward += 0.1 * (snake_len / 10)
            snake_set = set(self.snake)
            safe_moves = self._count_safe_moves(new_head, new_direction, snake_set)
            if safe_moves == 0: reward -= 5
            elif safe_moves == 1 and snake_len > 5: reward -= 2
            elif safe_moves == 2 and snake_len > 10: reward -= 0.5
            if snake_len > 30:
                tail = self.snake[-1]
                if self._can_reach_tail(new_head, snake_set - {tail}, tail): reward += 2.0
                else: reward -= 5.0
        if self.steps >= self.max_steps:
            return self._get_state(), -5, True, {"score": self.score, "reason": "timeout"}
        return self._get_state(), reward, False, {"score": self.score}

print("‚úì SnakeEnv loaded")


## PPO Agent
Actor-Critic network with separate networks for policy and value


In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),
            nn.Linear(hidden_dim, action_dim), nn.Softmax(dim=-1))
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1))
        self._init_weights()
    
    def _init_weights(self):
        for module in self.actor:
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
                nn.init.constant_(module.bias, 0.0)
        nn.init.orthogonal_(self.actor[-2].weight, gain=0.01)
        for module in self.critic:
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
                nn.init.constant_(module.bias, 0.0)
    
    def forward(self, state):
        return self.actor(state), self.critic(state)

class PPOAgent:
    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99, eps_clip=0.2, k_epochs=10, device='cpu'):
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.k_epochs = k_epochs
        self.device = device
        self.policy = ActorCritic(state_dim, action_dim).to(device)
        self.lr = lr
        self.min_lr = lr * 0.1
        self.actor_optimizer = optim.Adam(self.policy.actor.parameters(), lr=lr)
        self.critic_optimizer = optim.Adam(self.policy.critic.parameters(), lr=lr)
        self.policy_old = ActorCritic(state_dim, action_dim).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.mse_loss = nn.MSELoss()
        self.update_count = 0
        self.states, self.actions, self.logprobs, self.rewards, self.is_terminals, self.state_values = [], [], [], [], [], []
    
    def select_action(self, state, action_mask=None):
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).to(self.device)
            action_probs, state_value = self.policy_old(state_tensor)
            if action_mask is not None:
                mask_tensor = torch.FloatTensor(action_mask).to(self.device)
                masked_probs = action_probs * mask_tensor
                if masked_probs.sum() > 0:
                    masked_probs = masked_probs / masked_probs.sum()
                else:
                    masked_probs = action_probs
                action_probs_to_use = masked_probs
            else:
                action_probs_to_use = action_probs
            dist = Categorical(action_probs_to_use)
            action = dist.sample()
            action_logprob = dist.log_prob(action)
            action = action.item()
        self.states.append(state_tensor)
        self.actions.append(action)
        self.logprobs.append(action_logprob)
        self.state_values.append(state_value)
        return action
    
    def store_transition(self, reward, is_terminal):
        self.rewards.append(reward)
        self.is_terminals.append(is_terminal)
    
    def update(self):
        if len(self.rewards) == 0: return
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.rewards), reversed(self.is_terminals)):
            if is_terminal: discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)
        old_states = torch.stack(self.states).detach()
        old_actions = torch.tensor(self.actions, device=self.device)
        old_logprobs = torch.stack(self.logprobs).detach()
        old_state_values = torch.stack(self.state_values).squeeze().detach()
        advantages = rewards - old_state_values
        if len(advantages) > 1:
            advantages = advantages / (advantages.std() + 1e-7)
        for _ in range(self.k_epochs):
            action_probs, state_values = self.policy(old_states)
            dist = Categorical(action_probs)
            action_logprobs = dist.log_prob(old_actions)
            dist_entropy = dist.entropy()
            ratios = torch.exp(action_logprobs - old_logprobs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            value_loss = self.mse_loss(state_values.squeeze(), rewards)
            policy_loss = -torch.min(surr1, surr2).mean()
            actor_loss = policy_loss - 0.05 * dist_entropy.mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward(retain_graph=True)
            self.actor_optimizer.step()
            self.critic_optimizer.zero_grad()
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy.critic.parameters(), 0.5)
            self.critic_optimizer.step()
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.update_count += 1
        if self.update_count % 50 == 0:
            new_lr = max(self.min_lr, self.lr * (0.995 ** (self.update_count // 50)))
            for pg in self.actor_optimizer.param_groups: pg['lr'] = new_lr
            for pg in self.critic_optimizer.param_groups: pg['lr'] = new_lr
        self.states, self.actions, self.logprobs, self.rewards, self.is_terminals, self.state_values = [], [], [], [], [], []
    
    def save(self, filepath):
        torch.save({'policy_state_dict': self.policy.state_dict(),
                   'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
                   'critic_optimizer_state_dict': self.critic_optimizer.state_dict()}, filepath)
    
    def load(self, filepath):
        checkpoint = torch.load(filepath, map_location=self.device)
        self.policy.load_state_dict(checkpoint['policy_state_dict'])
        self.policy_old.load_state_dict(checkpoint['policy_state_dict'])
        if 'actor_optimizer_state_dict' in checkpoint:
            self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
            self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])

print("‚úì PPOAgent loaded")


## MCTS
AlphaZero-style tree search


In [None]:
import math

class MCTSNode:
    def __init__(self, state, parent=None, action=None, prior=0.0):
        self.state = state
        self.parent = parent
        self.action = action
        self.prior = prior
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0.0
        self.is_terminal = False
        self.terminal_value = 0.0
    
    @property
    def value(self):
        return 0.0 if self.visit_count == 0 else self.value_sum / self.visit_count
    
    def is_expanded(self):
        return len(self.children) > 0
    
    def select_child(self, c_puct=1.5):
        best_score = -float('inf')
        best_child = None
        sqrt_total = math.sqrt(self.visit_count + 1)
        for action, child in self.children.items():
            q_value = child.value if child.visit_count > 0 else 0.0
            exploration = c_puct * child.prior * sqrt_total / (1 + child.visit_count)
            score = q_value + exploration
            if score > best_score:
                best_score = score
                best_child = child
        return best_child.action if best_child else None, best_child
    
    def expand(self, action_priors, next_states, terminals, terminal_values):
        for action, prior in action_priors.items():
            if action not in self.children:
                child = MCTSNode(state=next_states[action], parent=self, action=action, prior=prior)
                child.is_terminal = terminals[action]
                child.terminal_value = terminal_values[action]
                self.children[action] = child

class SnakeSimulator:
    DIRECTIONS = [(0, -1), (1, 0), (0, 1), (-1, 0)]
    def __init__(self, grid_size=10):
        self.grid_size = grid_size
    
    def get_state_dict(self, env):
        return {'snake': list(env.snake), 'food': env.food, 'direction': env.direction,
                'score': env.score, 'steps': env.steps, 'max_steps': env.max_steps}
    
    def simulate_action(self, state_dict, action):
        snake = deque(state_dict['snake'])
        food = state_dict['food']
        direction = state_dict['direction']
        score = state_dict['score']
        steps = state_dict['steps'] + 1
        max_steps = state_dict['max_steps']
        new_direction = direction if action == 0 else (direction + 1) % 4 if action == 1 else (direction - 1) % 4
        head = snake[0]
        dx, dy = self.DIRECTIONS[new_direction]
        new_head = (head[0] + dx, head[1] + dy)
        if new_head[0] < 0 or new_head[0] >= self.grid_size or new_head[1] < 0 or new_head[1] >= self.grid_size:
            return state_dict, -10, True, {'score': score, 'reason': 'wall'}
        if new_head in list(snake)[:-1]:
            return state_dict, -10, True, {'score': score, 'reason': 'self'}
        new_snake = deque([new_head] + list(snake))
        if new_head == food:
            score += 1
            reward = 10
            new_food = self._place_food(new_snake)
        else:
            new_snake.pop()
            new_food = food
            old_dist = abs(head[0] - food[0]) + abs(head[1] - food[1])
            new_dist = abs(new_head[0] - food[0]) + abs(new_head[1] - food[1])
            reward = 1 if new_dist < old_dist else -1.5
        if steps >= max_steps:
            return {'snake': list(new_snake), 'food': new_food, 'direction': new_direction,
                    'score': score, 'steps': steps, 'max_steps': max_steps}, -5, True, {'score': score, 'reason': 'timeout'}
        return {'snake': list(new_snake), 'food': new_food, 'direction': new_direction,
                'score': score, 'steps': steps, 'max_steps': max_steps}, reward, False, {'score': score}
    
    def _place_food(self, snake):
        snake_set = set(snake)
        empty = [(x, y) for x in range(self.grid_size) for y in range(self.grid_size) if (x, y) not in snake_set]
        return empty[np.random.randint(len(empty))] if empty else None
    
    def get_features(self, state_dict):
        snake = deque(state_dict['snake'])
        food = state_dict['food']
        direction = state_dict['direction']
        head = snake[0]
        features = []
        relative_dirs = [direction, (direction + 1) % 4, (direction - 1) % 4]
        for abs_dir in relative_dirs:
            dx, dy = self.DIRECTIONS[abs_dir]
            next_pos = (head[0] + dx, head[1] + dy)
            features.append(1 if self._is_collision(next_pos, snake) else 0)
        food_dx, food_dy = food[0] - head[0], food[1] - head[1]
        for abs_dir in relative_dirs:
            dir_dx, dir_dy = self.DIRECTIONS[abs_dir]
            features.append(1 if food_dx * dir_dx + food_dy * dir_dy > 0 else 0)
        features.append((abs(head[0] - food[0]) + abs(head[1] - food[1])) / (2 * self.grid_size))
        for abs_dir in relative_dirs:
            features.append(self._get_wall_distance(head, abs_dir))
        for abs_dir in relative_dirs:
            features.append(self._get_depth(head, abs_dir, snake))
        for abs_dir in relative_dirs:
            dx, dy = self.DIRECTIONS[abs_dir]
            features.append(1 if self._is_collision((head[0] + dx, head[1] + dy), snake) else 0)
        for abs_dir in relative_dirs:
            dx, dy = self.DIRECTIONS[abs_dir]
            features.append(1 if (head[0] + dx, head[1] + dy) == food else 0)
        current_dist = abs(head[0] - food[0]) + abs(head[1] - food[1])
        for abs_dir in relative_dirs:
            dx, dy = self.DIRECTIONS[abs_dir]
            next_pos = (head[0] + dx, head[1] + dy)
            if self._is_collision(next_pos, snake): features.append(1)
            else:
                new_dist = abs(next_pos[0] - food[0]) + abs(next_pos[1] - food[1])
                features.append((new_dist - current_dist) / (2 * self.grid_size))
        features.append(len(snake) / (self.grid_size * self.grid_size))
        adjacent = sum(1 for d in range(4) if (head[0] + self.DIRECTIONS[d][0], head[1] + self.DIRECTIONS[d][1]) in list(snake)[1:])
        features.append(adjacent / 4)
        for abs_dir in relative_dirs:
            features.append(0.5)
            features.append(0.5)
        return np.array(features, dtype=np.float32)
    
    def _is_collision(self, pos, snake):
        if pos[0] < 0 or pos[0] >= self.grid_size or pos[1] < 0 or pos[1] >= self.grid_size: return True
        return pos in list(snake)[:-1]
    
    def _get_wall_distance(self, pos, direction):
        dx, dy = self.DIRECTIONS[direction]
        x, y = pos
        distance = 0
        while True:
            x += dx; y += dy; distance += 1
            if x < 0 or x >= self.grid_size or y < 0 or y >= self.grid_size: break
        return distance / self.grid_size
    
    def _get_depth(self, pos, direction, snake):
        dx, dy = self.DIRECTIONS[direction]
        x, y = pos
        distance = 0
        while True:
            x += dx; y += dy; distance += 1
            if x < 0 or x >= self.grid_size or y < 0 or y >= self.grid_size: break
            if (x, y) in snake: break
        return distance / self.grid_size

class MCTS:
    def __init__(self, policy_network, grid_size=10, c_puct=1.5, num_simulations=100, device='cpu'):
        self.policy_network = policy_network
        self.grid_size = grid_size
        self.c_puct = c_puct
        self.num_simulations = num_simulations
        self.device = device
        self.simulator = SnakeSimulator(grid_size)
    
    def get_action_probs(self, env, temperature=1.0):
        state_dict = self.simulator.get_state_dict(env)
        root = MCTSNode(state=state_dict)
        for _ in range(self.num_simulations):
            node = root
            search_path = [node]
            while node.is_expanded() and not node.is_terminal:
                _, node = node.select_child(self.c_puct)
                if node: search_path.append(node)
                else: break
            if node and node.is_terminal:
                value = node.terminal_value
            elif node:
                value = self._expand_node(node)
            else:
                value = 0
            for n in reversed(search_path):
                n.visit_count += 1
                n.value_sum += value
        visits = np.array([root.children[a].visit_count if a in root.children else 0 for a in range(3)])
        if temperature == 0 or visits.sum() == 0:
            action_probs = np.zeros(3)
            action_probs[np.argmax(visits)] = 1.0
        else:
            visits_temp = visits ** (1.0 / temperature)
            action_probs = visits_temp / (visits_temp.sum() + 1e-8)
        return action_probs, np.argmax(visits)
    
    def _expand_node(self, node):
        state_dict = node.state
        features = self.simulator.get_features(state_dict)
        with torch.no_grad():
            features_tensor = torch.FloatTensor(features).to(self.device)
            action_probs, value = self.policy_network(features_tensor)
            action_probs = action_probs.cpu().numpy()
            value = value.item()
        next_states, terminals, terminal_values, action_priors = {}, {}, {}, {}
        for action in range(3):
            next_state, reward, done, _ = self.simulator.simulate_action(state_dict, action)
            next_states[action] = next_state
            terminals[action] = done
            terminal_values[action] = reward if done else 0
            action_priors[action] = action_probs[action]
        node.expand(action_priors, next_states, terminals, terminal_values)
        return value
    
    def select_action(self, env, temperature=1.0):
        action_probs, best_action = self.get_action_probs(env, temperature)
        return best_action if temperature == 0 else np.random.choice(3, p=action_probs)

class HybridAgent:
    def __init__(self, ppo_agent, mcts_threshold=50, num_simulations=100):
        self.ppo_agent = ppo_agent
        self.mcts_threshold = mcts_threshold
        self.mcts = MCTS(policy_network=ppo_agent.policy, grid_size=10, num_simulations=num_simulations, device=ppo_agent.device)
        self.use_mcts_count = 0
        self.use_ppo_count = 0
    
    def select_action(self, state, env, action_mask=None):
        snake_length = len(env.snake)
        if snake_length >= self.mcts_threshold:
            self.use_mcts_count += 1
            action = self.mcts.select_action(env, temperature=0.5)
            self._record_action_to_ppo(state, action, action_mask)
            return action
        else:
            self.use_ppo_count += 1
            return self.ppo_agent.select_action(state, action_mask)
    
    def _record_action_to_ppo(self, state, mcts_action, action_mask=None):
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).to(self.ppo_agent.device)
            action_probs, state_value = self.ppo_agent.policy_old(state_tensor)
            if action_mask is not None:
                mask_tensor = torch.FloatTensor(action_mask).to(self.ppo_agent.device)
                masked_probs = action_probs * mask_tensor
                action_probs_to_use = masked_probs / masked_probs.sum() if masked_probs.sum() > 0 else action_probs
            else:
                action_probs_to_use = action_probs
            dist = Categorical(action_probs_to_use)
            action_logprob = dist.log_prob(torch.tensor(mcts_action, device=self.ppo_agent.device))
        self.ppo_agent.states.append(state_tensor)
        self.ppo_agent.actions.append(mcts_action)
        self.ppo_agent.logprobs.append(action_logprob)
        self.ppo_agent.state_values.append(state_value)
    
    def store_transition(self, reward, is_terminal):
        self.ppo_agent.store_transition(reward, is_terminal)
    
    def update(self):
        self.ppo_agent.update()
    
    def save(self, filepath):
        self.ppo_agent.save(filepath)
    
    def load(self, filepath):
        self.ppo_agent.load(filepath)

print("‚úì MCTS loaded")


## Configuration
Adjust parameters below


In [None]:
# ===== CONFIGURATION =====
MAX_EPISODES = 10000        # Total episodes
MCTS_SIMULATIONS = 200      # Simulations per move (more = better but slower)
MCTS_THRESHOLD = 1          # 1 = MCTS always, 50+ = MCTS only at high scores
UPDATE_FREQUENCY = 10       # PPO update frequency
SAVE_INTERVAL = 500         # Checkpoint frequency
LEARNING_RATE = 1e-3

# Load checkpoint? Set to None for fresh start
LOAD_CHECKPOINT = None  # e.g., "/content/drive/MyDrive/snake.pt"

print(f"Episodes: {MAX_EPISODES}, MCTS sims: {MCTS_SIMULATIONS}, threshold: {MCTS_THRESHOLD}")


## Training Loop
Run this to start training


In [None]:
from datetime import datetime

# Initialize
env = SnakeEnv(grid_size=10)
ppo_agent = PPOAgent(state_dim=30, action_dim=3, lr=LEARNING_RATE, gamma=0.9,
                     eps_clip=0.2, k_epochs=4, device=device)

if LOAD_CHECKPOINT:
    ppo_agent.load(LOAD_CHECKPOINT)
    print(f"Loaded: {LOAD_CHECKPOINT}")

agent = HybridAgent(ppo_agent=ppo_agent, mcts_threshold=MCTS_THRESHOLD, num_simulations=MCTS_SIMULATIONS)

scores = []
best_score = 0
start_time = datetime.now()

print("=" * 60)
print(f"Training on {device} | MCTS: {MCTS_SIMULATIONS} sims")
print("=" * 60)

for episode in range(1, MAX_EPISODES + 1):
    state = env.reset()
    done = False
    
    while not done:
        action_mask = env.get_action_mask()
        action = agent.select_action(state, env, action_mask)
        state, reward, done, info = env.step(action)
        agent.store_transition(reward, done)
    
    score = info.get('score', 0)
    scores.append(score)
    
    if episode % UPDATE_FREQUENCY == 0:
        agent.update()
    
    avg = np.mean(scores[-100:]) if scores else 0
    elapsed = (datetime.now() - start_time).total_seconds()
    reason = info.get('reason', '')
    
    print(f"Ep {episode:5d} | Score: {score:3d} | Max: {max(scores):3d} | Avg100: {avg:5.1f} | {elapsed:.0f}s | {reason}")
    
    if score > best_score:
        best_score = score
        agent.save('/content/snake_best.pt')
        print(f"  *** NEW BEST: {score} ***")
    
    if episode % SAVE_INTERVAL == 0:
        agent.save(f'/content/snake_ep{episode}.pt')
        print(f"  Checkpoint saved")

print("\n" + "=" * 60)
print(f"Done! Avg: {np.mean(scores):.1f}, Max: {max(scores)}")
print("=" * 60)


## Download Model


In [None]:
from google.colab import files
files.download('/content/snake_best.pt')


## (Optional) Google Drive


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Save to Drive
# agent.save('/content/drive/MyDrive/snake_best.pt')
