# Christopher Morales
## EE 5830 - Neural Networks

In [1]:
# Importing Libraries
import pygame
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import namedtuple

pygame 2.5.2 (SDL 2.28.3, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


## Constants for the DQN model

In [2]:
# Screen Resolution
WIDTH, HEIGHT = 600, 400

# How many grids to have (x by x)
GRID_SIZE = 20

# Frames Per Second
FPS = 60

# The size of the replay memory used in experience replay (to store and sample past experience)
REPLAY_MEMORY_SIZE = 10000

# The number of samples (transitions) randomly sampled from the replay memory
BATCH_SIZE = 1000

# Lower gamma makes the agent focus on immediate rewards where higher gamma considered future rewards more
GAMMA = 0.95

# Higher the value then explore (exploration trade off parameter)
EPSILON_START = 1.0

# Sets the minimum value that epsilon can reach (exploration)
EPSILON_END = 0.01

# To allow the agent to tansistion from exploration to exploitation as it learns ()
EPSILON_DECAY = 0.99

# Number of episodes (training time essentially)
MAX_EPISODE_VALUE = 50000

In [3]:
# Creating the Transistion state (how the model can learn from past and current)
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

## Snake Class

In [4]:
class Snake:
    def __init__(self):
        # Set initial position in the grid
        self.position = [GRID_SIZE * 2, GRID_SIZE * 2] 
        
        # Initialize the snake's body with three segments, including the initial position
        self.body = [list(self.position), [self.position[0] - GRID_SIZE, self.position[1]], [self.position[0] - 2 * GRID_SIZE, self.position[1]]]
        
        # Set the initial direction of the snake to 'RIGHT'
        self.direction = 'RIGHT'
        
        # Set the initial change direction to the current direction
        self.change_to = self.direction

    def change_direction(self, new_direction):
        """
        Gets the model action input to determine what direction the snake should go

        :param new_direction: an integer value that will be represented as UP, DOWN, RIGHT, or LEFT
        return: N/A
        """
        # Checks if the new_direction is an integer
        if isinstance(new_direction, int):
            # If the value is 1 then
            if new_direction == 1:
                # Move RIGHT
                self.direction = 'RIGHT'
                
                # Verifying the condition input
                print("Moving right")
            
            # If the value is 1 then
            elif new_direction == 2:
                # Move LEFT
                self.direction = 'LEFT'

                # Verifying the condition input
                print("Moving left")
            
            # If the value is 1 then
            elif new_direction == 3:
                # Move UP
                self.direction = 'UP'
                
                # Verifying the condition input
                print("Moving up")
            
            # If the value is 1 then
            elif new_direction == 4:
                # Move DOWN
                self.direction = 'DOWN'
                
                # Verifying the condition input
                print("Moving down")

            # If the value is something else
            else:
                # Verifying the condition input
                print("Invalid direction")
        
        # If the new_direction is a different data type somehow
        else:
            print("Invalid direction format. Expected integer.")


    def move(self, pebble):
        """
        Move the snake in the current direction, update its position and body.

        :param pebble: Pebble object representing the food for the snake
        :return: True if the snake ate the pebble and False otherwise
        """
        # Move the snake to the right
        if self.direction == 'RIGHT':
            # Update the x-coordinate of the snake's position
            self.position[0] += GRID_SIZE
        
        # Move the snake to the left
        elif self.direction == 'LEFT':
            # Update the x-coordinate of the snake's position
            self.position[0] -= GRID_SIZE
        
        # Move the snake upwards
        elif self.direction == 'UP':
            # Update the y-coordinate of the snake's position
            self.position[1] -= GRID_SIZE
        
        # Move the snake downwards
        elif self.direction == 'DOWN':
            # Update the y-coordinate of the snake's position
            self.position[1] += GRID_SIZE

        # Update the body positions
        self.body.insert(0, list(self.position))
        
        # Check if the snake's position coincides with the pebble's position
        if self.position == pebble.position:
            # Snake ate the pebble, grow the body
            return True  
        
        # If the snake did not eat the pebble
        else:
            # Remove the last segment
            self.body.pop()

            # Snake did not eat the pebble
            return False  

    def check_collision(self):
        """
        Check if the snake has collided with the border or itself.

        :return: True if collision occurred, False otherwise
        """
        # Check if the snake's x-coordinate is outside the game window
        if (
            self.position[0] >= WIDTH
            or self.position[0] < 0
            or self.position[1] >= HEIGHT
            or self.position[1] < 0
        ):
            # Snake collided with the border
            return True  
        
        # Check if the snake collided with itself
        for segment in self.body[1:]:
            # Compare each body segment with the snake's current position
            if segment == self.position:
                # Snake collided with itself
                return True  
        # If no collision occurred
        return False

    def get_head_position(self):
        """
        Get the position of the snake's head.

        :return: List representing the x and y coordinates of the head position
        """
        # Return the current position of the snake's head
        return self.position

    def get_body_positions(self):
        """
        Get the positions of all segments in the snake's body.

        :return: List of lists representing x and y coordinates of each body segment
        """
        # Return the positions of all segments in the snake's body
        return self.body


## Pebble Class

In [5]:
class Pebble:
    def __init__(self):
        # Initialize the pebble's position randomly within the grid
        self.position = [random.randrange(1, (WIDTH//GRID_SIZE)) * GRID_SIZE,
                         random.randrange(1, (HEIGHT//GRID_SIZE)) * GRID_SIZE]

    def respawn(self):
        """
        Respawn the pebble at a new random position within the grid.
        
        return: None
        """
        # Set the pebble's position to a new random position within the grid
        self.position = [random.randrange(1, (WIDTH//GRID_SIZE)) * GRID_SIZE,
                         random.randrange(1, (HEIGHT//GRID_SIZE)) * GRID_SIZE]

    def get_position(self):
        """
        Get the current position of the pebble.
        
        return: List representing the pebble's position [x, y]
        """
        # Return the current position of the pebble
        return self.position


## DQN Model Class

In [6]:
class MyDQNModel(nn.Module):
    def __init__(self, n_observation, n_actions, n_input_channels, input_image_height, n_output_probs,
                 conv_layer_sizes=(32, 64, ), conv_kernel_sizes=(3, 3, ), act_func_maxpool=F.relu,
                 dense_layer_sizes=(100, 100, ), act_func_dense=F.relu, dropout=0.5):
        super(MyDQNModel, self).__init__()

        # Calculate the output size after the convolutional layers
        self.conv_out_size = self._calculate_conv_out_size(n_observation)

        # Initialize the convolutional network
        self.conv_network = nn.ModuleList()

        self.conv_network.append(
            nn.Conv2d(in_channels=1, out_channels=conv_layer_sizes[0], kernel_size=conv_kernel_sizes[0])
        )

        self.conv_network.append(nn.MaxPool2d(2))

        for i in range(len(conv_layer_sizes) - 1):
            self.conv_network.append(
                nn.Conv2d(in_channels=conv_layer_sizes[i], out_channels=conv_layer_sizes[i + 1],
                          kernel_size=conv_kernel_sizes[i + 1])
            )
            self.conv_network.append(nn.Dropout(p=dropout))
            self.conv_network.append(nn.MaxPool2d(2))

        # Calculate the number of inputs into the dense network
        self.dense_n_inputs = self._calc_dense_n_inputs(input_image_height, conv_kernel_sizes, conv_layer_sizes)

        # Initialize the dense network
        self.dense_network = nn.ModuleList()

        self.dense_network.append(nn.Linear(self.dense_n_inputs, dense_layer_sizes[0]))

        for i in range(len(dense_layer_sizes) - 1):
            self.dense_network.append(nn.Linear(dense_layer_sizes[i], dense_layer_sizes[i + 1]))

        # Output layer
        self.output = nn.Linear(dense_layer_sizes[-1], n_output_probs)

    def forward(self, x):
        # Assuming x is a 2D tensor with shape [batch_size, input_size]
        x = x.view(-1, 1, 1, x.size(1))

        for layer in self.conv_network:
            if isinstance(layer, nn.MaxPool2d):
                x = F.relu(layer(x))
            else:
                x = layer(x)

        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))

        for layer in self.dense_network:
            x = F.relu(layer(x))

        return F.softmax(self.output(x), dim=-1)

    def _calculate_conv_out_size(self, input_size):
        # dummy_input = torch.zeros((1, 1, *input_size), dtype=torch.float32)
        dummy_input = torch.zeros((1, 1, input_size, 1), dtype=torch.float32)
        conv_out = self._convolutional_layers(dummy_input)
        return conv_out.view(conv_out.size(0), -1).size(1)

    def _convolutional_layers(self, x):
        for layer in self.conv_network:
            if isinstance(layer, nn.MaxPool2d):
                x = F.relu(layer(x))
            else:
                x = layer(x)
        return x

    def _calc_dense_n_inputs(self, input_image_height, conv_kernel_sizes, conv_layer_sizes):
        final_size = input_image_height

        for conv_kernel_size in conv_kernel_sizes:
            final_size = np.floor(final_size - (conv_kernel_size - 1))
            final_size = np.floor((final_size - (2 - 1) - 1) / 2 + 1)

        return int(conv_layer_sizes[-1] * np.square(final_size))


## Replay Memory Class

In [7]:
class ReplayMemory:
    def __init__(self):
        # Initialize an empty list to store replay memory transitions
        self.memory = []

        # Initialize the position in the memory buffer
        self.position = 0

    def push(self, *args):
        """
        Add a transition to the replay memory.

        :param *args: A tuple representing a transition (state, action, next_state, reward).
        return: None
        """
        # If the memory size is less than the maximum allowed size, append None to the memory list
        if len(self.memory) < REPLAY_MEMORY_SIZE:
            self.memory.append(None)

        # Store the transition at the current position in the memory buffer
        self.memory[self.position] = Transition(*args)

        # Update the position in the memory buffer using modulo to create a circular buffer
        self.position = (self.position + 1) % REPLAY_MEMORY_SIZE

    def sample(self, batch_size):
        """
        Randomly sample a batch of transitions from the replay memory.

        :param batch_size: The number of transitions to sample.
        return: A list of sampled transitions.
        """
        # Use random.sample to randomly select a batch of transitions from the memory
        return random.sample(self.memory, batch_size)

    def __len__(self):
        """
        Return the current size of the replay memory.

        return: The number of stored transitions in the replay memory.
        """
        # Return the length of the memory list
        return len(self.memory)


## Snake Population Agent

In [8]:
class SnakePopulationAgent:
    def __init__(self):
        pass

## Game Environment Class

In [9]:
class GameEnvironment:
    def __init__(self):
        # Initialize the snake object
        self.snake = Snake()

        # Initialize the pebble object
        self.pebble = Pebble()

    def reset(self):
        """
        Reset the game environment by creating a new snake and respawning the pebble.

        return: The initial state of the game.
        """
        # Create a new snake
        self.snake = Snake()

        # Respawn the pebble
        self.pebble.respawn()

        # Return the initial state of the game
        return self.get_state()

    def get_state(self):
        """
        Get the current state of the game.

        return: A tensor representing the current state of the game.
        """
        # Check if the snake has collided, return a consistent representation for the terminal state
        if self.snake.check_collision():
            return torch.zeros((1, 12), dtype=torch.float32)
        
        else:
            # Extract relevant information about the state
            state = [
                self.snake.position[0] / WIDTH,
                self.snake.position[1] / HEIGHT,
                self.pebble.position[0] / WIDTH,
                self.pebble.position[1] / HEIGHT,
            ]

            # Include body segments in the state representation
            for segment in self.snake.body:
                state.extend([segment[0] / WIDTH, segment[1] / HEIGHT])

            # Pad the state with zeros if needed to match the expected input size
            while len(state) < 12:
                state.append(0.0)

            # Return the state as a tensor
            return torch.tensor(state, dtype=torch.float32).view(1, -1)

    def step(self, action):
        """
        Take a step in the environment based on the given action.

        :param action: The action to be taken by the snake.
        return: The next state and the reward obtained from the step.
        """
        # Change the snake's direction based on the action
        self.snake.change_direction(action)

        # Move the snake and check if it ate the pebble
        pebble_eaten = self.snake.move(self.pebble)

        # Handle rewards based on the game state
        if pebble_eaten:
            reward = 1.0
            self.pebble.respawn()

        # When the snake dies
        elif self.snake.check_collision():
            reward = -1.0
            self.reset()
        
        # Nothing happens (still vibing)
        else:
            reward = 0.0

        # Get the next state after the step
        next_state = self.get_state()

        # Return the next state and the obtained reward
        return next_state, reward


## Environment GUI Version Class

In [10]:
class EnvironmentGUIVersion:
    def __init__(self, model=None, n_observations=12):
        # Initialize Pygame
        pygame.init()
        
        # Pygame clock for controlling frame rate
        self.clock = pygame.time.Clock()
        
        # Initialize the game environment
        self.game_env = GameEnvironment()
        
        # Set up the Pygame screen
        self.screen = pygame.display.set_mode((WIDTH, HEIGHT))
        
        # Set up the font for displaying the score
        self.font = pygame.font.SysFont(None, 25)
        
        # Number of possible actions (UP, DOWN, LEFT, RIGHT)
        self.n_actions = 4
        
        # Use the provided model or create a new one
        # self.model = model if isinstance(model, MyDQNModel) else MyDQNModel(n_observations, self.n_actions)
        self.model = model if isinstance(model, MyDQNModel) else MyDQNModel(
                                        n_observations, self.n_actions, n_input_channels=1, input_image_height=1, n_output_probs=self.n_actions)

        
        # Exploration-exploitation trade-off parameter
        self.epsilon = EPSILON_START
        
        # Initialize optimizer if a model is provided
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) if self.model is not None else None
        
        # Loss functions for training the model
        self.mse_loss = nn.MSELoss() if self.model is not None else None
        self.mae_loss = nn.L1Loss() if self.model is not None else None
        self.smooth_L1_loss = nn.SmoothL1Loss() if self.model is not None else None
        self.huber_loss = nn.HuberLoss(reduction='mean', delta=1.0) if self.model is not None else None
        
        # Replay memory for experience replay
        self.memory = ReplayMemory()

    def handle_events(self):
        """
        Handle Pygame events, e.g., window close.
        """
        # Iterate through all Pygame events
        for event in pygame.event.get():
            # Check if the user closed the window
            if event.type == pygame.QUIT:
                # Quit Pygame and exit the program
                pygame.quit()
                quit()

        # Use the lambda function to get the state for the model
        state = (lambda env: env.get_state())(self.game_env)

        # Get the action predicted by the model
        with torch.no_grad():
            action = self.model(state).argmax().item()

        # Print the action predicted by the model
        print(f"Model's Action: {action}")
        
        # Change the snake's direction based on the predicted action
        self.game_env.snake.change_direction(action)
    

    def draw_snake(self):
        """
        Draw the snake on the Pygame screen.
        """
        # Iterate through all snake body segments
        for segment in self.game_env.snake.get_body_positions():
            # Draw a green rectangle for each snake segment
            pygame.draw.rect(self.screen, (0, 255, 0), pygame.Rect(segment[0], segment[1], GRID_SIZE, GRID_SIZE))

    def draw_pebble(self):
        """
        Draw the pebble on the Pygame screen.
        """
        # Get the pebble's position
        position = self.game_env.pebble.get_position()
        
        # Draw a red rectangle for the pebble
        pygame.draw.rect(self.screen, (255, 0, 0), pygame.Rect(position[0], position[1], GRID_SIZE, GRID_SIZE))

    def draw_score(self, score):
        """
        Draw the current score on the Pygame screen.
        
        :param score: The current score to display.
        """
        # Render the score text
        score_text = self.font.render(f'Score: {score}', True, (255, 255, 255))
        
        # Blit the score text onto the screen
        self.screen.blit(score_text, (10, 10))

    def draw_grid(self):
        """
        Draw the grid lines on the Pygame screen.
        """
        # Draw vertical grid lines
        for x in range(0, WIDTH, GRID_SIZE):
            pygame.draw.line(self.screen, (50, 50, 50), (x, 0), (x, HEIGHT))
        
        # Draw horizontal grid lines
        for y in range(0, HEIGHT, GRID_SIZE):
            pygame.draw.line(self.screen, (50, 50, 50), (0, y), (WIDTH, y))

    def run(self):
        """
        Main loop for running the environment.
        """
        # Flag indicating whether the game is running
        running = True

        # Main loop iterating over episodes
        for episode in range(1, MAX_EPISODE_VALUE):
            # Reset the game environment for a new episode
            state = self.game_env.reset()
            
            # Total reward accumulated during the episode
            total_reward = 0

            # Episode loop
            while True:
                # Handle events outside the main loop
                self.handle_events()  

                # Exploration-exploitation strategy
                if random.random() < self.epsilon:
                    # Explore: Choose a random action
                    action = random.randint(0, 3)
                
                else:
                    # Exploit: Choose the action with the highest Q-value
                    with torch.no_grad():
                        action = self.model(state).argmax().item()

                # Take a step in the environment
                next_state, reward = self.game_env.step(action)
                
                # Update total reward
                total_reward += reward

                # Store the experience in the replay memory
                self.memory.push(state, action, next_state, reward)
                
                # Update the current state
                state = next_state

                # Train the model if enough experiences are stored in the memory
                if len(self.memory) > BATCH_SIZE:
                    # Sample a batch of experiences from the replay memory
                    batch = Transition(*zip(*self.memory.sample(BATCH_SIZE)))
                    
                    # Create a mask indicating non-final next states
                    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)),
                                                  dtype=torch.bool)
                    
                    # Determine the maximum length among next states
                    max_len = max(s.shape[1] if s is not None else 0 for s in batch.next_state)

                    # Pad non-final next states to have the same length
                    non_final_next_states = torch.cat([
                        torch.nn.functional.pad(s, (0, max_len - s.shape[1])) if s is not None else torch.zeros(1, max_len)
                        for s in batch.next_state
                    ])

                    # Correct input size
                    input_size = len(self.game_env.get_state()[0])

                    # Update the existing model's conv_out_size attribute
                    self.model.conv_out_size = self.model._calculate_conv_out_size(input_size)
                    
                    # Create a new model with the correct input size
                    # self.model = MyDQNModel(input_size, self.n_actions)
                    self.model = MyDQNModel(self.model.conv_out_size, self.n_actions)
                    
                    # Convert the batch data to tensors
                    state_batch = torch.cat(batch.state)
                    action_batch = torch.tensor(batch.action, dtype=torch.long).view(-1, 1)
                    reward_batch = torch.tensor(batch.reward, dtype=torch.float32).view(-1, 1)

                    # Compute Q-values for the current state and selected actions
                    Q_current = self.model(state_batch).gather(1, action_batch)
                    
                    # Initialize tensor for Q-values of next states
                    Q_next = torch.zeros(BATCH_SIZE)

                    # Update Q-values for non-final next states
                    Q_next[non_final_mask] = self.model(non_final_next_states).max(1)[0].detach()
                    
                    # Compute Q-target values for the Bellman equation
                    Q_target = reward_batch + (GAMMA * Q_next)

                    # Compute the Huber loss between current and target Q-values
                    loss = self.huber_loss(Q_current, Q_target)
                    
                    # Zero gradients, perform a backward pass, and update the weights
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                # Decay epsilon for exploration-exploitation
                self.epsilon = max(EPSILON_END, self.epsilon * EPSILON_DECAY)
                print(f"Episode {episode}, Total Reward: {total_reward}, action state: {action}")

                # Update Pygame screen
                self.screen.fill((0, 0, 0))
                self.draw_snake()
                self.draw_pebble()
                self.draw_score(total_reward)
                self.draw_grid()
                pygame.display.flip()

                # Cap the frame rate
                self.clock.tick(FPS)

                # Exit the loop if the environment is no longer running
                if not running:
                    pygame.quit()
                    return


## Creating an instance for the model

In [11]:
# Calculate the input size based on the state representation in the environment
# input_size = len(EnvironmentGUIVersion().game_env.get_state()[0])
input_size = EnvironmentGUIVersion().game_env.get_state().size(0)


# Assuming 4 possible actions (UP, DOWN, LEFT, RIGHT)
n_actions = 4  

# Create an instance of the MyDQNModel class with the calculated input size and number of actions
model = MyDQNModel(
    n_observation=input_size,
    n_actions=n_actions,
    n_input_channels=1,  # Assuming grayscale images, adjust if needed
    input_image_height=1,  # Adjust if needed
    n_output_probs=n_actions  # Number of output probabilities should match the number of actions
)

AttributeError: 'MyDQNModel' object has no attribute 'conv_network'

## Main Cell (Combining everything)

In [None]:
gui = EnvironmentGUIVersion(model)
gui.run()

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 12]

: 