#### Notebook setup

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import notebook_setup

import matplotlib.pyplot as plt
import gym
import numpy as np
from numpy.linalg import norm
from tqdm.auto import tqdm, trange

from systems.trajplanning import TrajEnv
from systems.multirotor import MultirotorEnv
from systems.plotting import plot_env_response
from systems.multirotor import Multirotor, VP, SP, get_controller
from multirotor.helpers import DataLog
from multirotor.visualize import plot_datalog
from multirotor.coords import direction_cosine_matrix, inertial_to_body
from multirotor.env import SpeedsMultirotorEnv as LocalOctorotor
from multirotor.trajectories import Trajectory, GuidedTrajectory
from rl import learn_rl, transform_rl_policy
from xform import policy_transform, ab_xform_from_pseudo_matrix, pseudo_matrix_from_data

## Trajectory Adaptation

In [None]:
def plot_disturbance(fn, ax, xlims, ylims, n=10):
    distx = np.zeros((n, n))
    disty = np.zeros_like(distx)
    xs = np.linspace(*xlims, num=n)
    ys = np.linspace(*ylims, num=n)
    for xi, x in enumerate(xs):
        for yi, y in enumerate(ys):
            vec = fn(0, (x,y))
            distx[yi,xi] = vec[0]
            disty[yi,xi] = vec[1]
    xx, yy = np.meshgrid(xs, ys)
    distmag = np.sqrt(distx**2 + disty**2)
    largest = np.max(distmag)
    distx = distx * distmag / largest
    disty = disty * distmag / largest
    if ax is None:
        plt.quiver(xx, yy, distx, disty, angles='xy')
    else:
        ax.quiver(xx, yy, distx, disty, angles='xy')

In [None]:
def plot_planning(env, agent, pos, title=''):
    x = env.reset([-pos[0], -pos[1],0,0])
    shortest_path =  np.asarray([[0,0], env._start_pos]).T
    positions = []
    velocities = []
    rewards = []
    actions = []
    done = False
    while not done and env.t < 20:
        action = agent.predict(x, deterministic=True)[0]
        x, r, done, *_ = env.step(action)
        actions.append(action)
        positions.append(-x[:2])
        velocities.append(x[2:4])
        rewards.append(r)
    actions = np.asarray(actions).T
    positions = np.asarray(positions).T
    velocities = np.asarray(velocities).T
    rewards = np.asarray(rewards)

    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(positions[0], positions[1])
    plt.plot(shortest_path[0], shortest_path[1], ls='--', c='k')
    plot_disturbance(env.disturbance, None, (pos[0],0), (pos[1],0))
    plt.axis('equal')
    plt.xlabel('x /m', c='r')
    plt.ylabel('y /m', c='g')
    plt.title(title)
    plt.subplot(1,2,2)
    t = np.arange(velocities.shape[1]) * env.dt
    lx, = plt.plot(t, velocities[0], label='vel-x', c='r')
    ly, =plt.plot(t, velocities[1], label='vel-y', c='g')
    plt.ylabel('Velocity / m/s')
    plt.xlabel('Time /s')
    plt.title('Reward: %.2f' % sum(rewards))
    plt.twinx()
    plt.plot(t, actions[0], label='$\Delta$x', ls=':', c='r')
    plt.plot(t, actions[1], label='$\Delta$y', ls=':', c='g')
    plt.plot(t, positions[0], label='x', ls='-.', c='r')
    plt.plot(t, positions[1], label='y', ls='-.', c='g')
    plt.ylabel('Waypoint / m')
    plt.legend(handles=[lx, ly] +plt.gca().lines)
    plt.xlabel('Time / s')
    plt.tight_layout()

In [None]:
# Nominal case
class DummyAgent:
    
    def predict(self, u, *args, **kwargs):
        return np.zeros(2, np.float32), None

def disturbance(t, x):
    return np.asarray([1,0], np.float32)

In [None]:
# Learn nominal behavior
agent = learn_rl(TrajEnv(disturbance), steps=100_000, n_steps=1200, gamma=0.99, batch_size=300,
                 tensorboard_log='TrajEnv/ConstantXWind')

