# Play the SpaceShooter game against a trained Agent limited to firing one bullet

## Import libiaries

In [None]:
# Import libraries for SpaceShooter game
import pygame
import random
from enum import Enum
import time
import os
from random import randint
import numpy as np

# Import libraries for the Model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Import libraries for plot function 
import matplotlib.pyplot as plt
from IPython import display

# Import libraries for the Agent
from collections import deque

## Initialise settings 

In [None]:
pygame.font.init()
pygame.mixer.init()

WIDTH, HEIGHT = 900, 500
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
BORDER = pygame.Rect(WIDTH//2 - 5, 0, 10, HEIGHT)
HEALTH_FONT = pygame.font.SysFont('comicsans', 40)
WINNER_FONT = pygame.font.SysFont('comicsans', 100)
FPS = 60

## Set colours

In [None]:
BLACK = (0, 0, 0)
RED = (255, 0, 0)
YELLOW = (255, 255, 0)
PURPLE = (200, 0, 200)

## Initialise physics


In [None]:
VEL = 25
BULLET_VEL = 10
MAX_BULLETS = 4
SPACESHIP_WIDTH, SPACESHIP_HEIGHT = 55, 40

## Initialise images and sounds

In [None]:
YELLOW_HIT = pygame.USEREVENT + 1
RED_HIT = pygame.USEREVENT + 2

# Heart image
HEART_IMAGE = pygame.image.load(os.path.join('Assets', 'heart.png')).convert_alpha()
HEART = pygame.transform.scale(HEART_IMAGE, (40, 40))

# Yellow spaceShip image
YELLOW_SPACESHIP_IMAGE = pygame.image.load(os.path.join('Assets', 'spaceship_yellow.png'))
YELLOW_SPACESHIP = pygame.transform.rotate(pygame.transform.scale(YELLOW_SPACESHIP_IMAGE, (SPACESHIP_WIDTH, SPACESHIP_HEIGHT)), 90)

# Red spaceship image
RED_SPACESHIP_IMAGE = pygame.image.load(os.path.join('Assets', 'spaceship_red.png'))
RED_SPACESHIP = pygame.transform.rotate(pygame.transform.scale(RED_SPACESHIP_IMAGE, (SPACESHIP_WIDTH, SPACESHIP_HEIGHT)), 270)

# Backround image
SPACE = pygame.transform.scale(pygame.image.load(os.path.join('Assets', 'space_background_2.jpg')), (WIDTH, HEIGHT))

# Game sounds
BULLET_HIT_SOUND = pygame.mixer.Sound('Assets/Grenade+1.mp3')
BULLET_FIRE_SOUND = pygame.mixer.Sound('Assets/Gun+Silencer.mp3')

In [None]:
class Action(Enum):
    UP = 1
    DOWN = 2
    FIRE = 3
    NOTHING = 4

## Build enviroment

In [None]:
class SpaceShooter:

    def __init__(self):
        #Init display
        pygame.display.set_caption("SpaceShooterAI")
        self.BORDER = pygame.Rect(WIDTH//2 - 5, 0, 10, HEIGHT)
        self.clock = pygame.time.Clock()
        self.reset()
          
    def reset(self):
        # Reset game state
        self.frame_iteration = 0
        self.action = Action.NOTHING
        # Spaceships
        self.red = pygame.Rect(700, 230, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)
        self.yellow = pygame.Rect(100, 230, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)
        # Bullets
        self.yellow_bullet = None
        self.red_bullet = None
        self.last_shot = 0
        self.YELLOW_BULLET_FLAG = 0
        self.RED_BULLET_FLAG = 0
        self.hits = 0
        # Health
        self.red_health = 10
        self.yellow_health = 10
        
    def step(self, action):
        self.frame_iteration += 1
        # Collect user input and perform action 
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_SPACE and self.RED_BULLET_FLAG == 0:
                            self.red_bullet = pygame.Rect(self.red.x, self.red.y + self.red.height//2, 10, 5)
                            self.RED_BULLET_FLAG = 1
                            #BULLET_FIRE_SOUND.play() uncomment if you want sound :)
        
    
            
        # [UP, DOWN, FIRE, NOTHING]
        if np.array_equal(action, [1, 0, 0]) and self.yellow.y - VEL > 0: # UP
            self.action = Action.UP
            self.yellow.y -= VEL
        elif np.array_equal(action, [0, 1, 0]) and self.yellow.y + VEL + self.yellow.height < HEIGHT - 15: # DOWN
            self.action = Action.DOWN
            self.yellow.y += VEL
        elif np.array_equal(action, [0, 0, 1]) and self.YELLOW_BULLET_FLAG == 0 and time.time() - self.last_shot >= 0.1: # FIRE
            self.action = Action.FIRE
            self.yellow_bullet = pygame.Rect(self.yellow.x, self.yellow.y + self.yellow.height//2, 10, 5)
            self.YELLOW_BULLET_FLAG = 1
            self.last_shot = time.time()
            # BULLET_FIRE_SOUND.play() 
            
        # Handle yellow bullet
        if self.yellow_bullet != None:
            self.yellow_bullet.x += BULLET_VEL
            if self.red.colliderect(self.yellow_bullet):
                self.red_health -= 1
                self.yellow_bullet = None
                self.YELLOW_BULLET_FLAG = 0
            elif self.yellow_bullet.x > WIDTH:
                self.yellow_bullet = None
                self.YELLOW_BULLET_FLAG = 0
        
        # Handle red bullet
        if self.red_bullet != None:
            self.red_bullet.x -= BULLET_VEL
            if self.yellow.colliderect(self.red_bullet):
                self.red_bullet = None
                self.RED_BULLET_FLAG = 0
            elif self.red_bullet.x <= 0:
                self.red_bullet = None
                self.RED_BULLET_FLAG = 0
        
        # Check if game is over 
        game_over = False 
        # Yellow wins
        if self.red_health <= 0:
            game_over = True
            self.winner_text = "TriBlaster beast!"
            self.winner_colour = YELLOW
            self.draw_winner()
            return game_over, self.yellow_health, self.red_health
        
        # Red wins
        if self.yellow_health <= 0:
            game_over = True
            self.winner_text = "You Win!! :)"
            self.winner_colour = RED
            self.draw_winner()
            return game_over, self.yellow_health, self.red_health

        
        # Red spaceship controlled by human
        keys_pressed = pygame.key.get_pressed()
        # Movement 
        if keys_pressed[pygame.K_j] and self.red.x - VEL > BORDER.x + BORDER.width:  # LEFT
            self.red.x -= VEL
        if keys_pressed[pygame.K_l] and self.red.x + VEL + self.red.width < WIDTH:  # RIGHT
            self.red.x += VEL
        if keys_pressed[pygame.K_i] and self.red.y - VEL > 0:  # UP
            self.red.y -= VEL
        if keys_pressed[pygame.K_k] and self.red.y + VEL + self.red.height < HEIGHT - 15:  # DOWN
            self.red.y += VEL
        #Shooting
        for event in pygame.event.get():
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_RCTRL and self.RED_BULLET_FLAG == 0:
                        self.red_bullet = pygame.Rect(self.red.x, self.red.y + self.red.height//2, 10, 5)
                        #self.red_bullets.append(self.bullet)
                        self.RED_BULLET_FLAG = 1
                        #BULLET_FIRE_SOUND.play()
    
        # Update UI and Clock 
        self.draw_window()
        self.clock.tick(FPS)
        return game_over, self.yellow_health, self.red_health 
        
    #Render
    def draw_window(self):
        WIN.blit(SPACE, (0, 0))
        pygame.draw.rect(WIN, PURPLE, BORDER)
        
        # Draw spaceship health
        self.red_health_text = HEALTH_FONT.render(str(self.red_health), 1, RED)
        self.yellow_health_text = HEALTH_FONT.render(str(self.yellow_health), 1, RED)
        
        # Draw Health
        WIN.blit(HEART, (10, 17))
        WIN.blit(self.yellow_health_text, (60, 10))
        WIN.blit(HEART, (WIDTH - 150, 17))
        WIN.blit(self.red_health_text, (WIDTH - 100, 10))
        
        # Draw Spaceships
        WIN.blit(YELLOW_SPACESHIP, (self.yellow.x, self.yellow.y))
        WIN.blit(RED_SPACESHIP, (self.red.x, self.red.y))

        # Draw bullets
        if self.yellow_bullet != None:
            pygame.draw.rect(WIN, YELLOW, self.yellow_bullet)
        if self.red_bullet != None:
            pygame.draw.rect(WIN, RED, self.red_bullet)
        
        # Update display
        pygame.display.update()   
        
    def draw_winner(self):
        draw_text = WINNER_FONT.render(self.winner_text, 1, self.winner_colour)
        WIN.blit(draw_text, (WIDTH/2 - draw_text.get_width() /2, HEIGHT/2 - draw_text.get_height()/2))
        pygame.display.update()
        pygame.time.delay(500)

## Create the model used to train the Agent

In [None]:
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

## Train an Agent to play the SpaceShooter game 

In [None]:
MAX_MEMORY = 100_000
BATCH_SIZE = 1000

# Learning rate
LR = 0.001 

class Agent:

    def __init__(self):
        self.n_games = 0
        self.model = Linear_QNet(11, 256, 3)
        
        # Change path to directory of the model you want to use model  
        path = os.path.join('SpaceShooter_model', 'spaceshooter_model_500.pth')
        self.model.load_state_dict(torch.load(path))


    def get_state(self, game):
        # yellow spaceship center
        yellow_center = game.yellow.y + game.yellow.height//2
        
        # red spaceship center
        red_center = game.red.y + game.red.height//2
        
        # center of yellow_spaceship in next after moving
        # Up
        if game.yellow.y-VEL > 0:
            yellow_center_up = yellow_center - VEL
        else:
            yellow_center_up = yellow_center
        # Down
        if game.yellow.y + VEL + game.yellow.height < HEIGHT - 15:
            yellow_center_down = yellow_center + VEL 
        else:
            yellow_center_down = yellow_center
        
        # Position of yellow_spaceship in next after taking action
        # Up
        if game.yellow.y-VEL > 0:
            yellow_up = pygame.Rect(game.yellow.x, game.yellow.y - VEL, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)
        else:
            yellow_up = pygame.Rect(game.yellow.x, game.yellow.y, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)
        # Down
        if game.yellow.y + VEL + game.yellow.height < HEIGHT - 15:
            yellow_down = pygame.Rect(game.yellow.x, game.yellow.y + VEL, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)
        else:
            yellow_down = pygame.Rect(game.yellow.x, game.yellow.y, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)
        
        # Action
        action_up = game.action == Action.UP
        action_down = game.action == Action.DOWN
        action_fire = game.action == Action.FIRE 
        
        # Test red bullet
        if game.red_bullet != None:
            test_red_bullet_current = pygame.Rect(game.red_bullet.x, game.red_bullet.y, 10, 5)
            test_red_bullet = pygame.Rect(game.red_bullet.x-BULLET_VEL, game.red_bullet.y, 10, 5)
        else:
            test_red_bullet_current = pygame.Rect(700, 0, 10, 5)
            test_red_bullet = pygame.Rect(700, 0, 10, 5)
        
        # Test yellow bullet
        test_yellow_bullet = pygame.Rect(game.red.x, game.yellow.y + game.yellow.height//2, 10, 5)
        
        # State 
        state = [
            # Checks for shooting red spaceship
            # If yellow spaceship fires from current position will it hit the red spaceship
            game.red.colliderect(test_yellow_bullet),
                        
            # Direction of the red bullet from yellow spaceship
            yellow_center < red_center,  # red spaceship is below yellow spaceship
            yellow_center > red_center,   # red spaceship is above yellow spaceship
            
            # Will yellow spaceship be closer to red spaceship after moving
            abs(yellow_center_up - red_center) < abs(game.yellow.y - red_center), # After moving up
            abs(yellow_center_down - red_center) < abs(game.yellow.y - red_center), # After moving down
            
            # Current action
            action_up,
            action_down,
            action_fire,
            
            # Checks for dodging red bullet
            # Will yellow be hit by red bullet in current position
            game.yellow.colliderect(test_red_bullet),
            
            # Will yellow be hit by red bullet after moving 
            (yellow_up.colliderect(test_red_bullet) or yellow_up.colliderect(test_red_bullet_current)), # After moving up
            (yellow_down.colliderect(test_red_bullet) or yellow_down.colliderect(test_red_bullet_current)), # After moving down
            ]
         
        return np.array(state, dtype=int)

    def get_action(self, state):
        final_move = [0,0,0]
        state0 = torch.tensor(state, dtype=torch.float)
        prediction = self.model(state0)
        move = torch.argmax(prediction).item()
        final_move[move] = 1
        return final_move

def train():
    agent = Agent()
    game = SpaceShooter()
    
    while True:
        # get old state
        state_old = agent.get_state(game)

        # get move
        final_move = agent.get_action(state_old)
        
        # perform move and get new state
        done, yellow_health, red_health = game.step(final_move)

        if done:
            # train long memory, plot result
            agent.n_games += 1

            print("--------- Game number", agent.n_games, "---------")
            print("Yellow health:", yellow_health)
            print("Red health:", red_health)
            
            game.reset()
            
if __name__ == '__main__':
    train()