## Example & Ideas of Implementation for Multi-Agent

In [1]:
from dataclasses import dataclass
import casadi as ca
import numpy as np
import time
import math
from typing import Dict, Optional
import logging
import datetime
from tqdm import tqdm

from ocean_navigation_simulator.data_sources import OceanCurrentSource
from ocean_navigation_simulator.data_sources.SolarIrradiance.SolarIrradianceSource import SolarIrradianceSource
from ocean_navigation_simulator.data_sources.SeaweedGrowth.SeaweedGrowthSource import SeaweedGrowthSource
from ocean_navigation_simulator.utils import units
from ocean_navigation_simulator.environment.PlatformState import PlatformState
from ocean_navigation_simulator.environment.PlatformState import SpatialPoint
from ocean_navigation_simulator.environment.ArenaFactory import ArenaFactory
from ocean_navigation_simulator.environment.NavigationProblem import NavigationProblem
from ocean_navigation_simulator.controllers.hj_planners.HJReach2DPlanner import HJReach2DPlanner
from ocean_navigation_simulator.utils import units
import matplotlib.pyplot as plt
import os
os.chdir('/home/nicolas/documents/Master_Thesis_repo/OceanPlatformControl')
print(os.getcwd())


/home/nicolas/documents/Master_Thesis_repo/OceanPlatformControl


+ Create 3 Platforms as usual using the `PlatformState` class 

In [2]:
x_0_1 = PlatformState(lon=units.Distance(deg=-82.5), lat=units.Distance(deg=23.7),
                    date_time=datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc)) #here could also given an id property !
x_0_2 = PlatformState(lon=units.Distance(deg=-82.6), lat=units.Distance(deg=23.8),
                    date_time=datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc)) #eg. id = int(nb)
x_0_3 = PlatformState(lon=units.Distance(deg=-82.4), lat=units.Distance(deg=23.6),
                    date_time=datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc))
x_T = SpatialPoint(lon=units.Distance(deg=-80.3), lat=units.Distance(deg=24.6))