In [None]:
# Adapt to new behavior
def new_disturbance(t, pos):
    return np.asarray([-1, 0.], np.float32)

nA_s, nB_s, nF_A, nF_B, nxxform, nuxform = 1,1,1,1,1,1
dA_s, dB_s, dF_A, dF_B, dxxform, duxform = 1,1,1,1,1,1

for steps in trange(100, 4000, 200, leave=False):
    P_s = pseudo_matrix_from_data(TrajEnv(disturbance), steps, agent, 'steps')
    P_t = pseudo_matrix_from_data(TrajEnv(new_disturbance), steps, agent, 'steps')
    A_s, B_s, A_t, B_t, F_A, F_B = ab_xform_from_pseudo_matrix(P_s, P_t, TrajEnv().dt)
    state_xform, action_xform = policy_transform((A_s, B_s), xformA=F_A, xformB=F_B)
    dA_s, nA_s = norm(A_s) / nA_s, norm(A_s)
    dB_s, nB_s = norm(B_s) / nB_s, norm(B_s)
    dF_A, nF_A = norm(F_A) / nF_A, norm(F_A)
    dF_B, nF_B = norm(F_B) / nF_B, norm(F_B)
    dxxform, nxxform = norm(state_xform) / nxxform, norm(state_xform)
    duxform, nuxform = norm(action_xform) / nuxform, norm(action_xform)
    print('A_s %5.2f, B_s %5.2f, F_A %5.2f, F_B %5.2f, Xxform %5.2f, Uxform %5.2f' \
          % (dA_s, dB_s, dF_A, dF_B, dxxform, duxform))
    if all([abs(1-d) <= 0.1 for d in (dA_s, dB_s, dF_A, dF_B, dxxform, duxform)]):
        break

agent_new = transform_rl_policy(agent, state_xform, action_xform)

In [None]:
print('state_xform', np.linalg.norm(state_xform))
print('action_xform', np.linalg.norm(action_xform))
print('F_A', np.linalg.norm(F_A))
print('F_B', np.linalg.norm(F_B))
print('A_t', np.linalg.norm(A_t))
print('B_t', np.linalg.norm(B_t))

In [None]:
pos = (5,5)

In [None]:
# Agent tuned on original disturbance
%matplotlib inline
plot_planning(TrajEnv(disturbance), agent, pos, 'Optimized on eastwards wind')

In [None]:
# Agent applied to new disturbance
%matplotlib inline
plot_planning(TrajEnv(new_disturbance), agent, pos, 'Applied to west wind')

In [None]:
# Adapted agent on new disturbance
%matplotlib inline
plot_planning(TrajEnv(new_disturbance), agent_new, pos, 'Adapted on west wind')

## Simulation

In [None]:
def wind(i, m):
    w_inertial = np.asarray(
        [20 * np.sin(i * 2 * np.pi / 1000),
         10 * np.sin(i * 2 * np.pi / 500),
         0])
    dcm = direction_cosine_matrix(*m.orientation)
    return inertial_to_body(w_inertial, dcm), 0
def motor_failure(i, m, motors=(2,)):
    if i > 1000:
        for motor_num in motors:
            # m.propellers[motor_num].motor.speed = 400
            # m.propellers[motor_num].motor._last_angular_acc = 0.
            # m.propellers[motor_num].speed = 400
            m.propellers[motor_num].motor.params.k_torque = 0.005
    return 0, 0
def battery_degrade(i, m):
    m.battery.params.max_voltage /= 2
    m.battery.voltage = m.battery.params.max_voltage / 2
    return 0, 0

