In [1]:
%pip install pygame
%pip install numpy
%pip install gym



Note: you may need to restart the kernel to use updated packages.


In [2]:
import random
import pygame
from pygame.locals import *
import numpy as np
import numpy.linalg as LA
import math
from collections import defaultdict
import gym
from gym import spaces
import os
import sys
os.environ['KMP_DUPLICATE_LIB_OK']='True'

pygame 2.1.2 (SDL 2.0.18, Python 3.9.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
WHITE = (255, 255, 255)
RED = (255, 0, 0)
BLUE = (0, 0, 255)
DISP_WIDTH = 1000
DISP_HEIGHT = 1000
FPS = pygame.time.Clock()

---------- Default Config ----------

In [4]:
DEFAULT_CONFIG_DICT = {
    'time_step': 0.1, # hours per time step
    'plane_speed': 500, # mph
    'bullet_speed': 700, # mph
    'max_time': 10, # hours the epoch can last
    'show_viz': False, # show the pygame animation
    'step_turn': 30, # degrees to turn per step
    'hit_base_reward': 1000, # reward for shooting enemy base
    'hit_plane_reward': 1000, # reward for shooting enemy plane
    'miss_punishment': -5, # punishment for missing a shot
    'too_long_punishment': -1, # punishment for taking too long to end the game
    'closer_to_base_reward': 0, # reward for getting closer to enemy base
    'closer_to_plane_reward': 0, # reward for getting closer to enemy plane
    'turn_to_base_reward': 0, # reward for turning towards the enemy base
    'turn_to_plane_reward': 0 # reward for turning towards the enemy plane
}

---------- Helper Functions ----------

In [5]:
def get_angle(p1, p0):
    return math.degrees(math.atan2(p1[1]-p0[1],p1[0]-p0[0]))

def dist(p1,p0):
    return math.sqrt((p1[0]-p0[0])**2 + (p1[1]-p0[1])**2)

def blitRotate(image, pos, originPos, angle):

    # offset from pivot to center
    image_rect = image.get_rect(topleft = (pos[0] - originPos[0], pos[1]-originPos[1]))
    offset_center_to_pivot = pygame.math.Vector2(pos) - image_rect.center
    
    # roatated offset from pivot to center
    rotated_offset = offset_center_to_pivot.rotate(-angle)

    # rotated image center
    rotated_image_center = (pos[0] - rotated_offset.x, pos[1] - rotated_offset.y)

    # get a rotated image
    rotated_image = pygame.transform.rotate(image, angle)
    rotated_image_rect = rotated_image.get_rect(center = rotated_image_center)

    return rotated_image, rotated_image_rect

def calc_new_xy(old_xy, speed, time, angle):
    new_x = old_xy[0] + (speed*time*math.cos(-math.radians(angle)))
    new_y = old_xy[1] + (speed*time*math.sin(-math.radians(angle)))
    return (new_x, new_y)

---------- Plane Class ----------

In [6]:
class Plane:
    def __init__(self, team): 
        self.team = team
        self.image = pygame.image.load(f"Images/{team}_plane.png")
        self.w, self.h = self.image.get_size()
        self.heading = 0
        self.rect = self.image.get_rect()
        self.reset()

    def reset(self):
        if self.team == 'red':
            x = (DISP_WIDTH - self.w/2)/2 * random.random()
            y = (DISP_HEIGHT - self.h/2) * random.random()
            self.rect.center = (x, y)
            self.heading = 90 * random.random() if random.random() < .5 else 90 * random.random() + 270
            
        else:
            x = (DISP_WIDTH - self.w/2)/2 * random.random() + (DISP_WIDTH - self.w/2)/2
            y = (DISP_HEIGHT - self.h/2) * random.random()
            self.rect.center = (x, y)
            self.heading = 180 * random.random() + 90
        
    def rotate(self, angle):
        self.heading += angle

    def set_heading(self, heading):
        self.heading = heading

    def forward(self, speed, time):
        oldpos = self.rect.center
        self.rect.center = calc_new_xy(oldpos, speed, time, self.heading)

    def draw(self, surface):
        image, rect = blitRotate(self.image, self.rect.center, (self.w/2, self.h/2), self.heading)
        surface.blit(image, rect)

    def update(self):
        # Keep player on the screen
        if self.rect.left < 0:
            self.rect.left = 0
        if self.rect.right > DISP_WIDTH:
            self.rect.right = DISP_WIDTH
        if self.rect.top <= 0:
            self.rect.top = 0
        if self.rect.bottom >= DISP_HEIGHT:
            self.rect.bottom = DISP_HEIGHT

    def get_pos(self):
        return self.rect.center

----------- Base Class ----------

In [7]:
class Base:
    
    def __init__(self, team):
        self.team = team
        self.image = pygame.image.load(f"Images/{team}_base.png")
        self.w, self.h = self.image.get_size()
        self.rect = self.image.get_rect()
        self.reset()
        
    def reset(self):
        if self.team == 'red':
            x = (DISP_WIDTH - self.w/2)/2 * random.random()
            y = (DISP_HEIGHT - self.h/2) * random.random()
            self.rect.center = (x, y)
        else:
            x = (DISP_WIDTH - self.w/2)/2 * random.random() + (DISP_WIDTH - self.w/2)/2
            y = (DISP_HEIGHT - self.h/2) * random.random()
            self.rect.center = (x, y)

    def draw(self, surface):
        surface.blit(self.image, self.rect)
            
    def get_pos(self):
        return self.rect.center

---------- Bullet Class ----------

In [8]:
class Bullet(pygame.sprite.Sprite):
    def __init__(self, x, y, angle, speed, fteam, oteam):
        pygame.sprite.Sprite.__init__(self)
        self.off_screen = False
        self.image = pygame.Surface((8, 4), pygame.SRCALPHA)
        self.fteam = fteam
        self.color = RED if self.fteam == 'red' else BLUE
        self.oteam = oteam
        self.image.fill(self.color)
        self.rect = self.image.get_rect(center=(x, y))
        self.w, self.h = self.image.get_size()
        self.heading = angle + (random.random() * 10 - 5)
        self.pos = (x, y)
        self.speed = speed

    def update(self, screen_width, screen_height, time):
        oldpos = self.rect.center
        self.rect.center = calc_new_xy(oldpos, self.speed, time, self.heading)
        if self.rect.centerx > screen_width or self.rect.centerx < 0 or self.rect.centery > screen_height or self.rect.centery < 0:
            return 'miss'
        for plane in self.oteam['planes']:
            if self.rect.colliderect(plane.rect):
                return 'plane'
        if self.rect.colliderect(self.oteam['base'].rect):
            return 'base'
        return 'none'

    def draw(self, surface):
        image, rect = blitRotate(self.image, self.rect.center, (self.w/2, self.h/2), self.heading)
        surface.blit(image, rect)

----------- Battle Environment ----------

In [9]:
class BattleEnvironment(gym.Env):
    def __init__(self, config: dict=DEFAULT_CONFIG_DICT):
        super(BattleEnvironment, self).__init__()
        self.width = DISP_WIDTH
        self.height = DISP_HEIGHT
        self.max_time = config['max_time']
        high = np.array( # Observation space: fplane_pos_x, fplane_pos_y, fplane_angle, dist_obase, dist_oplane, rel_angle_obase, rel_angle_oplane
            [
                self.width,
                self.height,
                720,
                math.sqrt(math.pow(self.width, 2) + math.pow(self.height, 2)),
                math.sqrt(math.pow(self.width, 2) + math.pow(self.height, 2)),
                720,
                720,
                self.max_time,
            ],
            dtype=np.float32,
        )
        # For the Agent, actions are turn left, turn right, turn to enemy, turn to target, go forward, or shoot
        # For the Random choice Agent, the actions are to enemy, to target, shoot enemy, or shoot target
        self.action_space = spaces.Discrete(6)
        self.random_action_space = [0, 1, 4, 5]
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
        self.team = {}
        self.team['red'] = {}
        self.team['blue'] = {}
        self.team['red']['base'] = Base('red')
        self.team['blue']['base'] = Base('blue')
        self.team['red']['planes'] = []
        self.team['red']['planes'].append(Plane('red'))
        self.team['blue']['planes'] = []
        self.team['blue']['planes'].append(Plane('blue'))
        self.team['red']['wins'] = 0
        self.team['blue']['wins'] = 0
        self.ties = 0
        self.bullets = []
        self.speed = config['plane_speed']
        self.bullet_speed = config['bullet_speed']
        self.total_time = 0 # in hours
        self.time_step = config['time_step']
        self.show = config['show_viz']
        self.step_turn = config['step_turn']
        self.hit_base_reward = config['hit_base_reward']
        self.hit_plane_reward = config['hit_plane_reward']
        self.miss_punishment = config['miss_punishment']
        self.too_long_punishment = config['too_long_punishment']
        self.closer_to_base_reward = config['closer_to_base_reward']
        self.closer_to_plane_reward = config['closer_to_plane_reward']
        self.turn_to_base_reward = config['turn_to_base_reward']
        self.turn_to_plane_reward = config['turn_to_plane_reward']

    def _get_observation(self):
        # This code will all have to be changed when adding multiple planes
        fplane = self.team['red']['planes'][0]
        oplane = self.team['blue']['planes'][0]
        obase = self.team['red']['base']

        fplane_pos = fplane.get_pos()
        fplane_angle = fplane.heading
        oplane_pos = oplane.get_pos()
        obase_pos = obase.get_pos()

        dist_oplane = dist(oplane_pos, fplane_pos)
        dist_obase = dist(obase_pos, fplane_pos)

        angle_to_oplane = get_angle(oplane_pos, fplane_pos)
        angle_to_obase = get_angle(obase_pos, fplane_pos)
        rel_angle_oplane = (angle_to_oplane - fplane_angle) % 360
        rel_angle_obase = (angle_to_obase - fplane_angle) % 360
        
        self.observation = (fplane_pos[0], fplane_pos[1], fplane_angle, dist_obase, dist_oplane, rel_angle_obase, rel_angle_oplane, self.total_time)
        return np.array(self.observation, dtype=np.float32)

        
    def reset(self): # return observation
        self.done = False
        self.winner = 'none'

        self.team['red']['base'].reset()
        self.team['blue']['base'].reset()

        for plane in self.team['red']['planes']:
            plane.reset()
        for plane in self.team['blue']['planes']:
            plane.reset()

        self.total_time = 0
        self.bullets = []

        if self.show:
            pygame.init()
            self.display = pygame.display.set_mode((DISP_WIDTH, DISP_HEIGHT))
            pygame.display.set_caption("Battlespace Simulator")

        return self._get_observation()

    def step(self, action): # return observation, reward, done, info

        reward = 0

        # Check if over time, if so, end game in tie
        self.total_time += self.time_step
        if self.total_time >= self.max_time:
            self.done = True
            self.ties += 1
            if self.show:
                self.render()
                print("Draw")
            return self._get_observation(), 0, self.done, {}

        # Red turn
        self.friendly = 'red'
        self.opponent = 'blue'
        reward += self._process_action(action, self.team[self.friendly], self.team[self.opponent])
        
        # Blue turn
        self.friendly = 'blue'
        self.opponent = 'red'
        self._process_action(self.random_action_space[random.randint(0, 3)], self.team[self.friendly], self.team[self.opponent])        

        # Check if bullets hit and move them
        for bullet in self.bullets:
            outcome = bullet.update(self.width, self.height, self.time_step)
            if outcome == 'miss':
                reward += self.miss_punishment
                self.bullets.pop(self.bullets.index(bullet))
            elif outcome == 'plane' or outcome == 'base': # If a bullet hit
                self.winner = bullet.fteam
                self.team[self.winner]['wins'] += 1
                self.done = True
                reward = reward + self.hit_base_reward if outcome == 'base' else reward + self.hit_plane_reward
                if self.show:
                    self.render()
                    print(f"{self.winner} wins")
                return self._get_observation(), reward, self.done, {}
            
        # Check if past half of max time and give punishment
        if (self.total_time > self.max_time//2):
            reward += self.too_long_punishment

        # Continue game
        if self.show:
            self.render()
        return self._get_observation(), reward, self.done, {}
    
    def _process_action(self, action, fteam, oteam): # friendly and opponent
        reward = 0

        fplane = fteam['planes'][0]
        oplane = oteam['planes'][0]
        obase = oteam['base']

        fplane_pos = fplane.get_pos()
        fplane_angle = fplane.heading
        oplane_pos = oplane.get_pos()
        obase_pos = obase.get_pos()

        dist_oplane = dist(oplane_pos, fplane_pos)
        dist_obase = dist(obase_pos, fplane_pos)

        angle_to_oplane = get_angle(oplane_pos, fplane_pos)
        angle_to_obase = get_angle(obase_pos, fplane_pos)
        rel_angle_oplane = (angle_to_oplane - fplane_angle) % 360
        rel_angle_obase = (angle_to_obase - fplane_angle) % 360

        # --------------- FORWARDS ---------------
        if action == 0: 
            fplane.forward(self.speed, self.time_step)

         # --------------- SHOOT ---------------
        elif action == 1:
            self.bullets.append(Bullet(fplane_pos[0], fplane_pos[1], fplane_angle, self.bullet_speed, self.friendly, oteam))
            fplane.forward(self.speed, self.time_step)
        
        # --------------- TURN RIGHT ---------------
        elif action == 2:
            fplane.rotate(-self.step_turn)
            fplane.forward(self.speed, self.time_step)

        # ---------------- TURN LEFT ----------------
        elif action == 3:
            fplane.rotate(self.step_turn)
            fplane.forward(self.speed, self.time_step)

        # ---------------- TURN TO OPLANE ----------------
        elif action == 4:
            if math.fabs(rel_angle_oplane) < self.step_turn: # within step_turn of base
                fplane.set_heading(angle_to_oplane)

            elif math.fabs(rel_angle_oplane) > 360 - self.step_turn: # within step_turn of base
                fplane.set_heading(angle_to_oplane)

            elif math.fabs(rel_angle_oplane) < 180: # turn right
                fplane.rotate(-self.step_turn)

            else: # turn left
                fplane.rotate(self.step_turn)
                
            fplane.forward(self.speed, self.time_step)

        # ---------------- TURN TO OBASE ----------------
        elif action == 5:
            if math.fabs(rel_angle_obase) < self.step_turn: # within step_turn of base
                fplane.set_heading(angle_to_obase)

            elif math.fabs(rel_angle_obase) > 360 - self.step_turn: # within step_turn of base
                fplane.set_heading(angle_to_obase)

            elif math.fabs(rel_angle_obase) < 180: # turn right
                fplane.rotate(-self.step_turn)

            else: # turn left
                fplane.rotate(self.step_turn)

            fplane.forward(self.speed, self.time_step)
        
        # ---------------- GIVE REWARDS IF CLOSER (DIST OR ANGLE) ----------------
        new_fplane_pos = fplane.get_pos()
        new_fplane_angle = fplane.heading
        new_oplane_pos = oplane.get_pos()
        new_obase_pos = obase.get_pos()

        new_dist_oplane = dist(new_oplane_pos, new_fplane_pos)
        new_dist_obase = dist(new_obase_pos, new_fplane_pos)

        new_angle_to_oplane = get_angle(new_oplane_pos, new_fplane_pos)
        new_angle_to_obase = get_angle(new_obase_pos, new_fplane_pos)
        new_rel_angle_oplane = (new_angle_to_oplane - new_fplane_angle) % 360
        new_rel_angle_obase = (new_angle_to_obase - new_fplane_angle) % 360

        if new_dist_oplane < dist_oplane: # If got closer to enemy plane
            reward += self.closer_to_plane_reward

        if new_dist_obase < dist_obase: # If got closer to enemy base
            reward += self.closer_to_base_reward

        if math.fabs(new_rel_angle_oplane) < math.fabs(rel_angle_oplane): # If aiming closer to enemy plane
            reward += self.turn_to_plane_reward

        if math.fabs(new_rel_angle_obase) < math.fabs(rel_angle_obase): # If aiming closer to enemy base
            reward += self.turn_to_base_reward

        return reward

    def draw_shot(self, hit, friendly_pos, target_pos, team):
        color = (0, 0, 0)
        color = (255, 0, 0) if team == 'red' else (0, 0, 255)
        if not hit: target_pos = (target_pos[0] + (random.random() * 2 - 1) * 100, target_pos[1] + (random.random() * 2 - 1) * 100)
        self.shot_history.append((hit, friendly_pos, target_pos, color))
    
    def winner_screen(self):
        if self.show:
            font = pygame.font.Font('freesansbold.ttf', 32)
            if self.winner != 'none':
                text = font.render(f"THE WINNER IS {self.winner.upper()}", True, (0, 0, 0))
                textRect = text.get_rect()
                textRect.center = (DISP_WIDTH//2, DISP_HEIGHT//2)
            else:
                text = font.render(f"THE GAME IS A TIE", True, (0, 0, 0))
                textRect = text.get_rect()
                textRect.center = (DISP_WIDTH//2, DISP_HEIGHT//2)
            self.display.blit(text, textRect)

    def show_wins(self):
        print("Wins by red:", self.team['red']['wins'])
        print("Wins by blue:", self.team['blue']['wins'])
        print("Tied games:", self.ties)

    def render(self):
        if self.show: # Just to ensure it won't render if self.show == False
            for event in pygame.event.get():
                # Check for KEYDOWN event
                if event.type == KEYDOWN:
                    # If the Esc key is pressed, then exit the main loop
                    if event.key == K_ESCAPE:
                        pygame.quit()
                        sys.exit()
                        return
                # Check for QUIT event. If QUIT, then set running to false.
                elif event.type == QUIT:
                    pygame.quit()
                    sys.exit()
                    return
                    
            # Fill background
            self.display.fill(WHITE)

            # Draw bullets
            for bullet in self.bullets:
                bullet.draw(self.display)
                    
            # Draw bases
            self.team['red']['base'].draw(self.display)
            self.team['blue']['base'].draw(self.display)

            # Draw planes
            for plane in self.team['red']['planes']:
                plane.update()
                plane.draw(self.display)
            for plane in self.team['blue']['planes']:
                plane.update()
                plane.draw(self.display)

            # Winner Screen
            if self.done:
                font = pygame.font.Font('freesansbold.ttf', 32)
                if self.winner != 'none':
                    text = font.render(f"THE WINNER IS {self.winner.upper()}", True, (0, 0, 0))
                    textRect = text.get_rect()
                    textRect.center = (DISP_WIDTH//2, DISP_HEIGHT//2)
                else:
                    text = font.render(f"THE GAME IS A TIE", True, (0, 0, 0))
                    textRect = text.get_rect()
                    textRect.center = (DISP_WIDTH//2, DISP_HEIGHT//2)
                self.display.blit(text, textRect)
                pygame.display.update()
                pygame.time.wait(3000)
                pygame.quit()
                return
            
            pygame.display.update()
            FPS.tick(7)

---------- Hyperparameter Tuning ----------

In [10]:
%pip install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio===0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
%pip install stable-baselines3[extra] optuna

^C
Note: you may need to restart the kernel to use updated packages.


In [None]:
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

In [None]:
LOG_DIR = './logs/'
OPT_DIR = './opt/'
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)
if not os.path.exists(OPT_DIR):
    os.makedirs(OPT_DIR)

In [None]:
def optimize_ppo(trial): 
    return {
        'n_steps':trial.suggest_int('n_steps', 2048, 8192),
        'gamma':trial.suggest_loguniform('gamma', 0.8, 0.9999),
        'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),
        'clip_range':trial.suggest_uniform('clip_range', 0.1, 0.4),
        'gae_lambda':trial.suggest_uniform('gae_lambda', 0.8, 0.99)
    }

In [None]:
def optimize_agent(trial):
    try:
        model_params = optimize_ppo(trial) 

        # Create environment 
        env = BattleEnvironment()
        env = Monitor(env, LOG_DIR)
        env = DummyVecEnv([lambda: env])
        env = VecFrameStack(env, 4, channels_order='last')

        model = PPO('MlpPolicy', env, tensorboard_log=LOG_DIR, verbose=0, **model_params)
        model.learn(total_timesteps=400000)

        # Evaluate model 
        mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)

        SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(trial.number))
        model.save(SAVE_PATH)

        return mean_reward
    except Exception as e:
        return -1000

In [30]:
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=15)

[32m[I 2022-06-26 16:19:21,952][0m A new study created in memory with name: no-name-f83b9314-4689-48e0-8fa3-274453f9450d[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=4786 and n_envs=1)
[32m[I 2022-06-26 16:29:44,321][0m Trial 0 finished with value: 102.8 and parameters: {'n_steps': 4786, 'gamma': 0.971550686333748, 'learning_rate': 1.5097327356026387e-05, 'clip_range': 0.34148511908874346, 'gae_lambda': 0.842651713694691}. Best is trial 0 with value: 102.8.[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=2778 and n_envs=1)
[32m[I 2022-06-26 16:40:05,190][0m Trial 1 finished with value: 287.6 and parameters: {'n_steps': 2778, 'gamma': 0.8183254905816083, 'learning_rate': 4.953550427923822e-05, 'clip_range': 0.3742727667812864, 'gae_lambda': 0.8694423217102631}. Best is trial 1 with value: 287.6.[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=27

We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=6354 and n_envs=1)
[32m[I 2022-06-26 18:30:47,112][0m Trial 12 finished with value: 307.0 and parameters: {'n_steps': 6354, 'gamma': 0.8241762204693597, 'learning_rate': 3.3158699298144134e-05, 'clip_range': 0.3974894031940379, 'gae_lambda': 0.9119210197111699}. Best is trial 12 with value: 307.0.[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=6983 and n_envs=1)
[32m[I 2022-06-26 18:40:32,991][0m Trial 13 finished with value: 111.6 and parameters: {'n_steps': 6983, 'gamma': 0.8301071033743885, 'learning_rate': 4.302573460543227e-05, 'clip_range': 0.3214519635552193, 'gae_lambda': 0.916256539152556}. Best is trial 12 with value: 307.0.[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=6269 and n_envs=1)
[32m[I 2022-06-26 18:50:00,962][0m Trial 14 finished with value: 101.4 and parameters: {'n_steps': 6269,

[32m[I 2022-06-26 20:47:59,558][0m Trial 24 finished with value: 93.0 and parameters: {'n_steps': 2141, 'gamma': 0.8194509027145852, 'learning_rate': 6.4124667986393e-05, 'clip_range': 0.39876008874893154, 'gae_lambda': 0.8634409960276019}. Best is trial 12 with value: 307.0.[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=3469 and n_envs=1)
[32m[I 2022-06-26 21:01:49,518][0m Trial 25 finished with value: 298.8 and parameters: {'n_steps': 3469, 'gamma': 0.8867738799751667, 'learning_rate': 5.158510478106449e-05, 'clip_range': 0.358259884757668, 'gae_lambda': 0.8964363894791995}. Best is trial 12 with value: 307.0.[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=3329 and n_envs=1)
[32m[I 2022-06-26 21:01:59,110][0m Trial 26 finished with value: -1000.0 and parameters: {'n_steps': 3329, 'gamma': 0.9159567912682718, 'learning_rate': 6.884460972555864e-05, 'clip_range': 0.26453495097319635, 'gae

-------- Load the Best Model ----------

In [31]:
study.best_params
study.best_trial
best_model = 'trial_12_best_model.zip'

In [32]:
model = PPO.load(os.path.join(OPT_DIR, best_model))

---------- Callback ----------

In [33]:
from stable_baselines3.common.callbacks import BaseCallback

In [34]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
        return True

In [35]:
CHECKPOINT_DIR = './train/'
callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)

---------- Train the Model ----------

In [36]:
env = BattleEnvironment()
env = Monitor(env, LOG_DIR)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, 4, channels_order='last')

In [37]:
model_params = study.best_params
model_params['n_steps'] = (model_params['n_steps'] // 64) * 64 # set to a factor of 64
model = PPO('MlpPolicy', env, tensorboard_log=LOG_DIR, verbose=1, **model_params)

Using cuda device


In [None]:
model.load(os.path.join(OPT_DIR, best_model))
model.learn(total_timesteps=4000000, callback=callback)

Logging to ./logs/PPO_34
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 89.1     |
|    ep_rew_mean     | 58.3     |
| time/              |          |
|    fps             | 541      |
|    iterations      | 1        |
|    time_elapsed    | 11       |
|    total_timesteps | 6336     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 89.9        |
|    ep_rew_mean          | 56.4        |
| time/                   |             |
|    fps                  | 400         |
|    iterations           | 2           |
|    time_elapsed         | 31          |
|    total_timesteps      | 12672       |
| train/                  |             |
|    approx_kl            | 0.019567937 |
|    clip_fraction        | 0.00775     |
|    clip_range           | 0.397       |
|    entropy_loss         | -1.78       |
|    explained_variance   | 0.00188     |
|    

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 95.5         |
|    ep_rew_mean          | 158          |
| time/                   |              |
|    fps                  | 344          |
|    iterations           | 11           |
|    time_elapsed         | 202          |
|    total_timesteps      | 69696        |
| train/                  |              |
|    approx_kl            | 0.0051663397 |
|    clip_fraction        | 0.00229      |
|    clip_range           | 0.397        |
|    entropy_loss         | -0.217       |
|    explained_variance   | 0.0123       |
|    learning_rate        | 3.32e-05     |
|    loss                 | 12.2         |
|    n_updates            | 100          |
|    policy_gradient_loss | -0.00337     |
|    value_loss           | 1.75e+03     |
------------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 97.4         |
|    ep_rew_mean          | 143          |
| time/                   |              |
|    fps                  | 338          |
|    iterations           | 21           |
|    time_elapsed         | 393          |
|    total_timesteps      | 133056       |
| train/                  |              |
|    approx_kl            | 0.0002479873 |
|    clip_fraction        | 0.000316     |
|    clip_range           | 0.397        |
|    entropy_loss         | -0.0388      |
|    explained_variance   | 0.0329       |
|    learning_rate        | 3.32e-05     |
|    loss                 | 6.83         |
|    n_updates            | 200          |
|    policy_gradient_loss | -0.000857    |
|    value_loss           | 375          |
------------------------------------------
-------------------------------------------
| rollout/                |               |
|    ep_l

---------- Evaluation ---------

In [None]:
env = BattleEnvironment()
env = Monitor(train_env, LOG_DIR)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(train_env, 4, channels_order='last')
model = PPO.load(f"{OPT_DIR}{best_model}")

In [None]:
episodes = 100
for episode in range(episodes): # Evaluates the model n times
    state = env.reset()
    score = 0
    while not env.done:
        action, _states = model.predict(state)
        n_state, red_reward, done, info = env.step(action)
        score += red_reward
    print('Episode:{} Score:{}'.format(episode+1, score))
env.show_wins()

--------- Render One Episode ----------

In [None]:
env.show = True
state = env.reset()
score = 0
while not train_env.done:
    env.render()
    action, _states = model.predict(state)
    n_state, red_reward, done, info = env.step(action)
    score += red_reward
print('Score:{}'.format(score))