In [None]:
from stable_baselines3 import PPO
from stable_baselines.common.policies import FeedForwardPolicy, register_policy
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_checker import check_env
from robot import Robot
import os

In [None]:
# load the robot interface gym environment
robot=Robot()

# validate that the robot interface is valid
check_env(robot, warn=True)

In [None]:
# test the robot interface
episodes = 5
for episode in range(1, episodes+1):
    state = robot.reset()
    done = False
    score = 0 
    
    while not done:
        robot.render()
        action = robot.action_space.sample()
        n_state, reward, done, info = robot.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
robot.close()

In [None]:
# create the model

class CustomPolicy(FeedForwardPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomPolicy, self).__init__(*args, **kwargs,
                                           net_arch=[dict(pi=[2],
                                                          vf=[2])],
                                           feature_extraction="mlp")
                                           
register_policy('CustomPolicy', CustomPolicy)
log_path = os.path.join(os.path, 'logs')
model = PPO(policy='CustomPolicy', env=robot, verbose=1, tensorboard_log=log_path)

In [None]:
# train
model.learn(total_timesteps=4000)

In [None]:
# save the model
model.save('PPO')
evaluate_policy(model, robot, n_eval_episodes=10, render=False)