In [None]:
import numpy as np
import random
import pygame
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple

# Constants
WIDTH, HEIGHT = 600, 400
BLOCK_SIZE = 20
SNAKE_COLOR = (0, 255, 0)
FOOD_COLOR = (255, 0, 0)
BACKGROUND_COLOR = (0, 0, 0)
FPS = 12

# Hyperparameters
GAMMA = 0.9
EPSILON = 0.05 # Start with lower exploration
EPSILON_MIN = 0.01  # Reduced minimum exploration rate
EPSILON_DECAY = 0.99  # Faster decay rate for epsilon
LEARNING_RATE = 0.001
MEMORY_SIZE = 10000
BATCH_SIZE = 128
TARGET_UPDATE_FREQUENCY = 10

# PyTorch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Updated DQN network with 8 hidden layers
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, 128)
        self.fc5 = nn.Linear(128, 128)
        self.fc6 = nn.Linear(128, 128)
        self.fc7 = nn.Linear(128, 128)
        self.fc8 = nn.Linear(128, 128)
        self.fc9 = nn.Linear(128, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = torch.relu(self.fc5(x))
        x = torch.relu(self.fc6(x))
        x = torch.relu(self.fc7(x))
        x = torch.relu(self.fc8(x))
        return self.fc9(x)

# Initialize DQNs
policy_net = DQN(input_size=4, output_size=4).to(device)
target_net = DQN(input_size=4, output_size=4).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

# Load pre-trained model if available
try:
    policy_net.load_state_dict(torch.load("snake.pt", map_location=device))
    policy_net.eval() 
    print("evaluation mode")
    # policy_net.train() 
    # print("training mode")
    print("Loaded pre-trained model 'snake.pt'")
    
except FileNotFoundError:
    print("No pre-trained model found, training from scratch.")

# Rest of the code remains unchanged (Optimizer, game loop, helper functions, etc.)

# Optimizer and replay memory
optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
memory = deque(maxlen=MEMORY_SIZE)
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

# Initialize pygame
pygame.init()
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("DQN Snake Game-new")
clock = pygame.time.Clock()

# Action mappings: 0 - Up, 1 - Left, 2 - Right, 3 - Down
ACTIONS = [(0, -BLOCK_SIZE), (-BLOCK_SIZE, 0), (BLOCK_SIZE, 0), (0, BLOCK_SIZE)]
# Modify the step method to ensure that next_state is not None
class SnakeGame:
    def __init__(self):
        self.reset()

    def reset(self):
        self.snake = [(WIDTH // 2, HEIGHT // 2)]
        self.direction = (0, -BLOCK_SIZE)
        self.food = self.spawn_food()
        self.score = 0
        self.done = False
        self.start_distance = self.calculate_distance(self.snake[0], self.food)  # Distance to food at start
        return self.get_state()

    def spawn_food(self):
        while True:
            x = random.randint(0, (WIDTH - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
            y = random.randint(0, (HEIGHT - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
            if (x, y) not in self.snake:
                return (x, y)

    def get_state(self):
        head_x, head_y = self.snake[0]
        food_x, food_y = self.food
        state = np.array([head_x - food_x, head_y - food_y, self.direction[0], self.direction[1]], dtype=np.float32)
        return state

    def calculate_distance(self, head, food):
        return np.abs(head[0] - food[0]) + np.abs(head[1] - food[1])

    def step(self, action):
        # Check for the new direction based on the chosen action
        new_direction = ACTIONS[action]
        potential_new_head = (self.snake[0][0] + new_direction[0], self.snake[0][1] + new_direction[1])

    # Check if the new direction would lead to self-collision
        if potential_new_head in self.snake:
        # Find all possible safe directions that avoid self-collisions
            safe_directions = [
                d for d in ACTIONS
                if (self.snake[0][0] + d[0], self.snake[0][1] + d[1]) not in self.snake
            ]
        
        # If there are safe directions, choose one randomly
            if safe_directions:
                new_direction = random.choice(safe_directions)
            else:
                # If no safe direction is found, maintain the current direction
                new_direction = self.direction

    # Update the snake's direction
        self.direction = new_direction
        head_x, head_y = self.snake[0]
        new_head = (head_x + self.direction[0], head_y + self.direction[1])

    # Check for final wall or self-collision after the move
        if (new_head in self.snake or
                new_head[0] < 0 or new_head[0] >= WIDTH or
                new_head[1] < 0 or new_head[1] >= HEIGHT):
            self.done = True
            return self.get_state(), -15, self.done  # High penalty for collision

    # Move the snake to the new position
        self.snake.insert(0, new_head)
        reward = 0

    # Calculate the distance to the food before and after the move
        old_distance = self.start_distance
        new_distance = self.calculate_distance(new_head, self.food)

        if new_distance < old_distance:
            reward = 1  # Reward for getting closer to food
        elif new_distance > old_distance:
            reward = -1  # Penalty for moving farther from food

    # Check if the snake has found food
        if new_head == self.food:
            self.score += 1
            reward += 10  # Reward for food
            self.food = self.spawn_food()  # Spawn new food after eating
            self.start_distance = self.calculate_distance(new_head, self.food)  # Reset the distance to new food

        else:
            self.snake.pop()  # Remove the last segment if no food is eaten

        next_state = self.get_state()  # Always return the next state, even when done
        return next_state, reward, self.done


    def render(self):
        screen.fill(BACKGROUND_COLOR)
        for segment in self.snake:
            pygame.draw.rect(screen, SNAKE_COLOR, (*segment, BLOCK_SIZE, BLOCK_SIZE))
        pygame.draw.rect(screen, FOOD_COLOR, (*self.food, BLOCK_SIZE, BLOCK_SIZE))
        font = pygame.font.Font(None, 36)
        score_text = font.render(f'Score: {self.score}', True, (255, 255, 255))
        screen.blit(score_text, (10, 10))
        pygame.display.flip()

# Helper functions
def choose_action(state):
    if random.uniform(0, 1) < EPSILON:
        return random.randint(0, 3)  # Take a random action
    else:
        with torch.no_grad():
            state_tensor = torch.tensor(state, device=device)
            q_values = policy_net(state_tensor)
            action = q_values.argmax().item()
            return action

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = random.sample(memory, BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    state_batch = torch.tensor(np.array(batch.state), device=device)
    action_batch = torch.tensor(batch.action, device=device)
    reward_batch = torch.tensor(batch.reward, device=device)

    # Handle next_state, replace None with zeros for done states
    next_state_batch = np.array(batch.next_state)
    next_state_batch = np.nan_to_num(next_state_batch, nan=0)  # Convert None to zeros
    next_state_batch = torch.tensor(next_state_batch, device=device)

    q_values = policy_net(state_batch).gather(1, action_batch.unsqueeze(1))
    next_q_values = target_net(next_state_batch).max(1)[0].detach()
    expected_q_values = reward_batch + GAMMA * next_q_values

    loss = nn.MSELoss()(q_values.squeeze(), expected_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


total_score = 0 
wrong_move_penalty = 0
collision_penalty = 0

# Main game loop
game = SnakeGame()
total_wrong_move_penalty = 0  # Track wrong move penalty
total_collision_penalty = 0   # Track collision penalty
games_played = 0

while True:
    state = game.reset()
    game_over = False

    # Reset per-game penalty counters
    wrong_move_penalty = 0
    collision_penalty = 0

    while not game_over:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()

        action = choose_action(state)
        next_state, reward, game_over = game.step(action)

        # Track penalty types
        if reward == -1:
            wrong_move_penalty += reward  # Increment wrong move penalty
        elif reward == -15:  # Assuming collision penalty is set to -15
            collision_penalty += reward  # Increment collision penalty

        memory.append(Transition(state, action, next_state, reward))
        state = next_state

        optimize_model()
        game.render()
        clock.tick(FPS)

    # After each game, accumulate and print penalties
    total_wrong_move_penalty += wrong_move_penalty
    total_collision_penalty += collision_penalty
    total_score += game.score  # Add score of this game to total score

    print(f'Game {games_played + 1} ended with score {game.score} wrong moves ({wrong_move_penalty}) colision ({collision_penalty}) .')


    if games_played % TARGET_UPDATE_FREQUENCY == 0:
        target_net.load_state_dict(policy_net.state_dict())

    games_played += 1
    EPSILON = max(EPSILON_MIN, EPSILON * EPSILON_DECAY)

    if games_played >= 100:
        # torch.save(policy_net.state_dict(), "snake.pt")
        # print("Trained model saved as 'snake.pt'")
        print(f'Total Score: {total_score}, Average Score: {total_score / games_played:.2f}')
        break

pygame.quit()