We can create a new class containing the platform states in a list form. By rewriting the `__array__()` method we can extract the states of all platform in a numpy array fashion with [lon,lat,date_time, battery_charge, seaweed_mass] as rows and each platform as column (corresponding to it's id)

In [3]:
from typing import List
import dataclasses
from dataclasses import astuple
@dataclasses.dataclass
class PlatformStateSet:
    states:List[PlatformState]

    def __array__(self):
        states_list = [np.array(state) for state in self.states]
        return np.array(states_list).T #columns: number of platforms

    def __len__(self):
        return len(self.states)

    def __getitem__(self, platform_id):
        return np.array(self.states[platform_id])

x_set = PlatformStateSet(states=[x_0_1, x_0_2, x_0_3])
np_x_set = np.array(x_set)
print("States numpy array= ", np_x_set, "of shape", np_x_set.shape)
print("number of platforms = ", len(x_set))
print("get platform 1 state", x_set[1])

States numpy array=  [[-8.2500000e+01 -8.2600000e+01 -8.2400000e+01]
 [ 2.3700000e+01  2.3800000e+01  2.3600000e+01]
 [ 1.6377552e+09  1.6377552e+09  1.6377552e+09]
 [ 1.0000000e+02  1.0000000e+02  1.0000000e+02]
 [ 1.0000000e+02  1.0000000e+02  1.0000000e+02]] of shape (5, 3)
number of platforms =  3
get platform 1 state [-8.2600000e+01  2.3800000e+01  1.6377552e+09  1.0000000e+02
  1.0000000e+02]


Import the known scenario to initialize `config`

In [4]:
import yaml
scenario_name = 'gulf_of_mexico_HYCOM_hindcast_local'
with open(f'config/arena/{scenario_name}.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
casadi_cache_dict=config['casadi_cache_dict']
platform_dict=config['platform_dict']
ocean_dict= config['ocean_dict']
use_geographic_coordinate_system=config['use_geographic_coordinate_system']
spatial_boundary=config['spatial_boundary']

In this way we can create a `NavigationProblem` where the start state is now of type `PlatformStateSet`

In [5]:
from ocean_navigation_simulator.environment.Problem import Problem
@dataclasses.dataclass
class NavigationProblem(Problem):
    start_state: PlatformStateSet 
    end_region: SpatialPoint
    target_radius: float
    timeout: datetime.timedelta = None
    platform_dict: dict = None
    x_range: List = None
    y_range: List = None
    extra_info: dict = None

    def plot(self, ax):
        pass

NavigationProblem(
    start_state=x_set,
    end_region=x_T,
    target_radius=0.1,
    timeout=datetime.timedelta(days=2),
    platform_dict=platform_dict)

NavigationProblem(start_state=PlatformStateSet(states=[Platform State[lon: -82.5 deg, lat: 23.7 deg, date_time: 2021-11-24 12:00:00+00:00, battery_charge: 100.0 Joule, seaweed_mass: 100.0 kg], Platform State[lon: -82.6 deg, lat: 23.8 deg, date_time: 2021-11-24 12:00:00+00:00, battery_charge: 100.0 Joule, seaweed_mass: 100.0 kg], Platform State[lon: -82.4 deg, lat: 23.6 deg, date_time: 2021-11-24 12:00:00+00:00, battery_charge: 100.0 Joule, seaweed_mass: 100.0 kg]]), end_region=[-80.300000°,24.60000°], target_radius=0.1, timeout=datetime.timedelta(days=2), platform_dict={'battery_cap_in_wh': 400.0, 'u_max_in_mps': 0.1, 'motor_efficiency': 1.0, 'solar_panel_size': 0.5, 'solar_efficiency': 0.2, 'drag_factor': 675.0, 'dt_in_s': 600.0}, x_range=None, y_range=None, extra_info=None)

### For this example, we just consider one data source (the most important for the simulation) which is `ocean_source`

In [6]:
from ocean_navigation_simulator.data_sources.OceanCurrentField import OceanCurrentField
from ocean_navigation_simulator.data_sources.SeaweedGrowthField import SeaweedGrowthField
from ocean_navigation_simulator.data_sources.SolarIrradianceField import SolarIrradianceField
from typing import Dict, Optional, Union, Tuple, List, AnyStr, Literal, Callable

ocean_field = OceanCurrentField(
    casadi_cache_dict=casadi_cache_dict,
    hindcast_source_dict=ocean_dict['hindcast'],
    forecast_source_dict=ocean_dict['forecast'],
    use_geographic_coordinate_system=use_geographic_coordinate_system)

ocean_source = ocean_field.hindcast_data_source

INFO:arena.ocean_field:DataField: Create Hindcast Source (20.8s)
INFO:arena.ocean_field:DataField: Forecast is the same as Hindcast for OceanCurrents.


For now just update/initialize the casadi cache around one of the platform as both platforms are initially very close. Later, we will rewrite `update_casadi_dynamics` which will now take a `PlatformStateSet` argument instead of `PlatformState`.

Interpolation works as usual -> for an ocean data source, it is defined in `OceanCurrentSource.py` 

In [7]:
ocean_source.update_casadi_dynamics(x_0_1)
print(ocean_source.u_curr_func, "\n", ocean_source.v_curr_func)

u_curr:(x[3])->(f) LinearInterpolant 
 v_curr:(x[3])->(f) LinearInterpolant


We deliberately skip the part with the planner and the arena observation, but it is likely that we will have also an `ArenaObservationSet` and run the the HJ Planner iteratively over this list of Observations to return a `PlatformActionSet` that could be implemented as below:
(here the action are hard-coded for simplification)

In [9]:
from ocean_navigation_simulator.environment.Platform import Platform, PlatformAction
@dataclasses.dataclass
class PlatformActionSet:
    action_set: List[PlatformAction]

    def __array__(self):
        action_list = [np.array(action) for action in self.action_set]
        return np.array(action_list).T #columns: number of platforms

action_x_0_1 = PlatformAction(magnitude=50, direction=1)
action_x_0_2 = PlatformAction(magnitude=25, direction=0.5)
action_x_0_3 = PlatformAction(magnitude=75, direction= 0)
action_x_set = PlatformActionSet(action_set = [action_x_0_1, action_x_0_2, action_x_0_3])
print("actions to numpy array ready for casadi first row = mag, second row = dir \n", np.array(action_x_set))

actions to numpy array ready for casadi first row = mag, second row = dir 
 [[50.  25.  75. ]
 [ 1.   0.5  0. ]]


#### The aim is test casADi ability to vectorize, so we first start to implement the computations directly outside of the function and without symbolic variables:

In [10]:
import casadi as ca
from ocean_navigation_simulator.utils import units

nb_platforms = len(x_set) # number of platforms
dt_in_s = platform_dict['dt_in_s']
print("Get platform states as a numpy multi dim array, where each column is a platform: \n", np.array(x_set))
lon_deg, lat_deg, t,_,_ = np.array(x_set) #extract only relevant states for our simplified problem
mat_state = ca.vertcat(t.reshape(1,nb_platforms), lat_deg.reshape(1,nb_platforms), lon_deg.reshape(1,nb_platforms)) # platforms as columns
print("platform state matrices shape", mat_state.shape) 

# Interpolation supports vectorization
u_curr = ocean_source.u_curr_func(mat_state) #needs to be 3 (time, lat,lon) x nb_platforms
v_curr = ocean_source.v_curr_func(mat_state)
print("current interpolation u = ", u_curr, "\ncurrent interpolation v = ", v_curr)

u_mag, u_angle = np.array(action_x_set)
u_max = units.Velocity(mps=platform_dict['u_max_in_mps'])
lon_delta_meters_per_s = ca.cos(u_angle.reshape(1,nb_platforms))*u_mag.reshape(1, nb_platforms)*u_max.mps+u_curr
lat_delta_meters_per_s = ca.sin(u_angle.reshape(1,nb_platforms))*u_mag.reshape(1, nb_platforms)*u_max.mps+v_curr
lon_delta_deg_per_s = 180 * lon_delta_meters_per_s / math.pi / 6371000 / ca.cos(math.pi * lat_deg.reshape(1,nb_platforms) / 180)
lat_delta_deg_per_s = 180 * lat_delta_meters_per_s / math.pi / 6371000

# Equations for next states using the intermediate variables from above
lon_next = lon_deg + dt_in_s * lon_delta_deg_per_s
lat_next = lat_deg + dt_in_s * lat_delta_deg_per_s
print("next lon position for the platforms =", lon_next, "\nnext lat position for the platforms = ", lat_next)

Get platform states as a numpy multi dim array, where each column is a platform: 
 [[-8.2500000e+01 -8.2600000e+01 -8.2400000e+01]
 [ 2.3700000e+01  2.3800000e+01  2.3600000e+01]
 [ 1.6377552e+09  1.6377552e+09  1.6377552e+09]
 [ 1.0000000e+02  1.0000000e+02  1.0000000e+02]
 [ 1.0000000e+02  1.0000000e+02  1.0000000e+02]]
platform state matrices shape (3, 3)
current interpolation u =  [[0.87875, 0.598949, 0.669905]] 
current interpolation v =  [[1.0565, 0.883948, 0.861923]]
next lon position for the platforms = [[-82.47890179 -82.58352898 -82.35189216]] 
next lat position for the platforms =  [[23.72840339 23.81123709 23.60465088]]


The interpolation function also returns multiple current values if we provide multiple platforms location !

+ Now we are ready to implement the platform dynamic updates for the respective control action to apply to each platforms with a CasADi `Function` and symbolic variables:

In [11]:
def get_casadi_dynamics(nb_platforms, ocean_source, u_max):
    state_row_dict = {"lon":0, "lat":1, "time":2, "battery":3, "seaweed":4 }
    action_row_dict = {"thrust":0, "angle":1}
    # sym_lon_degree      = ca.MX.sym('lon')          # in deg or m
    # sym_lat_degree      = ca.MX.sym('lat')          # in deg or m
    # sym_time            = ca.MX.sym('time')         # in posix
    # sym_battery         = ca.MX.sym('battery')      # in Joule
    # sym_seaweed_mass    = ca.MX.sym('battery')      # in Kg
    sym_dt              = ca.MX.sym('dt')           # in s
    sym_seaweed_mass    = ca.MX.sym('seaweed_mass', 1,nb_platforms)      # in Kg
    sym_u_thrust        = ca.MX.sym('u_thrust',1, nb_platforms)     # in % of u_max
    sym_u_angle         = ca.MX.sym('u_angle', 1, nb_platforms)      # in radians
    sym_inputs_states   = ca.MX.sym('inputs_states', 5, nb_platforms)
    sym_inputs_ctrl     = ca.MX.sym('inputs_ctrl', 2, nb_platforms)

    sym_lon_degree = sym_inputs_states[state_row_dict["lon"],:]
    sym_lat_degree = sym_inputs_states[state_row_dict["lat"],:]
    sym_time = sym_inputs_states[state_row_dict["time"],:]
    sym_battery = sym_inputs_states[state_row_dict["battery"],:]
    sym_seaweed_mass = sym_inputs_states[state_row_dict["seaweed"],:]
    sym_u_thrust = sym_inputs_ctrl[action_row_dict["thrust"],:]
    sym_u_angle = sym_inputs_ctrl[action_row_dict["angle"],:]

    # Get currents
    u_curr = ocean_source.u_curr_func(ca.vertcat(sym_time,sym_lat_degree, sym_lon_degree))
    v_curr = ocean_source.v_curr_func(ca.vertcat(sym_time, sym_lat_degree, sym_lon_degree))
    sym_lon_delta_meters_per_s = ca.cos(sym_u_angle)*sym_u_thrust*u_max.mps + u_curr
    sym_lat_delta_meters_per_s = ca.sin(sym_u_angle)*sym_u_thrust*u_max.mps + v_curr
    sym_lon_delta_deg_per_s = 180 * sym_lon_delta_meters_per_s / math.pi / 6371000 / ca.cos(math.pi * sym_lat_degree / 180)
    sym_lat_delta_deg_per_s = 180 * sym_lat_delta_meters_per_s / math.pi / 6371000

     # Equations for next states using the intermediate variables from above
    sym_lon_next = sym_lon_degree + sym_dt * sym_lon_delta_deg_per_s
    sym_lat_next = sym_lat_degree + sym_dt * sym_lat_delta_deg_per_s
    sym_time_next = sym_time + sym_dt
    # F_next = ca.Function('F_x_next', [ca.vertcat(sym_lon_degree, sym_lat_degree, sym_time, sym_battery, sym_seaweed_mass), ca.vertcat(sym_u_thrust, sym_u_angle), sym_dt],
    #                     [ca.vertcat(sym_lon_next, sym_lat_next, sym_time_next)])

    F_next = ca.Function('F_x_next', [sym_inputs_states, sym_inputs_ctrl, sym_dt ], [ca.vertcat(sym_lon_next, sym_lat_next, sym_time_next)] )
    return F_next
F_x_next = get_casadi_dynamics(nb_platforms, ocean_source, u_max)
next_state_np = np.array(F_x_next(np.array(x_set), np.array(action_x_set), dt_in_s))
print("Next state: rows [lon, lat, time] and columns [platforms] \n", next_state_np) # row lon, lat, time

Next state: rows [lon, lat, time] and columns [platforms] 
 [[-8.24789018e+01 -8.25835290e+01 -8.23518922e+01]
 [ 2.37284034e+01  2.38112371e+01  2.36046509e+01]
 [ 1.63775580e+09  1.63775580e+09  1.63775580e+09]]


Here `ca.vertcat(sym_lon_degree, sym_lat_degree, sym_time, sym_battery, sym_seaweed_mass)` does not work when `sym_lon_degree` etc. are symbolic matrices, casadi throws an error that the input argument is not purely symbolic. A work around is to define a symbolic matrix `sym_inputs_states` and plit it then into lon, lat, time etc (same for the control) so that we can still pass the state as numpy array through `np.array(x_set)`, where `x_set` is of type `PlatformStateSet`. 