# Use trained agent to hit a stationary target 

## Import libiaries

In [None]:
# Import libraries for SpaceShooter game
import pygame
import random
from enum import Enum
import time
import os
from random import randint
import numpy as np

# Import libraries for the Model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Import libraries for plot function 
import matplotlib.pyplot as plt
from IPython import display

# Import libraries for the Agent
from collections import deque

## Initialise settings 

In [None]:
pygame.font.init()
pygame.mixer.init()

WIDTH, HEIGHT = 900, 500
WIN = pygame.display.set_mode((WIDTH, HEIGHT))
BORDER = pygame.Rect(WIDTH//2 - 5, 0, 10, HEIGHT)
HEALTH_FONT = pygame.font.SysFont('comicsans', 40)
WINNER_FONT = pygame.font.SysFont('comicsans', 100)
FPS = 60

## Set colours

In [None]:
BLACK = (0, 0, 0)
RED = (255, 0, 0)
YELLOW = (255, 255, 0)
PURPLE = (200, 0, 200)

## Initialise physics


In [None]:
VEL = 20
BULLET_VEL = 50
SPACESHIP_WIDTH, SPACESHIP_HEIGHT = 55, 40

## Initialise images and sounds

In [None]:
YELLOW_HIT = pygame.USEREVENT + 1
RED_HIT = pygame.USEREVENT + 2

# Heart image
HEART_IMAGE = pygame.image.load(os.path.join('Assets', 'heart.png')).convert_alpha()
HEART = pygame.transform.scale(HEART_IMAGE, (40, 40))

# Yellow spaceShip image
YELLOW_SPACESHIP_IMAGE = pygame.image.load(os.path.join('Assets', 'spaceship_yellow.png'))
YELLOW_SPACESHIP = pygame.transform.rotate(pygame.transform.scale(YELLOW_SPACESHIP_IMAGE, (SPACESHIP_WIDTH, SPACESHIP_HEIGHT)), 90)

# Red spaceship image
RED_SPACESHIP_IMAGE = pygame.image.load(os.path.join('Assets', 'spaceship_red.png'))
RED_SPACESHIP = pygame.transform.rotate(pygame.transform.scale(RED_SPACESHIP_IMAGE, (SPACESHIP_WIDTH, SPACESHIP_HEIGHT)), 270)

# Backround image
SPACE = pygame.transform.scale(pygame.image.load(os.path.join('Assets', 'space_background_2.jpg')), (WIDTH, HEIGHT))

# Game sounds
BULLET_HIT_SOUND = pygame.mixer.Sound('Assets/Grenade+1.mp3')
BULLET_FIRE_SOUND = pygame.mixer.Sound('Assets/Gun+Silencer.mp3')
EZPZ = pygame.mixer.Sound('Assets/ezpz.mp3')

In [None]:
class Action(Enum):
    UP = 1
    DOWN = 2
    FIRE = 3
    NOTHING = 4

## Build enviroment

In [None]:
class SpaceShooter:

    def __init__(self):
        #Init display
        pygame.display.set_caption("SpaceShooter")
        self.BORDER = pygame.Rect(WIDTH//2 - 5, 0, 10, HEIGHT)
        self.clock = pygame.time.Clock()
        self.reset()
          
    def reset(self):
        # Reset game state
        # Change 2nd argument to 300 when using models from the models folder
        self.red = pygame.Rect(700, 230, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)  
        # Change 2nd argument to 300 when using models from the models folder
        self.yellow = pygame.Rect(100, 230, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)
        self.yellow_bullet = None
        self.red_health = 0
        self.yellow_health = 1 
        self.frame_iteration = 0
        self.YELLOW_BULLET_FLAG = 0
        self.RED_BULLET_FLAG = 0
        self.action = Action.NOTHING
        
        
    def step(self, action):
        self.frame_iteration += 1
        # 1. Collect user input and perform action 
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
            
        # [UP, DOWN, FIRE, NOTHING]
        if np.array_equal(action, [1, 0, 0]) and self.yellow.y - VEL > 0: # UP
            self.action = Action.UP
            self.yellow.y -= VEL
        elif np.array_equal(action, [0, 1, 0]) and self.yellow.y + VEL + self.yellow.height < HEIGHT - 15: # DOWN
            self.action = Action.DOWN
            self.yellow.y += VEL
        elif np.array_equal(action, [0, 0, 1]) and self.YELLOW_BULLET_FLAG == 0: # FIRE
            self.action = Action.FIRE
            self.yellow_bullet = pygame.Rect(self.yellow.x + self.yellow.width, self.yellow.y + self.yellow.height//2, 10, 5)
            self.YELLOW_BULLET_FLAG = 1
            # BULLET_FIRE_SOUND.play() uncomment if you want sound :)
        
        # Check is game is over 
        game_over = False 
        if self.yellow_bullet != None:
            self.yellow_bullet.x += BULLET_VEL
            
            if self.red.colliderect(self.yellow_bullet):
                # Red spaceship new position new 
                self.red_pseudo_y = randint(40, 400)
                nearest_multiple_20 = 20 * round(self.red_pseudo_y/20)
                self.red_new_y = nearest_multiple_20
                self.red = pygame.Rect(self.red.x, self.red_new_y, SPACESHIP_WIDTH, SPACESHIP_HEIGHT)
                self.red_health += 1
                
                # Handle yellow bullet
                self.yellow_bullet = None
                self.YELLOW_BULLET_FLAG = 0
                return game_over, self.red_health
                
            elif self.yellow_bullet.x > WIDTH:
                game_over = True
                return game_over, self.red_health

        # 3. Check if game is over    
        if self.red_health >= 1000:
            game_over = True
            self.winner_text = "Yellow Wins!"
            self.winner_colour = YELLOW
            self.draw_winner()
            return game_over, self.red_health
        
        if self.frame_iteration > 50*(self.red_health+1):
            game_over = True
            return game_over, self.red_health
        
        # 4. Update UI and Clock 
        self.draw_window()
        self.clock.tick(FPS)
        return game_over, self.red_health
        
    #Render
    def draw_window(self):
        WIN.blit(SPACE, (0, 0))
        pygame.draw.rect(WIN, PURPLE, BORDER)

        self.red_health_text = HEALTH_FONT.render(str(self.red_health), 1, RED)
        self.yellow_health_text = HEALTH_FONT.render(str(self.yellow_health), 1, RED)
        
        # Draw score
        WIN.blit(HEART, (WIDTH - 200, 17))
        WIN.blit(self.red_health_text, (WIDTH - 150, 10))
        
        # Draw Spaceships
        WIN.blit(YELLOW_SPACESHIP, (self.yellow.x, self.yellow.y))
        WIN.blit(RED_SPACESHIP, (self.red.x, self.red.y))

        # Draw yellow bullet
        if self.yellow_bullet != None:
            pygame.draw.rect(WIN, YELLOW, self.yellow_bullet)

        pygame.display.update()   
    
    def draw_winner(self):
        draw_text = WINNER_FONT.render(self.winner_text, 1, self.winner_colour)
        WIN.blit(draw_text, (WIDTH/2 - draw_text.get_width() /2, HEIGHT/2 - draw_text.get_height()/2))
        pygame.display.update()
        pygame.time.delay(500)

## Create the model used to train the Agent

In [None]:
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

## Create Plot function to display training data

In [None]:
plt.ion()

def plot(scores, mean_scores):
    display.clear_output(wait=True)
    display.display(plt.gcf())
    plt.clf()
    plt.title('Plot of scores achieved')
    plt.xlabel('Number of Games')
    plt.ylabel('Score')
    plt.plot(scores)
    plt.plot(mean_scores)
    plt.ylim(ymin=0)
    plt.text(len(scores)-1, scores[-1], str(scores[-1]))
    plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))
    plt.show(block=False)
    plt.pause(.1)

## Train an Agent to play the SpaceShooter game 

In [None]:
class Agent:

    def __init__(self):
        self.n_games = 0
        self.model = Linear_QNet(8, 256, 3)
        # Change path to directory of the model you want to use model  
        path = os.path.join('model', 'model_250.pth')
        self.model.load_state_dict(torch.load(path))

    def get_state(self, game):
        # yellow spaceship center
        yellow_center = game.yellow.y + game.yellow.height//2
        
        # red spaceship center
        red_center = game.red.y + game.red.height//2
        
        # Position of yellow_spaceship in next after taking action
        position_up = yellow_center - VEL
        position_down = yellow_center + VEL 
        
        # Action
        action_up = game.action == Action.UP
        action_down = game.action == Action.DOWN
        action_fire = game.action == Action.FIRE        
        
        # State 
        state = [
            # If yellow spaceship fires from current position will it hit the red spaceship
            (action_fire and yellow_center == red_center),
            
            # Will yellow spaceship be closer to red spaceship after taking an action
            (action_up and abs(position_up - red_center) < abs(game.yellow.y - red_center)), # Action up
            (action_down and abs(position_down - red_center) < abs(game.yellow.y - red_center)), # Action down
            
            # Current action
            action_up,
            action_down,
            action_fire,
            
            # Direction of the red spaceship from yellow spaceship
            yellow_center < red_center,  # red spaceship is below yellow spaceship
            yellow_center > red_center   # red spaceship is above yellow spaceship
            ]
         
        return np.array(state, dtype=int)

    def get_action(self, state):
        final_move = [0,0,0]
        state0 = torch.tensor(state, dtype=torch.float)
        prediction = self.model(state0)
        move = torch.argmax(prediction).item()
        final_move[move] = 1
        return final_move

# Used trained model to predict move
def train():
    plot_scores = []
    plot_mean_scores = []
    total_score = 0
    record = 0
    agent = Agent()
    game = SpaceShooter()
    
    while True:
        # get old state
        state_old = agent.get_state(game)

        # get move
        final_move = agent.get_action(state_old)
        
        # perform move and get new state
        done, score = game.step(final_move)

        if done:
            # Plot result
            game.reset()
            agent.n_games += 1
            
            if score > record:
                record = score
 
            plot_scores.append(score)
            total_score += score
            mean_score = total_score / agent.n_games
            plot_mean_scores.append(mean_score)
            plot(plot_scores, plot_mean_scores)
            
            print('Game', agent.n_games, 'Score', score, 'Record:', record)

if __name__ == '__main__':
    train()