In [1]:
import os
import time
from datetime import datetime
import argparse
import gymnasium as gym
import numpy as np
import torch as th
import pandas as pd
import csv

from stable_baselines3 import PPO,SAC,TD3
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from stable_baselines3.common.evaluation import evaluate_policy

from gym_pybullet_drones.utils.Logger import Logger
from gym_pybullet_drones.envs.HoverAviary import HoverAviary
from gym_pybullet_drones.envs.MultiHoverAviary import MultiHoverAviary
from gym_pybullet_drones.utils.utils import sync, str2bool
from gym_pybullet_drones.utils.enums import ObservationType, ActionType, Physics

from policies import GaussianMLPPolicy
#from server import Federated_RL

DEFAULT_GUI = True
DEFAULT_RECORD_VIDEO = True
DEFAULT_OUTPUT_FOLDER = 'results'
DEFAULT_COLAB = False
DEFAULT_DYNAMICS = Physics('pyb') # pyb: Pybullet dynamics; dyn: Explicit Dynamics specified in BaseAviary.py
DEFAULT_WIND = np.array([0, 0.05, 0]) # units are in induced newtons
DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb'
DEFAULT_ACT = ActionType('one_d_rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid'
DEFAULT_AGENTS = 4
DEFAULT_MA = False
DEFAULT_MASS = 0.037 # Actual default is 0.027

DR = True
MASS_RANGE = [0.027, 0.042] # Maximum recommended payload is 15g
WIND_RANGE = 0.005 # Inspired by literature

# Maintain consistent network structures
policy_kwargs = dict(activation_fn=th.nn.Tanh,
                     net_arch=dict(pi=[512, 512, 256, 128], qf=[32, 32]))

pybullet build time: Jun 24 2024 15:23:59


In [2]:
def train(multiagent=DEFAULT_MA, output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, colab=DEFAULT_COLAB, record_video=DEFAULT_RECORD_VIDEO, local=True):
    filename = 'TD3_test_run'
    if not os.path.exists(filename):
        os.makedirs(filename+'/')

    if not multiagent:
        train_env = make_vec_env(HoverAviary,
                                 env_kwargs=dict(obs=DEFAULT_OBS, act=DEFAULT_ACT),
                                 n_envs=1,
                                 seed=0
                                 )
        eval_env = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT)
    else:
        train_env = make_vec_env(MultiHoverAviary,
                                 env_kwargs=dict(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT),
                                 n_envs=1,
                                 seed=0
                                 )
        eval_env = MultiHoverAviary(num_drones=DEFAULT_AGENTS, obs=DEFAULT_OBS, act=DEFAULT_ACT)

    #### Check the environment's spaces ########################
    print('[INFO] Action space:', train_env.action_space)
    print('[INFO] Observation space:', train_env.observation_space)

    #### Train the model #######################################
    model = TD3('MlpPolicy',
                train_env,
                # tensorboard_log=filename+'/tb/',
                policy_kwargs=policy_kwargs,
                local_iterations=2,
                verbose=1)
    
        #### Target cumulative rewards (problem-dependent) ##########
    if DEFAULT_ACT == ActionType.ONE_D_RPM:
        target_reward = 474.15 if not multiagent else 949.5
    else:
        target_reward = 467. if not multiagent else 920.
    callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=target_reward,
                                                     verbose=1)
    eval_callback = EvalCallback(eval_env,
                                 callback_on_new_best=callback_on_best,
                                 verbose=1,
                                 best_model_save_path=filename+'/',
                                 log_path=filename+'/',
                                 eval_freq=int(1000),
                                 deterministic=True,
                                 DR=DR,
                                 mass_range=MASS_RANGE,
                                 wind_range=WIND_RANGE,
                                 render=False)
    
    model.learn(total_timesteps=int(1e7) if local else int(1e2), # shorter training in GitHub Actions pytest
                callback=eval_callback,
                log_interval=100)

    #### Save the model ########################################
    model.save(filename+'/final_model.zip')
    print(filename)
'''
    #### Print training progression ############################
    with np.load(filename+'/evaluations.npz') as data:
        for j in range(data['timesteps'].shape[0]):
            print(str(data['timesteps'][j])+","+str(data['results'][j][0]))
            '''

'\n    #### Print training progression ############################\n    with np.load(filename+\'/evaluations.npz\') as data:\n        for j in range(data[\'timesteps\'].shape[0]):\n            print(str(data[\'timesteps\'][j])+","+str(data[\'results\'][j][0]))\n            '