In [None]:
def run_sim(env, traj, steps=60_000, disturbance=None):
    ctrl = get_controller(env.vehicle, max_velocity=5.)

    log = DataLog(env.vehicle, ctrl,
                  other_vars=('speeds','target', 'alloc_errs', 'att_err',
                              'rate_target', 'att_target',
                              'leash', 'currents', 'voltages'))
    disturb_force, disturb_torque = 0., 0
    for i, (pos, feed_forward_vel) in tqdm(
        enumerate(traj), leave=False, total=steps
    ):
        if i==steps: break
        # Generate reference for controller
        ref = np.asarray([*pos, 0.])
        # Get prescribed dynamics for system as thrust and torques
        dynamics = ctrl.step(ref, feed_forward_velocity=feed_forward_vel)
        thrust, torques = dynamics[0], dynamics[1:]
        # Allocate control: Convert dynamics into motor rad/s
        action = env.vehicle.allocate_control(thrust, torques)
        # get any disturbances
        if disturbance is not None:
            disturb_force, disturb_torque = disturbance(i, env.vehicle)
        # Send speeds to environment
        state, *_ = env.step(
            action, disturb_forces=disturb_force, disturb_torques=disturb_torque
        )
        alloc_errs = np.asarray([thrust, *torques]) - env.vehicle.alloc @ action**2

        log.log(speeds=action, target=pos, alloc_errs=alloc_errs,
                leash=ctrl.ctrl_p.leash,
                att_err=ctrl.ctrl_a.err,
                att_target = ctrl.ctrl_v.action[::-1],
                rate_target=ctrl.ctrl_a.action,
                currents=[p.motor.current for p in env.vehicle.propellers],
                voltages=[p.motor.voltage for p in env.vehicle.propellers])

        if np.any(np.abs(env.vehicle.orientation[:2]) > np.pi/6): break

    log.done_logging()
    return log

In [None]:
%matplotlib inline
SP.dt=1e-2
waypoints = 2 * np.asarray(
    [[0,50,2], [25,60,2], [50,50,2], [60,25,2], [50,0,2], [25,-10,2], [0,0,2]]
)
env = LocalOctorotor(vehicle=Multirotor(VP, SP))
# traj = GuidedTrajectory(env.vehicle, waypoints, proximity=2)
traj = Trajectory(env.vehicle, waypoints, proximity=2, resolution=None)

log = run_sim(env, traj, steps=10000, disturbance=wind)
plot_datalog(log)

In [None]:
# 3D Trajectory
%matplotlib notebook
ax = plt.figure().add_subplot(projection='3d')
ax.plot(log.x, log.y, log.z)
ax.plot(log.target[:,0], log.target[:,1], log.target[:,2], c='k', ls=':')
ax.set_xlabel('x /m', c='r')
ax.set_ylabel('y /m', c='g')
ax.set_zlabel('z /m', c='b')
# ax.set_box_aspect(list(map(np.ptp, (log.x, log.y, log.z))))

### Other plots

In [None]:
# Attitude control
%matplotlib inline
plt.plot(log.att_err[:,1] * 180 / np.pi, label='Roll error')
plt.plot(log.att_target[:,1] * 180 / np.pi, label='Roll target')
plt.plot(log.pitch * 180 / np.pi, label='Roll')
plt.ylabel('Attitudes', c='b')
plt.legend()
# plt.twinx()
# plt.plot(log.att_action[:, 0], c='r', label='Roll rate')
# plt.ylabel('Rate', c='r')
# plt.xlabel('Time /ms')

In [None]:
# Leashing
%matplotlib inline
plt.plot(np.clip(log.leash, 0, 20))
plt.ylabel('Leash length /m')
plt.xlabel('Time /ms')
plt.title('Leashing position error')

In [None]:
%matplotlib inline
from multirotor.controller.pid import sqrt_control

err = np.linspace(0, 10, num=100)
acc = 1
k_p = 1.
corr = []
for e in err:
    corr.append(sqrt_control(e, k_p, acc, 0.001))
plt.plot(err, corr, c='orange')
plt.ylabel('$K_P$')
plt.xlabel('Position Error /m')
plt.title('Square root proportional control')

In [None]:
wp = np.asarray([
    [0.0, 0.0, 30.0],
    [164.0146725649829, -0.019177722744643688, 30.0],
    [165.6418055187678, 111.5351051245816, 30.0],
    [127.3337449710234, 165.73576059611514, 30.0],
    [-187.28170707810204, 170.33217775914818, 45.0],
    [-192.03130502498243, 106.30660058604553, 45.0],
    [115.89920266153058, 100.8644210617058, 30.0],
    [114.81859536317643, 26.80923518165946, 30.0],
    [-21.459931490011513, 32.60508110653609, 30.0]
])