In [1]:
import gymnasium as gym
import os
import random
import string

from PyFlyt.gym_envs import FlattenWaypointEnv
from stable_baselines3.common.env_util import make_vec_env

from stable_baselines3 import PPO, A2C, DDPG, TD3, SAC

def train(env_name, train_env, algorithm_name, n_episodes=10000, logging=False, n_runs=10, model_name= None):
    environment = env_name

    if algorithm_name == "PPO": algorithm = PPO
    elif algorithm_name == "A2C": algorithm = A2C
    elif algorithm_name == "DDPG": algorithm = DDPG
    elif algorithm_name == "TD3": algorithm = TD3
    elif algorithm_name == "SAC": algorithm = SAC
    else:
        print("Error: Invalid DRL Algorithm specified")
        return

    full_id = algorithm_name + '_' + environment

    models_dir = f"models/{full_id}"
    logdir = "data"

    if not os.path.exists(models_dir):
        os.makedirs(models_dir)
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    
    

    if model_name:
        model = algorithm.load(models_dir +"\\"+ model_name)
    else:
        model = algorithm("MlpPolicy", train_env, verbose=1, tensorboard_log=logdir if logging else None)
    
    model.set_env(train_env)
    for i in range(1, n_runs+1):
        model.learn(total_timesteps=n_episodes, reset_num_timesteps=False, tb_log_name=full_id)   
        model.save(f"{models_dir}/{model._total_timesteps + n_episodes*i}")

    train_env.close()
    return model




In [2]:
env = FlattenWaypointEnv(gym.make("PyFlyt/QuadX-Waypoints-v1"), context_length=1)
#env = gym.make("PyFlyt/QuadX-Waypoints-v1")
obs = env.reset()

  logger.warn(


[A                             [A


In [2]:
#env = make_vec_env(lambda: FlattenWaypointEnv(gym.make("PyFlyt/QuadX-Waypoints-v1"), context_length=1), n_envs=1)
env_name = "QuadX-Hover-v1"
env = make_vec_env(lambda: gym.make("PyFlyt/QuadX-Hover-v1"))
m = train(env_name, env, "PPO", n_episodes=100000, n_runs=1)# model_name="50720.zip")

Using cpu device                             [A                             [A                             [A                             [A                             [A                             [A
[A                             [A                             [A                             [A
[A                             [A                             [A                             [A
[A                             [A                             [A                             [A                             [A                             [A                             [A                             [A                             [A                             [A                             [A                             [A                             [A
[A                             [A                             [A                             [A                             [A                             [A                             [A        

KeyboardInterrupt: 

In [3]:
m = PPO.load("models\\PPO_QuadX-Waypoints-v1_sGHuVNzoGSoWUMxHNeXJ\\250000.zip")

In [4]:

for _ in range(2):
    render_env = FlattenWaypointEnv(gym.make(f"PyFlyt/QuadX-Waypoints-v1", render_mode="human"), context_length=1)
    #render_env = gym.make("PyFlyt/QuadX-Hover-v1", render_mode="human")
    obs = render_env.reset()
    obs = obs[0]
    done = False
    while not done:
        action, _states = m.predict(obs)
        obs, rewards, terminated, truncated, info = render_env.step(action)
        done = terminated or truncated
        render_env.render()
    render_env.close()

  logger.warn(


[A                             [A


  logger.warn(


[A                             [A


In [2]:
render_env = gym.make(f"PyFlyt/QuadX-Waypoints-v1", render_mode="human")
obs = render_env.reset()

[A                             [A


In [18]:
render_env.compute_auxiliary()

  logger.warn(


array([0.05140749, 0.05087556, 0.05333369, 0.04924014])

In [16]:
render_env.compute_attitude()
#ang_vel, ang_pos, lin_vel, lin_pos, quarternion

(array([ 0.00125275,  0.00190909, -0.00077839]),
 array([8.51291471e-05, 2.40113936e-04, 4.88786758e-06]),
 array([ 1.84097661e-04, -6.52745857e-05, -7.63067482e-01]),
 array([ 1.26732446e-08, -4.39806075e-09,  9.68138822e-01]),
 (4.256427980806423e-05,
  0.00012005707183480889,
  2.4388235954353965e-06,
  0.9999999918843169))

In [17]:
obs[0]["attitude"]

array([ 1.25275156e-03,  1.90908641e-03, -7.78390410e-04,  4.25642798e-05,
        1.20057072e-04,  2.43882360e-06,  9.99999992e-01,  1.84097661e-04,
       -6.52745857e-05, -7.63067482e-01,  1.26732446e-08, -4.39806075e-09,
        9.68138822e-01,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  5.14074940e-02,  5.08755614e-02,  5.33336916e-02,
        4.92401436e-02])

In [10]:
obs[0]["target_deltas"]

array([[ 0.68285049, -1.04936727,  3.16790635],
       [-0.65618658, -0.19189936,  1.34415656],
       [-0.21819811, -0.55136075,  0.57843157],
       [ 1.0706794 ,  0.41072173, -0.39996766]])

In [10]:
render_env.reset()

[A                             [A


(array([-6.83655089e-03, -9.86821756e-03, -2.47449865e-03, -1.64292553e-04,
        -1.48964536e-04, -2.98982793e-05,  9.99999975e-01, -2.28135780e-04,
         2.52316654e-04, -7.63178290e-01, -9.85124380e-09,  3.00167640e-08,
         9.68132543e-01,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  5.27517654e-02,  5.03095138e-02,  5.27818093e-02,
         5.08577356e-02,  1.19553976e+00,  1.54473374e+00, -6.08256510e-01]),
 {'out_of_bounds': False,
  'collision': False,
  'env_complete': False,
  'num_targets_reached': 0})