# Christopher Morales
## EE 5830 - Neural Networks

In [9]:
# Importing Libraries
import pygame
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import namedtuple, deque

## Enabling GPU 

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Constants for the Model

In [13]:
# Initialize Pygame
pygame.init()

# Define a named tuple for transitions
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

# Screen Resolution
WIDTH, HEIGHT = 600, 400

# How many grids to have
GRID_SIZE = 20

# Frames Per Second
FPS = 10

# Colors for the game
WHITE = (255, 255, 255)
RED = (255, 0, 0)
GREEN = (0, 255, 0)

# Defining the directions
UP = (0, -1)
DOWN = (0, 1)
LEFT = (-1, 0)
RIGHT = (1, 0)

## Snake Class

In [4]:
class Snake:
    def __init__(self):
        """
        Initializing the attributes
        """
        # Initializing the length of the snake
        self.length = 1

        # The position for the snake (middle)
        self.positions = [((WIDTH // 2), (HEIGHT // 2))]
        
        # When spawned, pick a random direction 
        self.direction = random.choice([UP, DOWN, LEFT, RIGHT])
        
        # Color of the snake
        self.color = GREEN

    def get_head_position(self):
        """
        Gets the updated head of the snake
        """
        return self.positions[0]

    def update(self, new_head_position=None):
        """
        Keeps updating the position of the snake
        """
        # Getting the current position of the snake (head)
        cur = self.get_head_position()

        # If its nothing then try the default values
        if new_head_position is None:
            # Getting the current direction
            x, y = self.direction

            # The new location of the snake (when moving)
            new = (((cur[0] + (x * GRID_SIZE)) % WIDTH), (cur[1] + (y * GRID_SIZE)) % HEIGHT)
        else:
            new = new_head_position

        # Print the raw data    
        print(f"x: {x}, y: {y}, cur[0]: {cur[0]}, GRID_SIZE: {GRID_SIZE}, new: {new}")

        # Check if any collision has happen or not
        self.check_collision(new)

        return new
        
    
    def check_collision(self, new):
        """
        Checks if the snake eats itself

        :param new: the snake 
        """
        # If the snake eats itself then reset
        if len(self.positions) > 2 and new in self.positions[2:]:
            # Resets from the beginning
            self.reset()
        
        # Else, continuing updating
        else:
            # Get the new position
            self.positions.insert(0, new)

            # Verying that the length is not being left behind 
            if len(self.positions) > self.length:
                # Update the moving block and if its not part of the length then delete
                self.positions.pop()


    def reset(self):
        """
        When the players dies
        """
        # Initialize the length at 1
        self.length = 1

        # Our usual middle spot
        self.positions = [((WIDTH // 2), (HEIGHT // 2))]
        
        # Pick a random direction
        self.direction = random.choice([UP, DOWN, LEFT, RIGHT])

    def render(self, surface):
        """
        Draw the snake and the color (follow it as it moves)
        """
        # Getting an array of data for the position
        for p in self.positions:
            # Redraw the snake
            pygame.draw.rect(surface, self.color, (p[0], p[1], GRID_SIZE, GRID_SIZE))


## Pebble Class

In [5]:
class Pebble:
    def __init__(self):
        """
        Initializing the attributes
        """
        # Creating the starting point
        self.position = (0, 0)

        # Picking the color of the pebble
        self.color = RED

        # Generate a random location for the pebble
        self.randomize_position()

    def randomize_position(self):
        """
        Generates a random position for the pebble
        """
        # Picks a random spot
        self.position = (random.randint(0, (WIDTH // GRID_SIZE) - 1) * GRID_SIZE,
                         random.randint(0, (HEIGHT // GRID_SIZE) - 1) * GRID_SIZE)

    def render(self, surface):
        """
        Render the graphics for the pebble
        """
        # Draw the new pebble location (so the player can see it)
        pygame.draw.rect(surface, self.color, (self.position[0], self.position[1], GRID_SIZE, GRID_SIZE))

## Draws the Grid

In [7]:
def draw_grid(surface):
    """
    Drawing and Initializing the grid onto the screen
    """
    # Apply the lines to the y axis (data points)
    for y in range(0, HEIGHT, GRID_SIZE):
        # Apply the lines to the x axis (data points)
        for x in range(0, WIDTH, GRID_SIZE):
            # Draw the lines for the player to see
            rect = pygame.Rect(x, y, GRID_SIZE, GRID_SIZE)

            # Apply it to the screen
            pygame.draw.rect(surface, WHITE, rect, 1)

## Normal Snake Game Class (user input only)

In [8]:
class NormalSnakeGame:
    def __init__(self):
        # Create an object to help detect time
        self.clock = pygame.time.Clock()

        # Setting up the resolution
        self.screen = pygame.display.set_mode((WIDTH, HEIGHT), 0, 32)

        # Getting the screen size (adjust to people screen)
        self.surface = pygame.Surface(self.screen.get_size())

        # Modify to the approapraiate size
        self.surface = self.surface.convert()

        # Calling the two classes (the snake and pebble)
        self.snake = Snake()
        self.pebble = Pebble()

    def main(self):
        # An infinte loop
        while True:
            # Letting the user quit the game when they hit the "X"
            for event in pygame.event.get():
                # If the event is true
                if event.type == pygame.QUIT:
                    # End the game environment
                    pygame.quit()
                    sys.exit()
                
                # Else if, play the game:)
                elif event.type == pygame.KEYDOWN:
                    # If the player hits the UP arrow key
                    if event.key == pygame.K_UP:
                        # Move UP
                        self.snake.direction = UP
                    
                    # If the player hits the DOWN arrow key
                    elif event.key == pygame.K_DOWN:
                        # Move DOWN
                        self.snake.direction = DOWN
                    
                    # If the player hits the LEFT arrow key
                    elif event.key == pygame.K_LEFT:
                        # Move LEFT
                        self.snake.direction = LEFT
                    
                    # If the player hits the RIGHT arrow key
                    elif event.key == pygame.K_RIGHT:
                        # Move RIGHT
                        self.snake.direction = RIGHT

            # Update the game
            self.snake.update()

            # Check if the snake eats the pebble
            if self.snake.get_head_position() == self.pebble.position:
                # Increase the length of the snake
                self.snake.length += 1

                # Generate a new pebble on to the game
                self.pebble.randomize_position()

            # Fill the background black
            self.surface.fill((0, 0, 0))

            # Draw the grid on to the screen
            draw_grid(self.surface)

            # Draw the snake on the screen
            self.snake.render(self.surface)

            # Draw the pebble on the screen
            self.pebble.render(self.surface)

            # Sandwich everything
            self.screen.blit(self.surface, (0, 0))
            
            # Repeat the entire cycle
            pygame.display.update()

            # Limit the Frames Per Second
            self.clock.tick(FPS)

## Replay Memory Class

In [10]:
class ReplayMemory:
    """
    The ReplayMemory will be used for training out DQN. It stores the transition that the 
    agent observes, allowing us to reuse this data later.
    """
    def __init__(self, buffer_size, batch_size):
        # A named tuple representing a single transistion in our environment. (state, action)
        self.Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

        # 
        self.memory = deque([], maxlen=buffer_size)
        self.batch_size = batch_size

    def push(self, *args):
        # Save a transition
        self.memory.append(self.Transition(*args))

    def sample(self):
        # Randomly sample a batch of experience from memory
        return random.sample(self.memory, self.batch_size)

    def __len__(self):
        return len(self.memory)


## DQN Class

In [11]:
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.n_observations = n_observations
        self.n_actions = n_actions
        
        # Discount factor for future rewards
        self.gamma = 0.9

        # Exporation rate for "epsilon-greedy" strategy (START)
        self.epsilon = 1.0

        # Minimum exploration rate (END)
        self.epsilon_min = 0.01

        # Decay rate for the exporation rate
        self.epsilon_decay = 0.995

        # Learning rate for the nerual network
        self.learning_rate = 0.001

        # Build the Q network
        self.Build_QNetwork()
    
    def Build_QNetwork(self):
        # Creating the layers (Making the Q network)
        self.layer_one = nn.Linear(self.n_observations, 128)
        self.layer_two = nn.Linear(128, 128)
        self.layer_three = nn.Linear(128, self.n_actions)

    def forward(self, x):
        """
        Called with either one element to determine next action, or a batch 
        during optimization. Returns tensor([[left0exp, right0exp], [], . . .])

        Still initializing the Q network
        """ 
        # 
        x = torch.nn.functional.relu(self.layer_one(x))
        x = torch.nn.functional.relu(self.layer_two(x))
        return self.layer_three(x)



    def mean_absolute_error(self, input, target):
        """
        Applies the Mean Absolute Error (MAE) loss function
        Link: https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html#torch.nn.L1Loss

        :param input: Any number dimension
        :param target: Same shape as the input 
        return: The MAE loss function output
        """
        # Initializing the MAE loss function
        loss = torch.nn.L1Loss()

        # Applying the MAE loss function
        loss_output = loss(input, target)

        return loss_output

    def mean_squared_error(self, input, target):
        """
        Applies the Mean Squared Error (MSE) loss function
        Link: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss

        :param input: Means any number of dimension
        :param output: Same shape as the input
        return: The MSE loss function output
        """
        # Initializing the MSE loss function
        loss = torch.nn.MSELoss()

        # Appluing the MSE loss function
        loss_output = loss(input, target)

        return loss_output


    def smoothness(self, input, target, reduction_type: str="none", beta: float=1.0):
        """
        Applies Smooth L1 loss function where Smooth L1 loss converges (beta varies) at 
        a constant 0 loss. For Smooth L1 loss, as beta varies, the L1 segment of the loss
        has a constant slope of 1. 
        Link: https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html

        :param input: Means any number of dimensions
        :param target: Same shape as the input
        :param reduction_type: Specifies the reduction apply to the output
        return: The Smooth L1 loss function
        """
        # Creating a conditional statement to determine if the user wants a different Reduction Type
        if reduction_type == "none" or reduction_type == "mean" or reduction_type == "sum":
            # Initializing the Smooth L1 loss function
            loss = torch.nn.SmoothL1Loss()

            # Applying the loss function
            loss_output = loss(input, target, reduction_type, beta)

        # Else, raise an error
        else:
            raise("Incorrect Reduction Type!!! Try none, mean, or sum")
        
        return loss_output

    def huber_loss_function(self, input, target, reduction_type: str="none", delta: float=1.0):
        """
        Applies Huber Loss function where Huber Loss converges to Mean Square Error (MSE)
        loss. For Huber Loss, the slope of the L1 segment is beta.
        Link: https://pytorch.org/docs/stable/generated/torch.nn.HuberLoss.html

        :param input: Means any number of dimensions
        :param target: Same shape as the input
        :param reduction_type: Specifies the reduction apply to the output
        return: The Huber Loss function
        """
        # Creating a conditional statement to determine if the user wants a different Reduction Type
        if reduction_type == "none" or reduction_type == "mean" or reduction_type == "sum":
            # Initializing the Smooth L1 loss function
            loss = torch.nn.HuberLoss()

            # Applying the loss function
            loss_output = loss(input, target, reduction_type, delta)

        # Else, raise an error
        else:
            raise("Incorrect Reduction Type!!! Try none, mean, or sum")
        
        return loss_output

## Snake Environment Class

In [12]:
class SnakeEnvironment:
    def __init__(self):
        # Calling all the classes
        self.snake = Snake()
        self.pebble = Pebble()
        self.state = self.get_state()
        self.render_thread = None

    def get_state(self):
        """
        Return the crrrent state of the game
        """
        snake_head_collision = self.snake.get_head_position()
        state = np.concatenate([np.array(snake_head_collision), np.array(self.pebble.position)])
        return state

    def reset(self):
        """
        The model is able to reset the game to its initial state
        """
        # Calling the classes to reset
        self.snake.reset()
        self.pebble.randomize_position()
        self.get_state()

    def render(self):
        """
        Renders the game for the model and so we can see it 
        """
        # Create a Pygame window to visualize the game
        pygame.init()
        screen = pygame.display.set_mode((WIDTH, HEIGHT))

        # Looping it in a function 
        self.render_loop(screen)

        self.render_thread = threading.Thread(target=self.render_loop)
        self.render_thread.daemon = True
        self.render_thread.start()
        return self.render_thread
        

    
    def render_loop(self, screen):
        # An infinite loop for rendering
        while True:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    sys.exit()

            # Fill the background black
            screen.fill((0, 0, 0))

            # Draw the grid on to the screen
            draw_grid(screen)

            # Draw the snake on the screen
            self.snake.render(screen)

            # Draw the pebble on the screen
            self.pebble.render(screen)

            # Update the display
            pygame.display.flip()

            # Limit the Frames Per Second
            pygame.time.Clock().tick(FPS)

    def stop_rendering(self):
        """
        Stops the rendering thread when the game ends
        """
        if self.render_thread:
            pygame.quit()


    def play(self, action):
        """
        The function will make actions in the game and return the next state, reward, modify the
        game state based on the action. 
        """
        # If the model hits the UP arrow key
        if action == UP:
            # Move UP
            self.snake.direction = UP
        
        # If the model hits the DOWN arrow key
        elif action == DOWN:
            # Move DOWN
            self.snake.direction = DOWN
        
        # If the model hits the LEFT arrow key
        elif action == LEFT:
            # Move LEFT
            self.snake.direction = LEFT
        
        # If the model hits the RIGHT arrow key
        elif action == RIGHT:
            # Move RIGHT
            self.snake.direction = RIGHT

        # Getting the current position of the snake head
        current_position = self.snake.get_head_position()

        # Getting the current direction
        x, y = self.snake.direction

        # The new location of the snake ()
        new = (((current_position[0] + (x * GRID_SIZE)) % WIDTH), (current_position[1] + (y * GRID_SIZE)) % HEIGHT)

        # Testing/Current, this is my game logic
        new_head_position = self.snake.update()

        # When the snake head eats a pebble (a reward)
        if new_head_position == self.pebble.position:
            # Update the body
            self.snake.length += 1
            
            # Regenerate the pebble else where
            self.pebble.randomize_position()

            # Then reward the model
            reward = 1

        # Else, tell the model it behave bad 
        else:
            # Reward it with nothing
            reward = 0
        
        # Check if the snake collision to itself
        any_collision = len(self.snake.positions) > 2 and new in self.snake.positions[2:]

        # Get the current state
        self.state = self.get_state()

        return self.state, reward, any_collision

## Main Cell

In [14]:
# Set up the environment and DQN model
env = SnakeEnvironment()
n_observations = len(env.get_state())

# Assuming 4 possible actions: UP, DOWN, LEFT, RIGHT
n_actions = 4

# Getting our custom model
model = DQN(n_observations, n_actions)

# Applying the optimization algorithm
optimizer = torch.optim.Adam(model.parameters(), lr=model.learning_rate)

# Giving the model to change states (essentially a brain)
memory = ReplayMemory(buffer_size=10000, batch_size=64)

# Applying the training loop (episode means, complete run or sequence interaction between agent and its environment)
for episode in range(1000):
    # Reset the entire environment and state
    env.reset()
    state = env.get_state()

    # The maximum time steps per episode
    for time_steps in range(500):
        # Choose a random value for the episode to explore (greedy epsilon strategy)
        if random.random() < model.epsilon:
            action = random.randint(0, n_actions - 1)
        else:
            with torch.no_grad():
                # Convert it to tensor and where we will get q_values where we use this for prediction
                q_values = model.forward(torch.Tensor(state))
                # Return the maximum value, and that will be the action state
                action = torch.argmax(q_values).item()
    
        # Play the chosen action and get the next state, reward, and collision info
        next_state, reward, collision = env.play([UP, DOWN, LEFT, RIGHT][action])
        print(f"next_state: {next_state}, reward: {reward}, collision: {collision}")
        print(f"up: {UP}, down: {DOWN}, left: {LEFT}, right: {RIGHT}")

        # Store the transition in replay memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Sample a random batch from replay memory and perform a Q learning update
        if len(memory) > memory.batch_size:
            # Get a random sample
            transitions = memory.sample()
            batch = Transition(*zip(*transitions))

            # Create a mask for non-final next states
            non_final_next_states = torch.stack([torch.Tensor(s) for s in batch.next_state])
            state_batch = torch.stack([torch.Tensor(s) for s in batch.state])

            # Convert batch data to PyTorch tensors
            action_batch = torch.Tensor(batch.action).long()
            reward_batch = torch.Tensor(batch.reward)

            # Get Q-values for the current state
            q_values = model.forward(state_batch)
            state_action_values = q_values.gather(1, action_batch.view(-1, 1))

            # Calculate the next state values for non-final states
            next_state_values = torch.zeros(memory.batch_size)
            non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
            next_state_values[non_final_mask] = model.forward(non_final_next_states).max(1)[0].detach()

            # Calculate the expected Q-values for the current state
            expected_q_values = (next_state_values * model.gamma) + reward_batch

            # Compute the smooth L1 loss between predicted and expected Q-values
            loss = torch.nn.functional.smooth_l1_loss(state_action_values, expected_q_values.view(-1, 1))

            # Zero the gradients, perform a backward pass, and update the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Decay exploration rate
    model.epsilon = max(model.epsilon_min, model.epsilon * model.epsilon_decay)

x: 0, y: -1, cur[0]: 300, GRID_SIZE: 20, new: (300, 180)
next_state: [300 180 180 260], reward: 0, collision: False
up: (0, -1), down: (0, 1), left: (-1, 0), right: (1, 0)
x: 1, y: 0, cur[0]: 300, GRID_SIZE: 20, new: (320, 180)
next_state: [320 180 180 260], reward: 0, collision: False
up: (0, -1), down: (0, 1), left: (-1, 0), right: (1, 0)
x: 0, y: -1, cur[0]: 320, GRID_SIZE: 20, new: (320, 160)
next_state: [320 160 180 260], reward: 0, collision: False
up: (0, -1), down: (0, 1), left: (-1, 0), right: (1, 0)
x: 1, y: 0, cur[0]: 320, GRID_SIZE: 20, new: (340, 160)
next_state: [340 160 180 260], reward: 0, collision: False
up: (0, -1), down: (0, 1), left: (-1, 0), right: (1, 0)
x: -1, y: 0, cur[0]: 340, GRID_SIZE: 20, new: (320, 160)
next_state: [320 160 180 260], reward: 0, collision: False
up: (0, -1), down: (0, 1), left: (-1, 0), right: (1, 0)
x: 0, y: -1, cur[0]: 320, GRID_SIZE: 20, new: (320, 140)
next_state: [320 140 180 260], reward: 0, collision: False
up: (0, -1), down: (0, 1),