In [3]:
train()

[INFO] BaseAviary.__init__() loaded parameters from the drone's .urdf:
[INFO] m 0.027000, L 0.039700,
[INFO] ixx 0.000014, iyy 0.000014, izz 0.000022,
[INFO] kf 0.000000, km 0.000000,
[INFO] t2w 2.250000, max_speed_kmh 30.000000,
[INFO] gnd_eff_coeff 11.368590, prop_radius 0.023135,
[INFO] drag_xy_coeff 0.000001, drag_z_coeff 0.000001,
[INFO] dw_coeff_1 2267.180000, dw_coeff_2 0.160000, dw_coeff_3 -0.110000
[INFO] BaseAviary.__init__() loaded parameters from the drone's .urdf:
[INFO] m 0.027000, L 0.039700,
[INFO] ixx 0.000014, iyy 0.000014, izz 0.000022,
[INFO] kf 0.000000, km 0.000000,
[INFO] t2w 2.250000, max_speed_kmh 30.000000,
[INFO] gnd_eff_coeff 11.368590, prop_radius 0.023135,
[INFO] drag_xy_coeff 0.000001, drag_z_coeff 0.000001,
[INFO] dw_coeff_1 2267.180000, dw_coeff_2 0.160000, dw_coeff_3 -0.110000
[INFO] Action space: Box(-1.0, 1.0, (1, 1), float32)
[INFO] Observation space: Box([[-inf -inf   0. -inf -inf -inf -inf -inf -inf -inf -inf -inf  -1.  -1.
   -1.  -1.  -1.  -1.  

  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


[DR] Parameters for next episode: Mass 0.0397575567700256, Wind [0.0001946  0.002788   0.00164111]
[DR] Parameters for next episode: Mass 0.03608129981558965, Wind [0.00024259 0.00471997 0.00251289]
[DR] Parameters for next episode: Mass 0.029108596936224092, Wind [0.0029434  0.00206453 0.00397356]
[DR] Parameters for next episode: Mass 0.031680063916576, Wind [0.00028374 0.00191036 0.00125229]
[DR] Parameters for next episode: Mass 0.028723782937018643, Wind [0.00291476 0.00287723 0.00380578]
[DR] Parameters for next episode: Mass 0.03135578325620557, Wind [0.00215136 0.00351914 0.00037614]
[DR] Parameters for next episode: Mass 0.03676755908979388, Wind [0.00166258 0.00053949 0.00359754]
[DR] Parameters for next episode: Mass 0.040345952066650606, Wind [0.00456268 0.00058261 0.00236219]
[DR] Parameters for next episode: Mass 0.03152885531613361, Wind [0.00110908 0.00071049 0.00363601]
[DR] Parameters for next episode: Mass 0.03124409050114929, Wind [0.00319097 0.00148906 0.00197793]


In [23]:
model = SAC.load('/Users/kevinhan/opt/anaconda3/envs/drones/lib/python3.12/site-packages/Federated_RL/SAC_test_run/final_model.zip')

In [11]:
for i in model.get_parameters()['policy']:
    print(i + ': ' + str(model.get_parameters()['policy'][i].shape))

log_std: torch.Size([1])
mlp_extractor.policy_net.0.weight: torch.Size([512, 27])
mlp_extractor.policy_net.0.bias: torch.Size([512])
mlp_extractor.policy_net.2.weight: torch.Size([512, 512])
mlp_extractor.policy_net.2.bias: torch.Size([512])
mlp_extractor.policy_net.4.weight: torch.Size([256, 512])
mlp_extractor.policy_net.4.bias: torch.Size([256])
mlp_extractor.policy_net.6.weight: torch.Size([128, 256])
mlp_extractor.policy_net.6.bias: torch.Size([128])
mlp_extractor.value_net.0.weight: torch.Size([32, 27])
mlp_extractor.value_net.0.bias: torch.Size([32])
mlp_extractor.value_net.2.weight: torch.Size([32, 32])
mlp_extractor.value_net.2.bias: torch.Size([32])
action_net.weight: torch.Size([1, 128])
action_net.bias: torch.Size([1])
value_net.weight: torch.Size([1, 32])
value_net.bias: torch.Size([1])


In [12]:
print(model.get_parameters()['policy']['action_net.bias'])

tensor([0.0088])
