In [1]:
import enviroments_package
import gymnasium


world_dir = "/Users/jeste/Desktop/Clase/TFG/drone_tfg_juanes/simulation_package/worlds/my_frst_webots_world.wbt"
json_take_off = "/Users/jeste/Desktop/Clase/TFG/drone_tfg_juanes/configs/reward_package_config/takeoff.json"
json_basic = "/Users/jeste/Desktop/Clase/TFG/drone_tfg_juanes/configs/reward_package_config/basic_no_roll.json"
json_use_motors = "/Users/jeste/Desktop/Clase/TFG/drone_tfg_juanes/configs/reward_package_config/motors_use.json"

In [2]:
import gymnasium
from stable_baselines3.common.env_util import SubprocVecEnv
from enviroments_package import RemoveKeyObservationWrapper, ScaleRewardWrapper, ScaleActionWrapper


num_envs = 4  # Define el número de entornos que se van a crear

def make_env():
    def _init():
        # Crea el entorno base
        env = gymnasium.make('drone_tfg_juanes/Drone-v1', simulation_path=world_dir, reward_json_path=json_use_motors, no_render=False)

        # Aplica los wrappers necesarios
        env = RemoveKeyObservationWrapper(env, remove_keys=["camera", "gps"])
        env = ScaleRewardWrapper(env, scale_factor=0.1)
        env = ScaleActionWrapper(env, in_low=-1, in_high=1, out_low=0, out_high=576)
        return env
    return _init


env = SubprocVecEnv([make_env() for _ in range(num_envs)])

In [3]:
from stable_baselines3.common.callbacks import BaseCallback


class TrainingCallback(BaseCallback):
    def __init__(self, env, verbose=1):
        super(TrainingCallback, self).__init__(verbose)
        self.env = env

    def _on_step(self) -> bool:
        return True

    def _on_rollout_start(self) -> None:
        self.env.reset()

    def _on_training_end(self):
        print("Entrenamiento finalizado. Cerrando el entorno...")
        self.env.close()

callback = TrainingCallback(env=env)

In [None]:
import os
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.logger import configure


timesteps = 40960 #define los asos totales que se usarán para entrenar al modelo
log_dir = "./logs/"
os.makedirs(log_dir, exist_ok=True)

new_logger = configure(log_dir, ["stdout", "csv", "log"])

if not os.path.exists('./models/ppomodel.zip'):
    print("first train")

    model = RecurrentPPO(
        "MultiInputLstmPolicy",
        env,
        verbose=1,          # Si quiero ver las acciones por terminal
        n_steps=1024,       # Controla el buffer de experiencias para actualizar la política
        batch_size=64,      # Tamaño del lote, separa el buffer de experiencias en paquetes de este tamaño
        learning_rate=1e-3, # Tasa de aprendizaje
        ent_coef=0.2       # Coeficiente de entropía para exploración
    )
    model.set_logger(new_logger)
    model.learn(total_timesteps=timesteps, callback=callback)
    model.save('./models/ppomodel')
else:
    print("retrainning")

    model = RecurrentPPO.load("./models/ppomodel.zip", env=env)

    model.set_logger(new_logger)

    model.learn(total_timesteps=timesteps, callback=callback)
    model.save(path="./models/ppomodel")

Logging to ./logs/
first train
Using cuda device


In [None]:
import os
import shutil
from datetime import datetime


def move_and_rename_csv(src_dir, dst_dir, new_name):
    # Buscar el archivo CSV en el directorio fuente
    csv_files = [f for f in os.listdir(src_dir) if f.endswith('.csv')]

    # Verificar si hay algún archivo CSV en el directorio de origen
    if not csv_files:
        print("No se encontró ningún archivo CSV en el directorio de origen.")
        return

    # Tomar el primer archivo CSV encontrado
    csv_file = csv_files[0]
    src_path = os.path.join(src_dir, csv_file)
    dst_path = os.path.join(dst_dir, new_name)

    # Mover y renombrar el archivo
    shutil.copy2(src_path, dst_path)
    print(f"Archivo copiado y renombrado a {dst_path}")


src_directory = log_dir
dst_directory = './data_collected/'
new_filename = f'ppo{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'

move_and_rename_csv(src_directory, dst_directory, new_filename)

init = make_env()
env = init()
observation, _ = env.reset()

for i in range(100):
    action, _states = model.predict(observation, deterministic=True)
    observation, reward, terminated, truncated, _ = env.step(action)

    if terminated:
        observation, _ = env.reset()

env.close()