In [1]:
# pip install pygame # 2.1.2
# pip install numpy # 1.21.5
# pip install gym # 2.0.0
# pip install torch==1.10.2 # 1.10.2
# pip install stable-baselines3[extra] optuna # 1.5.0 for stable-baselines3, 2.10.1 for optuna

In [2]:
import random
import pygame
from pygame.locals import *
import numpy as np
import numpy.linalg as LA
import math
from collections import defaultdict
import gym
from gym import spaces
import os
import sys
os.environ['KMP_DUPLICATE_LIB_OK']='True'

pygame 2.1.2 (SDL 2.0.18, Python 3.9.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
WHITE = (255, 255, 255)
RED = (138, 24, 26)
BLUE = (0, 93, 135)
BLACK = (0, 0, 0)
DISP_WIDTH = 1000
DISP_HEIGHT = 1000

---------- Helper Functions ----------

In [4]:
def get_angle(p0, p1):
    dx = p1[0] - p0[0]
    dy = p1[1] - p0[1]
    rads = math.atan2(-dy,dx)
    rads %= 2*math.pi
    degs = math.degrees(rads)
    return degs

def dist(p1,p0):
    return math.sqrt((p1[0]-p0[0])**2 + (p1[1]-p0[1])**2)

def blitRotate(image, pos, originPos, angle):

    # offset from pivot to center
    image_rect = image.get_rect(topleft = (pos[0] - originPos[0], pos[1]-originPos[1]))
    offset_center_to_pivot = pygame.math.Vector2(pos) - image_rect.center
    
    # roatated offset from pivot to center
    rotated_offset = offset_center_to_pivot.rotate(-angle)

    # rotated image center
    rotated_image_center = (pos[0] - rotated_offset.x, pos[1] - rotated_offset.y)

    # get a rotated image
    rotated_image = pygame.transform.rotate(image, angle)
    rotated_image_rect = rotated_image.get_rect(center = rotated_image_center)

    return rotated_image, rotated_image_rect

def calc_new_xy(old_xy, speed, time, angle):
    new_x = old_xy[0] + (speed*time*math.cos(-math.radians(angle)))
    new_y = old_xy[1] + (speed*time*math.sin(-math.radians(angle)))
    return (new_x, new_y)

---------- Plane Class ----------

In [5]:
class Plane:
    def __init__(self, team): 
        self.team = team
        self.image = pygame.image.load(f"Images/{team}_plane.png")
        self.w, self.h = self.image.get_size()
        self.direction = 0
        self.rect = self.image.get_rect()
        self.reset()

    def reset(self):
        if self.team == 'red':
            x = (DISP_WIDTH - self.w/2)/3 * random.random()
            y = (DISP_HEIGHT - self.h/2) * random.random()
            self.rect.center = (x, y)
            self.direction = 90 * random.random() if random.random() < .5 else 90 * random.random() + 270
            
        else:
            x = (DISP_WIDTH - self.w/2)/3 * random.random() + (DISP_WIDTH - self.w/2)/3*2
            y = (DISP_HEIGHT - self.h/2) * random.random()
            self.rect.center = (x, y)
            self.direction = 180 * random.random() + 90
        
    def rotate(self, angle):
        self.direction += angle
        while self.direction > 360:
            self.direction -= 360
        while self.direction < 0:
            self.direction += 360
        # Keep player on the screen
        if self.rect.left < 0:
            self.rect.left = 0
        if self.rect.right > DISP_WIDTH:
            self.rect.right = DISP_WIDTH
        if self.rect.top <= 0:
            self.rect.top = 0
        if self.rect.bottom >= DISP_HEIGHT:
            self.rect.bottom = DISP_HEIGHT

    def set_direction(self, direction):
        self.direction = direction

    def forward(self, speed, time):
        oldpos = self.rect.center
        self.rect.center = calc_new_xy(oldpos, speed, time, self.direction)
        # Keep player on the screen
        if self.rect.left < 0:
            self.rect.left = 0
        if self.rect.right > DISP_WIDTH:
            self.rect.right = DISP_WIDTH
        if self.rect.top <= 0:
            self.rect.top = 0
        if self.rect.bottom >= DISP_HEIGHT:
            self.rect.bottom = DISP_HEIGHT

    def draw(self, surface):
        image, rect = blitRotate(self.image, self.rect.center, (self.w/2, self.h/2), self.direction)
        surface.blit(image, rect)

    def update(self):
        # Keep player on the screen
        if self.rect.left < 0:
            self.rect.left = 0
        if self.rect.right > DISP_WIDTH:
            self.rect.right = DISP_WIDTH
        if self.rect.top <= 0:
            self.rect.top = 0
        if self.rect.bottom >= DISP_HEIGHT:
            self.rect.bottom = DISP_HEIGHT

    def get_pos(self):
        image, rect = blitRotate(self.image, self.rect.center, (self.w/2, self.h/2), self.direction)
        return (rect.centerx, rect.centery)
    
    def get_direction(self):
        return self.direction

----------- Base Class ----------

In [6]:
class Base:
    
    def __init__(self, team):
        self.team = team
        self.image = pygame.image.load(f"Images/{team}_base.png")
        self.w, self.h = self.image.get_size()
        self.rect = self.image.get_rect()
        self.reset()
        
    def reset(self):
        if self.team == 'red':
            x = (DISP_WIDTH - self.w/2)/3 * random.random()
            y = (DISP_HEIGHT - self.h/2) * random.random()
            self.rect.center = (x, y)
        else:
            x = (DISP_WIDTH - self.w/2)/3 * random.random() + (DISP_WIDTH - self.w/2)/3*2
            y = (DISP_HEIGHT - self.h/2) * random.random()
            self.rect.center = (x, y)

    def draw(self, surface):
        surface.blit(self.image, self.rect)
            
    def get_pos(self):
        return self.rect.center

---------- Bullet Class ----------

In [7]:
class Bullet(pygame.sprite.Sprite):
    def __init__(self, x, y, angle, speed, fteam, oteam):
        pygame.sprite.Sprite.__init__(self)
        self.off_screen = False
        self.image = pygame.Surface((6, 3), pygame.SRCALPHA)
        self.fteam = fteam
        self.color = RED if self.fteam == 'red' else BLUE
        self.oteam = oteam
        self.image.fill(self.color)
        self.rect = self.image.get_rect(center=(x, y))
        self.w, self.h = self.image.get_size()
        self.direction = angle + (random.random() * 10 - 5)
        self.pos = (x, y)
        self.speed = speed

    def update(self, screen_width, screen_height, time):
        oldpos = self.rect.center
        self.rect.center = calc_new_xy(oldpos, self.speed, time, self.direction)
        if self.rect.centerx > screen_width or self.rect.centerx < 0 or self.rect.centery > screen_height or self.rect.centery < 0:
            return 'miss'
        for plane in self.oteam['planes']:
            if self.rect.colliderect(plane.rect):
                return 'plane'
        if self.rect.colliderect(self.oteam['base'].rect):
            return 'base'
        return 'none'

    def draw(self, surface):
        image, rect = blitRotate(self.image, self.rect.center, (self.w/2, self.h/2), self.direction)
        surface.blit(image, rect)

    def get_pos(self):
        return self.rect.center
    
    def get_direction(self):
        return self.direction

----------- Battle Environment ----------

In [8]:
class BattleEnvironment(gym.Env):
    def __init__(self, show=False, hit_base_reward=100, hit_plane_reward=100, miss_punishment=-5, too_long_punishment=0, closer_to_base_reward=0, 
        closer_to_plane_reward=0, lose_punishment=-50, fps=20):
        super(BattleEnvironment, self).__init__()

        self.width = DISP_WIDTH
        self.height = DISP_HEIGHT
        self.max_time = 10
        self.max_bullets = 10
        self.action_space = spaces.Discrete(4) # Forward, Turn to plane, Turn to base, Shoot

        # ---------- Observation Space ----------
        obs_space = {
            'fplane_x': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'fplane_y': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'fplane_direction': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'rel_angle_oplane': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'rel_angle_obase': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'fbase_x': spaces.Box(-1, 1, shape=(1,)),
            'fbase_y': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'oplane_x': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'oplane_y': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'oplane_direction': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'obase_x': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'obase_y': spaces.Box(-1, 1, dtype=np.float32, shape=(1,)),
            'time': spaces.Box(-1, 1, dtype=np.float32, shape=(1,))
        }

        for idx in range(self.max_bullets):
            obs_space[f'bullet_{idx}_shot'] = spaces.Box(-1, 1, dtype=np.int16, shape=(1,))
            obs_space[f'bullet_{idx}_x'] = spaces.Box(-1, 1, dtype=np.float32, shape=(1,))
            obs_space[f'bullet_{idx}_y'] = spaces.Box(-1, 1, dtype=np.float32, shape=(1,))
            obs_space[f'bullet_{idx}_direction'] = spaces.Box(-1, 1, dtype=np.float32, shape=(1,))
            obs_space[f'bullet_{idx}_dist_to_plane'] = spaces.Box(-1, 1, dtype=np.float32, shape=(1,))
            obs_space[f'bullet_{idx}_facing_plane'] = spaces.Box(-1, 1, dtype=np.float32, shape=(1,))

        mins = np.array([x.low[0] for x in obs_space.values()])
        maxs = np.array([x.high[0] for x in obs_space.values()])

        self.observation_space = spaces.Box(mins, maxs, dtype=np.float32)
        
        # ---------- Initialize values ----------
        self.team = {}
        self.team['red'] = {}
        self.team['blue'] = {}
        self.team['red']['base'] = Base('red')
        self.team['blue']['base'] = Base('blue')
        self.team['red']['planes'] = []
        self.team['red']['planes'].append(Plane('red'))
        self.team['blue']['planes'] = []
        self.team['blue']['planes'].append(Plane('blue'))
        self.team['red']['wins'] = 0
        self.team['blue']['wins'] = 0
        self.ties = 0
        self.bullets = {}
        self.speed = 400 # mph
        self.bullet_speed = 1000 # mph
        self.total_time = 0 # in hours
        self.time_step = 0.1 # hours per time step
        self.show = show # show the pygame animation
        self.step_turn = 25 # degrees to turn per step
        self.hit_base_reward = hit_base_reward
        self.hit_plane_reward = hit_plane_reward
        self.miss_punishment = miss_punishment
        self.too_long_punishment = too_long_punishment
        self.closer_to_base_reward = closer_to_base_reward
        self.closer_to_plane_reward = closer_to_plane_reward
        self.lose_punishment = lose_punishment
        self.fps = fps

        self.reset()
    
    def _get_observation(self):
        # Observation: Observation: fplane_pos, fplane_angle, obase_pos, oplane_pos, time
        # This code will all have to be changed when adding multiple planes
        fplane = self.team['red']['planes'][0]
        fbase = self.team['red']['base']
        oplane = self.team['blue']['planes'][0]
        obase = self.team['red']['base']

        fplane_pos = fplane.get_pos()
        fplane_direction = fplane.get_direction()
        fbase_pos = fbase.get_pos()
        oplane_pos = oplane.get_pos()
        oplane_direction = oplane.get_direction()
        obase_pos = obase.get_pos()

        angle_to_oplane = get_angle(fplane_pos, oplane_pos)
        angle_to_obase = get_angle(fplane_pos, obase_pos)
        rel_angle_oplane = (angle_to_oplane - fplane_direction)
        if rel_angle_oplane < -180: rel_angle_oplane += 360
        rel_angle_obase = (angle_to_obase - fplane_direction)
        if rel_angle_obase < -180: rel_angle_obase += 360

        dct = {
            'fplane_x': (fplane_pos[0] / self.width) * 2 - 1,
            'fplane_y': (fplane_pos[1] / self.height) * 2 - 1,
            'fplane_direction': (fplane_direction / 360) * 2 - 1,
            'rel_angle_oplane': rel_angle_oplane / 360,
            'rel_angle_obase': rel_angle_obase,
            'fbase_x': (fbase_pos[0] / self.width) * 2 - 1,
            'fbase_y': (fbase_pos[1] / self.height) * 2 - 1,
            'oplane_x': (oplane_pos[0] / self.width) * 2 - 1,
            'oplane_y': (oplane_pos[1] / self.width) * 2 - 1,
            'oplane_direction': (oplane_direction / 360) * 2 - 1,
            'obase_x': (obase_pos[0] / self.width) * 2 - 1,
            'obase_y': (obase_pos[1] / self.width) * 2 - 1,
            'time': self.total_time / self.max_time
        }

        for idx, bullet in self.bullets.items():
            if bullet is not None:
                bullet_pos = bullet.get_pos()
                bullet_direction = bullet.get_direction()
                dct[f'bullet_{idx}_shot'] = 1
                dct[f'bullet_{idx}_x'] = (bullet_pos[0] / self.width) * 2 - 1
                dct[f'bullet_{idx}_y'] = (bullet_pos[1] / self.height) * 2 - 1
                dct[f'bullet_{idx}_direction'] = (bullet_direction / self.height) * 2 - 1
                dct[f'bullet_{idx}_dist_to_plane'] = (dist(bullet_pos, fplane_pos)) * 2 - 1
                if math.fabs(bullet_direction - get_angle(bullet_pos, fplane_pos)) < self.step_turn:
                    dct[f'bullet_{idx}_facing_plane'] = 1
                else:
                    dct[f'bullet_{idx}_facing_plane'] = -1
            else:
                dct[f'bullet_{idx}_shot'] = -1
                dct[f'bullet_{idx}_x'] = -1
                dct[f'bullet_{idx}_y'] = -1
                dct[f'bullet_{idx}_direction'] = 0
                dct[f'bullet_{idx}_dist_to_plane'] = 0
                dct[f'bullet_{idx}_facing_plane'] = 0
        return np.array([x for x in dct.values()], dtype=np.float32)

    def reset(self): # return observation
        self.done = False
        self.winner = 'none'

        self.team['red']['base'].reset()
        self.team['blue']['base'].reset()

        for plane in self.team['red']['planes']:
            plane.reset()
        for plane in self.team['blue']['planes']:
            plane.reset()

        self.total_time = 0
        self.bullets = {n: None for n in range(self.max_bullets)}

        if self.show:
            pygame.init()
            pygame.font.init()
            self.clock = pygame.time.Clock()
            self.display = pygame.display.set_mode((DISP_WIDTH, DISP_HEIGHT))
            pygame.display.set_caption("Battlespace Simulator")
            pygame.time.wait(1000)

        return self._get_observation()

    def step(self, action): # return observation, reward, done, info
        reward = 0

        # Check if over time, if so, end game in tie
        self.total_time += self.time_step
        if self.total_time >= self.max_time:
            self.done = True
            self.ties += 1
            if self.show:
                self.render()
                print("Draw")
            return self._get_observation(), 0, self.done, {}

        # Red turn
        self.friendly = 'red'
        self.opponent = 'blue'
        reward += self._process_action(action, self.friendly, self.opponent)
        
        # Blue turn
        self.friendly = 'blue'
        self.opponent = 'red'
        self._process_action(random.randint(0, 3), self.friendly, self.opponent)        

        # Check if bullets hit and move them
        for key, bullet in self.bullets.items():
            if bullet is not None:
                outcome = bullet.update(self.width, self.height, self.time_step)
                if outcome == 'miss':
                    reward = reward + self.miss_punishment if bullet.fteam == 'red' else 0
                    self.bullets[key].kill()
                    self.bullets[key] = None
                elif outcome == 'base' or outcome == 'plane': # If a bullet hit
                    self.winner = bullet.fteam
                    self.team[self.winner]['wins'] += 1
                    if bullet.fteam == 'red':
                        reward = reward + self.hit_base_reward if outcome == 'base' else reward + self.hit_plane_reward
                    else:
                        reward += self.lose_punishment
                    self.done = True
                    if self.show:
                        self.render()
                        print(f"{self.winner} wins")
                    return self._get_observation(), reward, self.done, {}
            
        # Check if past half of max time and give punishment
        if (self.total_time > self.max_time//2):
            reward += self.too_long_punishment

        # Continue game
        if self.show:
            self.render()
        return self._get_observation(), reward, self.done, {}
    
    def _process_action(self, action, fteam, oteam): # friendly and opponent teams
        reward = 0

        fplane = self.team[fteam]['planes'][0]
        obase = self.team[oteam]['base']
        oplane = self.team[oteam]['planes'][0]

        fplane_pos = fplane.get_pos()
        fplane_angle = fplane.get_direction()
        obase_pos = obase.get_pos()
        oplane_pos = oplane.get_pos()

        dist_oplane = dist(oplane_pos, fplane_pos)
        dist_obase = dist(obase_pos, fplane_pos)

        angle_to_oplane = get_angle(fplane_pos, oplane_pos)
        angle_to_obase = get_angle(fplane_pos, obase_pos)
        rel_angle_oplane = (angle_to_oplane - fplane_angle)
        if rel_angle_oplane < -180: rel_angle_oplane += 360
        rel_angle_obase = (angle_to_obase - fplane_angle)
        if rel_angle_obase < -180: rel_angle_obase += 360

        # --------------- FORWARDS ---------------
        if action == 0: 
            fplane.forward(self.speed, self.time_step)

         # --------------- SHOOT ---------------
        elif action == 1:
            if any([v is None for v in self.bullets.values()]):
                for k, v in self.bullets.items():
                    if v is None:
                        self.bullets[k] = Bullet(fplane_pos[0], fplane_pos[1], fplane_angle, self.bullet_speed, fteam, self.team[oteam])
                        break
            fplane.forward(self.speed, self.time_step)
        
        # --------------- TURN TO ENEMY PLANE ---------------
        elif action == 2:
            if math.fabs(rel_angle_oplane) < self.step_turn:
                fplane.rotate(rel_angle_oplane)
            elif rel_angle_oplane < 0:
                fplane.rotate(-self.step_turn)
            else:
                fplane.rotate(self.step_turn)
            fplane.forward(self.speed, self.time_step)

        # ---------------- TURN TO ENEMY BASE ----------------
        elif action == 3:
            if math.fabs(rel_angle_obase) < self.step_turn:
                fplane.rotate(rel_angle_obase)
            elif rel_angle_obase < 0:
                fplane.rotate(-self.step_turn)
            else:
                fplane.rotate(self.step_turn)
            fplane.forward(self.speed, self.time_step)

        # ---------------- GIVE REWARDS IF CLOSER (DIST OR ANGLE) ----------------
        new_fplane_pos = fplane.get_pos()
        new_oplane_pos = oplane.get_pos()
        new_obase_pos = obase.get_pos()

        new_dist_oplane = dist(new_oplane_pos, new_fplane_pos)
        new_dist_obase = dist(new_obase_pos, new_fplane_pos)

        if new_dist_oplane < dist_oplane: # If got closer to enemy plane
            reward += self.closer_to_plane_reward

        if new_dist_obase < dist_obase: # If got closer to enemy base
            reward += self.closer_to_base_reward

        return reward

    def draw_shot(self, hit, friendly_pos, target_pos, team):
        color = BLACK
        color = RED if team == 'red' else BLUE
        if not hit: target_pos = (target_pos[0] + (random.random() * 2 - 1) * 100, target_pos[1] + (random.random() * 2 - 1) * 100)
        self.shot_history.append((hit, friendly_pos, target_pos, color))
    
    def winner_screen(self):
        if self.show:
            font = pygame.font.Font('freesansbold.ttf', 32)
            if self.winner != 'none':
                text = font.render(f"THE WINNER IS {self.winner.upper()}", True, BLACK)
                textRect = text.get_rect()
                textRect.center = (DISP_WIDTH//2, DISP_HEIGHT//2)
            else:
                text = font.render(f"THE GAME IS A TIE", True, BLACK)
                textRect = text.get_rect()
                textRect.center = (DISP_WIDTH//2, DISP_HEIGHT//2)
            self.display.blit(text, textRect)

    def show_wins(self):
        print("Wins by red:", self.team['red']['wins'])
        print("Wins by blue:", self.team['blue']['wins'])
        print("Tied games:", self.ties)

    def render(self, mode="human"):
        if self.show: # Just to ensure it won't render if self.show == False
            for event in pygame.event.get():
                # Check for KEYDOWN event
                if event.type == KEYDOWN:
                    # If the Esc key is pressed, then exit the main loop
                    if event.key == K_ESCAPE:
                        pygame.quit()
                        sys.exit()
                # Check for QUIT event. If QUIT, then set running to false.
                elif event.type == QUIT:
                    pygame.quit()
                    self.done = True
                    sys.exit()
                    
            # Fill background
            self.display.fill(WHITE)

            # Draw bullets
            for bullet in self.bullets.values():
                if bullet is not None:
                    bullet.draw(self.display)
                    
            # Draw bases
            self.team['red']['base'].draw(self.display)
            self.team['blue']['base'].draw(self.display)

            # Draw planes
            for plane in self.team['red']['planes']:
                plane.update()
                plane.draw(self.display)
            for plane in self.team['blue']['planes']:
                plane.update()
                plane.draw(self.display)

            # Winner Screen
            if self.done:
                font = pygame.font.Font('freesansbold.ttf', 32)
                if self.winner != 'none':
                    text = font.render(f"THE WINNER IS {self.winner.upper()}", True, BLACK)
                    textRect = text.get_rect()
                    textRect.center = (DISP_WIDTH//2, DISP_HEIGHT//2)
                else:
                    text = font.render(f"THE GAME IS A TIE", True, BLACK)
                    textRect = text.get_rect()
                    textRect.center = (DISP_WIDTH//2, DISP_HEIGHT//2)
                self.display.blit(text, textRect)
                pygame.display.update()
                pygame.time.wait(1000)
                pygame.quit()
                return
            
            pygame.display.update()
            self.clock.tick(self.fps)

----------- Hyperparam Tuning ----------

In [9]:
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import BaseCallback
# from stable_baselines3.common import env_checker
# env_checker.check_env(BattleEnvironment())

In [10]:
FOLDER = 'PPO(3)'
LOG_DIR = f'{FOLDER}/logs/'
OPT_DIR = f'{FOLDER}/opt/'
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)
if not os.path.exists(OPT_DIR):
    os.makedirs(OPT_DIR)

In [11]:
def optimize_ppo(trial): 
    return {
        'n_steps':trial.suggest_int('n_steps', 2048, 8192),
        'gamma':trial.suggest_loguniform('gamma', 0.8, 0.9999),
        'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),
        'clip_range':trial.suggest_uniform('clip_range', 0.1, 0.4),
        'gae_lambda':trial.suggest_uniform('gae_lambda', 0.8, 0.99)
    }

In [12]:
def optimize_agent(trial):
    try:
        model_params = optimize_ppo(trial)
        env = BattleEnvironment()
        model = PPO('MlpPolicy', env, tensorboard_log=LOG_DIR, verbose=0, **model_params)
        model.learn(total_timesteps=200000)
        mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)
        SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}'.format(trial.number))
        model.save(SAVE_PATH)
        return mean_reward
    except Exception as e:
        return -1000

In [13]:
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=20)

[32m[I 2022-07-02 13:25:09,844][0m A new study created in memory with name: no-name-6ab95736-1635-41c7-b3bc-8374bf456b55[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=7649 and n_envs=1)
[32m[I 2022-07-02 13:39:30,626][0m Trial 0 finished with value: -6.2 and parameters: {'n_steps': 7649, 'gamma': 0.8946663787485506, 'learning_rate': 1.4266799712631003e-05, 'clip_range': 0.18642251944832786, 'gae_lambda': 0.8477375470890243}. Best is trial 0 with value: -6.2.[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=7206 and n_envs=1)
[32m[I 2022-07-02 13:55:02,176][0m Trial 1 finished with value: -34.4 and parameters: {'n_steps': 7206, 'gamma': 0.8352397690147373, 'learning_rate': 3.235842211828897e-05, 'clip_range': 0.1011473276955016, 'gae_lambda': 0.8136795498947424}. Best is trial 0 with value: -6.2.[0m
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=303

In [17]:
model_params = study.best_params
model_params['n_steps'] = (model_params['n_steps'] // 64) * 64 # set to a factor of 64
study.best_trial

FrozenTrial(number=16, values=[64.6], datetime_start=datetime.datetime(2022, 7, 2, 15, 23, 45, 688205), datetime_complete=datetime.datetime(2022, 7, 2, 15, 30, 26, 898660), params={'n_steps': 2076, 'gamma': 0.866331689775563, 'learning_rate': 6.604645883688687e-05, 'clip_range': 0.286690354522357, 'gae_lambda': 0.9183168288996518}, distributions={'n_steps': IntUniformDistribution(high=8192, low=2048, step=1), 'gamma': LogUniformDistribution(high=0.9999, low=0.8), 'learning_rate': LogUniformDistribution(high=0.0001, low=1e-05), 'clip_range': UniformDistribution(high=0.4, low=0.1), 'gae_lambda': UniformDistribution(high=0.99, low=0.8)}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=16, state=TrialState.COMPLETE, value=None)

---------- Training the Best Model ----------

In [18]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
        return True

CHECKPOINT_DIR = f'{FOLDER}/train/'
save_freq = 50000
timesteps = 1000000
saved_timesteps = timesteps // save_freq * save_freq

In [19]:
callback = TrainAndLoggingCallback(check_freq=save_freq, save_path=CHECKPOINT_DIR)
env = BattleEnvironment()
model = PPO('MlpPolicy', env, tensorboard_log=LOG_DIR, verbose=1, **model_params)
model.load(os.path.join(OPT_DIR, f'trial_16.zip'))
model.learn(total_timesteps=timesteps, callback=callback)
del model

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to PPO(3)/logs/PPO_21
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 26.2     |
|    ep_rew_mean     | 20.1     |
| time/              |          |
|    fps             | 329      |
|    iterations      | 1        |
|    time_elapsed    | 6        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 23.6         |
|    ep_rew_mean          | 15.9         |
| time/                   |              |
|    fps                  | 310          |
|    iterations           | 2            |
|    time_elapsed         | 13           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0009034141 |
|    clip_fraction        | 0            |
|    clip_range     

----------- Evaluation ----------

In [20]:
model = PPO.load(f'{CHECKPOINT_DIR}best_model_{saved_timesteps}')
env = BattleEnvironment(show=True, fps=120)
episodes = 100
for episode in range(episodes): # Evaluates the model n times
    state = env.reset()
    score = 0
    while not env.done:
        env.render()
        action, _states = model.predict(state)
        n_state, reward, done, info = env.step(action)
        score += reward
    print('Episode:{} Score:{}'.format(episode+1, score))
env.show_wins()

red wins
Episode:1 Score:100
blue wins
Episode:2 Score:-32
blue wins
Episode:3 Score:-23
red wins
Episode:4 Score:100
red wins
Episode:5 Score:100
red wins
Episode:6 Score:100
red wins
Episode:7 Score:100
blue wins
Episode:8 Score:-23
Draw
Episode:9 Score:-57
blue wins
Episode:10 Score:-20
red wins
Episode:11 Score:82
blue wins
Episode:12 Score:-20
red wins
Episode:13 Score:97
red wins
Episode:14 Score:100
red wins
Episode:15 Score:94
blue wins
Episode:16 Score:-20
blue wins
Episode:17 Score:-32
red wins
Episode:18 Score:100
red wins
Episode:19 Score:97
red wins
Episode:20 Score:94
blue wins
Episode:21 Score:-20
red wins
Episode:22 Score:94
red wins
Episode:23 Score:100
red wins
Episode:24 Score:97
blue wins
Episode:25 Score:-20
red wins
Episode:26 Score:100
blue wins
Episode:27 Score:-20
blue wins
Episode:28 Score:-29
blue wins
Episode:29 Score:-29
blue wins
Episode:30 Score:-20
blue wins
Episode:31 Score:-32
blue wins
Episode:32 Score:-20
blue wins
Episode:33 Score:-35
red wins
Episo

---------- Visualize Games ----------

In [22]:
env = BattleEnvironment(show=True)
episodes = 10
for episode in range(episodes): # Evaluates the model n times
    state = env.reset()
    score = 0
    while not env.done:
        env.render()
        action, _states = model.predict(state)
        # action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward
    print('Episode:{} Score:{}'.format(episode+1, score))

blue wins
Episode:1 Score:-20
red wins
Episode:2 Score:100
red wins
Episode:3 Score:100
blue wins
Episode:4 Score:-20
red wins
Episode:5 Score:100
red wins
Episode:6 Score:100
blue wins
Episode:7 Score:-26
blue wins
Episode:8 Score:-20
blue wins
Episode:9 Score:-26
red wins
Episode:10 Score:73
