# Wind Impact on Cardinal Performance

How robust is the RL architecture trained on only cardinal wind to all wind?
Many of these functions are the same as in the `PID Wind Impact` notebook.

In [1]:
from systems.long_multirotor import LongTrajEnv

from typing import Union, Iterable, List
from copy import deepcopy
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm.autonotebook import tqdm, trange
import optuna

from rl import learn_rl, transform_rl_policy, evaluate_rl, PPO, load_agent
from multirotor.simulation import Multirotor
from multirotor.helpers import DataLog
from multirotor.visualize import plot_datalog
from multirotor.controller import Controller
from multirotor.trajectories import Trajectory, GuidedTrajectory
from multirotor.controller.scurves import SCurveController
from multirotor.coords import body_to_inertial, inertial_to_body, direction_cosine_matrix, euler_to_angular_rate
from systems.multirotor import MultirotorTrajEnv, VP
from multirotor.controller import (
    AltController, AltRateController,
    PosController, AttController,
    VelController, RateController,
    Controller
)
from scripts.opt_pidcontroller import (
    get_controller, make_disturbance_fn,
    apply_params as apply_params_pid, get_study as get_study_pid
)
from scripts.opt_multirotorenv import apply_params, get_study, get_established_controller

In [2]:
def get_env(wind_ranges, scurve=False, **kwargs):  
    kw = dict(
        safety_radius=5,
        vp=VP,get_controller_fn=lambda m: get_established_controller(m),
        steps_u=kwargs['steps_u'],
        scaling_factor=kwargs['scaling_factor'],
        wind_ranges=wind_ranges,
        proximity=5, 
        seed=0)
    return MultirotorTrajEnv(**kw)

In [3]:
log_root_path = './tensorboard/MultirotorTrajEnv/optstudy/%s/'
def get_study_agent_params(name):
    study = get_study(name)
    best_trial = study.best_trial.number
    best_agent = load_agent((log_root_path + '%03d/run_1/agent') % (name, best_trial)) # 63 is cardinal low, best_trial is cardinal high
    best_params = study.best_params
    return study, best_agent, best_params

In [4]:
study, best_agent, best_params = get_study_agent_params('cardinal@vel')

[I 2023-08-24 18:43:40,646] Using an existing study with name 'cardinal@vel' instead of creating a new one.


In [5]:
best_params

{'bounding_rect_length': 5,
 'steps_u': 31,
 'scaling_factor': 2.0,
 'learning_rate': 0.00015208522535908772,
 'n_epochs': 3,
 'n_steps': 304,
 'batch_size': 160,
 'training_interactions': 200000}

In [6]:
square_np = np.array([[100,0,0], [100,100,0], [0,100,0], [0,0,0]])
square_traj = Trajectory(None, points=square_np, resolution=best_params['bounding_rect_length'])
square_wpts = square_traj.generate_trajectory(curr_pos=np.array([0,0,0]))

In [7]:
def get_tte(initial_pos: tuple, waypoints: np.ndarray, x: np.ndarray, y:np.ndarray, z:np.ndarray) -> np.ndarray:
        """
        Calculates the trajectory tracking error. 
        The distance between the current point and the vector between previous and next wp. Uses ||v1 x v2|| / ||v1||.

        Parameters
        ----------
        initial_pos : tuple  
            the initial position of the UAV.
        waypoints : np.ndarray 
            the reference positions at each point in time.
        x : np.ndarray 
            the x positions of the UAV.
        y : np.ndarray 
            the y positions of the UAV.
        z : np.ndarray
            the z positions of the UAV.

        Returns
        -------
        np.ndarray 
            the trajectory tracking error at each point in time.
        """
        ttes = []
        prev = initial_pos
        for i, waypoint in enumerate(waypoints):
            if i > 0 and not np.array_equal(waypoints[i-1], waypoints[i]):
                prev = waypoints[i-1]

            v1 = waypoint - prev
            v2 = np.array([x[i],y[i],z[i]]) - prev
            tte = np.linalg.norm(np.cross(v1, v2)) / np.linalg.norm(v1)
            ttes.append(tte)
                
        return np.array(ttes)

In [8]:
def toc(tte: np.ndarray):
    corridor = 5
    return len(tte[tte > corridor]) / 100

In [9]:
def completed_mission(waypoints: np.ndarray, x: np.ndarray, y: np.ndarray, z: np.ndarray, radius: float = 0.65):
        for waypoint in waypoints:
            reached_waypoint = False

            for position in zip(x,y,z):
                dist = np.linalg.norm(waypoint - position)

                if dist <= radius:
                    reached_waypoint = True
                    break

            if not reached_waypoint:
                return False
            
        return True

