# PPO Training Script

A template script.

In [1]:
from stable_baselines3 import PPO

import torch as th

import random

from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
from animalai.envs.environment import AnimalAIEnvironment


In [None]:
def train_agent_single_config(configuration_file, env_path , results_path, log_bool = False, aai_seed = 2023, watch = False, num_saves = 100, num_steps = 10000):
    
    port = 5005 + random.randint(
    0, 1000
    )  # use a random port to avoid problems if a previous version exits slowly
    
    if not log_bool:
        log_folder_path = ""
    else:
        log_folder_path = results_path + "/player_logs"

    aai_env = AnimalAIEnvironment(
        seed = aai_seed,
        file_name=env_path,
        log_folder = log_folder_path,
        arenas_configurations=configuration_file,
        play=False,
        base_port=port,
        inference=watch,
        useCamera=False,
        #resolution=64,
        useRayCasts=True,
        no_graphics=True,
        raysPerSide=15,
        rayMaxDegrees = 30,
        timescale=1,
        
    )


    env = UnityToGymWrapper(aai_env, uint8_visual=False, allow_multiple_obs=False, flatten_branched=True)
    runname = "competition_raycast_ppo"

    policy_kwargs = dict(activation_fn=th.nn.ReLU)
    model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1, tensorboard_log=(results_path + "/tensor_log/" + runname))

    reset_num_timesteps = True
    for i in range(num_saves):
        model.learn(num_steps, reset_num_timesteps=reset_num_timesteps)
        model.save(results_path + "/modelsaves/" + runname + "/model_" + str( (i+1)*num_steps ))
        reset_num_timesteps = False
    env.close()


In [None]:
env_path = "../env/AnimalAI.exe"
model_results_path = "../modelsaves"
configuration_file = "../configs/aai-competition-curriculum.yml"

train_agent_single_config(configuration_file=configuration_file, env_path = env_path, results_path = model_results_path, watch = False, num_saves = 10, num_steps = 100000)

In [None]:
"""

run:
tensorboard --logdir ./tensor_log

from command line in conda environment to view tensor log (allows you to watch the agent while it trains).
(change the path to wherever you are storing tensor logs)
"""