In [1]:
import gym_line_follower
import gymnasium as gym
import matplotlib.pyplot as plt
from stable_baselines3.common.env_checker import check_env


from stable_baselines3 import DDPG, PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise
import numpy as np

from wrappers import RenderRerun

In [2]:
def initialize_environment(filename: str = None, skip_episodes=1000):
    env = gym.make('LineFollower-v0', gui = False, render_mode = 'rgb_array')
    env = RenderRerun(env, filename=filename, skip_episodes=skip_episodes, viewer="notebook")

    vec_env = DummyVecEnv([lambda: env])
    return env, vec_env

In [3]:
train_model = True
model_type = "ppo" # "ppo" or "ddpg

model_name = model_type +"_line_follower"

In [None]:
env, vec_env = initialize_environment(filename="gym-line-follower_training.rrd", skip_episodes=500)

if train_model:
    if model_type == "ddpg":
        # Stop training if there is no improvement after more than 3 evaluations
        stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=10, min_evals=5, verbose=1)
        eval_callback = EvalCallback(env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)
        # The noise objects for DDPG
        n_actions = env.action_space.shape[0]
        action_noise = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

        model = DDPG("MlpPolicy", vec_env, action_noise=action_noise, verbose=1, tensorboard_log="./ddpg_line_follower_tensorboard/")

        # model = DDPG("MlpPolicy", env, verbose=1, tensorboard_log="./ppo_line_follower_tensorboard/")
        model.learn(total_timesteps=10000, callback=eval_callback)

        # Evaluate the agent
        mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=100)
        print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

        # # Save the agent
        # model.save("ddpg_line_follower")

    if model_type == "ppo":
        # Stop training if there is no improvement after more than 3 evaluations
        stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=10, min_evals=5, verbose=1)
        eval_callback = EvalCallback(env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)

        model = PPO("MlpPolicy", vec_env, verbose=1, tensorboard_log="./ppo_line_follower_tensorboard/")
        model.learn(total_timesteps=100_000, callback=eval_callback)

        # Evaluate the agent
        mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=100)
        print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

    #save the model
    model.save(model_name)

else:
    if model_type == "ddpg":
       # The noise objects for DDPG
        n_actions = env.action_space.shape[0]
        action_noise = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
        model = DDPG("MlpPolicy", vec_env, action_noise=action_noise, verbose=1, tensorboard_log="./ddpg_line_follower_tensorboard/")
        
    if model_type == "ppo":
        model = PPO("MlpPolicy", vec_env, verbose=1, tensorboard_log="./ppo_line_follower_tensorboard/")



In [4]:
# Create a new environment for testing, saving a new recording file
env, vec_env = initialize_environment(filename="gym-line-follower_test.rrd", skip_episodes=0)

# Load the trained agent and do a test run
if model_type == "ppo":
    model = PPO.load("ppo_line_follower", env=vec_env)
if model_type == "ddpg":
    model = DDPG.load("ddpg_line_follower", env=vec_env)

obs = model.env.reset()

steps = 0
for i in range(10):
    action, _state = model.predict(obs, deterministic=True)
    # print("Action: ", *action)
    obs, reward, done, info = vec_env.step(action)

    steps += 1
    if done:
        print("Done in ", steps, " steps")
        break


options= 
made client


  gym.logger.warn(
  gym.logger.warn(


HTML(value='<div id="63c0b237-1682-4cc5-b5ad-546a9e0be17b"><style onload="eval(atob(\'KGFzeW5jIGZ1bmN0aW9uICgp…

Viewer()

In [5]:
env.close()