In [10]:
def run_trajectory(wind_ranges: np.ndarray, agent, params):
    env = LongTrajEnv(
        waypoints = square_wpts,
        base_env = get_env(wind_ranges, **params),
        initial_waypoints = square_np,
        random_cardinal_wind=False
    )
    done = False
    state = env.reset()
    log = DataLog(env.base_env.vehicle, env.base_env.ctrl,
                      other_vars=('reward',))
    while not done:
        action = agent.predict(state, deterministic=True)[0]
        state, reward, done, info = env.step(action)
        log.log(reward=reward)

    log.done_logging()
    return log, info

In [11]:
wind_range_dict = {
    'zero': [(0,0), (0,0), (0,0)],
    'n5': [(0,0), (5,5), (0,0)],
    'n7': [(0,0), (7,7), (0,0)],
    'n10': [(0,0), (10,10), (0,0)],
    's5': [(0,0), (-5,-5), (0,0)],
    's7': [(0,0), (-7,-7), (0,0)],
    's10': [(0,0), (-10,-10), (0,0)],
    'e5': [(5,5), (0,0), (0,0)],
    'e7': [(7,7), (0,0), (0,0)],
    'e10': [(10,10), (0,0), (0,0)],
    'w5': [(-5,-5), (0,0), (0,0)],
    'w7': [(-7,-7), (0,0), (0,0)],
    'w10': [(-10,-10), (0,0), (0,0)],
    'nw5': [(-3.53553391,-3.53553391), (3.53553391,3.53553391), (0,0)],
    'nw7': [(-4.94974747,-4.94974747), (4.94974747,4.94974747), (0,0)],
    'nw10': [(-7.07106781,-7.07106781), (7.07106781,7.07106781), (0,0)],
    'sw5': [(-3.53553391,-3.53553391), (-3.53553391,-3.53553391), (0,0)],
    'sw7': [(-4.94974747,-4.94974747), (-4.94974747,-4.94974747), (0,0)],
    'sw10': [(-7.07106781,-7.07106781), (-7.07106781,-7.07106781), (0,0)],
    'ne5': [(3.53553391,3.53553391), (3.53553391,3.53553391), (0,0)],
    'ne7': [(4.94974747,4.94974747), (4.94974747,4.94974747), (0,0)],
    'ne10': [(7.07106781,7.07106781), (7.07106781,7.07106781), (0,0)],
    'se5': [(3.53553391,3.53553391), (-3.53553391,-3.53553391), (0,0)],
    'se7': [(4.94974747,4.94974747), (-4.94974747,-4.94974747), (0,0)],
    'se10': [(7.07106781,7.07106781), (-7.07106781,-7.07106781), (0,0)],
}

In [12]:
wind_results = pd.DataFrame(columns=['Wind', 'Total TTE', 'Mean TTE', 'Completed Mission', 'Reward', 'Time Outside Corridor'])

In [13]:
def run_wind_sweep(results, wind_dict, agent, params):
    for wind in tqdm(wind_dict.keys()):
        log, info = run_trajectory(wind_dict[wind], agent, params)
        traj_err = get_tte(np.array([0,0,0]), log.states[:,12:], log.x, log.y, log.z)
        new_result = {
            'Wind': wind,
            'Mean TTE': np.mean(traj_err),
            'Total TTE': np.sum(traj_err),
            'Completed Mission': completed_mission(square_wpts, log.x, log.y, log.z, radius=5),
            'Reward': np.sum(log.reward),
            'Time Outside Corridor': toc(traj_err)
        }
        results = pd.concat([results, pd.DataFrame([new_result])], ignore_index=True)

    return results

In [14]:
wind_results = run_wind_sweep(wind_results, wind_range_dict, best_agent, best_params)

  0%|          | 0/25 [00:00<?, ?it/s]

In [17]:
wind_results.to_csv('./data/cardinal_wind_vel.csv')

In [18]:
wind_results = pd.read_csv('./data/cardinal_wind_vel.csv') # or read cardinal_wind.csv for the Cardinal High agent

In [19]:
wind_results

Unnamed: 0.1,Unnamed: 0,Wind,Total TTE,Mean TTE,Completed Mission,Reward,Time Outside Corridor
0,0,zero,687.037243,1.36047,True,6000.0,0.0
1,1,n5,577.63363,1.176443,True,6000.0,0.0
2,2,n7,712.468336,1.40804,True,6000.0,0.0
3,3,n10,2431.602736,3.706711,True,2718.0,1.06
4,4,s5,576.076136,1.16379,True,6000.0,0.0
5,5,s7,674.410028,1.324971,True,6000.0,0.0
6,6,s10,1253.122143,2.225794,True,4922.0,0.35
7,7,e5,790.942483,1.529869,True,6000.0,0.0
8,8,e7,949.796675,1.778645,True,6000.0,0.0
9,9,e10,3584.229848,4.823997,True,2381.0,1.17
