# Compare speed of casADi for multiple points interpolation vs. JAX libraries

In [1]:
from dataclasses import dataclass
import casadi as ca
import numpy as np
import time
import math
from typing import Dict, Optional
import logging
import datetime
import timeit
from tqdm import tqdm

from ocean_navigation_simulator.data_sources import OceanCurrentSource
from ocean_navigation_simulator.data_sources.SolarIrradiance.SolarIrradianceSource import SolarIrradianceSource
from ocean_navigation_simulator.data_sources.SeaweedGrowth.SeaweedGrowthSource import SeaweedGrowthSource
from ocean_navigation_simulator.utils import units
from ocean_navigation_simulator.environment.PlatformState import PlatformState
from ocean_navigation_simulator.environment.PlatformState import SpatialPoint
from ocean_navigation_simulator.environment.ArenaFactory import ArenaFactory
from ocean_navigation_simulator.environment.NavigationProblem import NavigationProblem
from ocean_navigation_simulator.controllers.hj_planners.HJReach2DPlanner import HJReach2DPlanner
from ocean_navigation_simulator.utils import units
import matplotlib.pyplot as plt
import os
os.chdir('/home/nicolas/documents/Master_Thesis_repo/OceanPlatformControl')
print(os.getcwd())

/home/nicolas/documents/Master_Thesis_repo/OceanPlatformControl


+ Import the ocean source

In [2]:
from ocean_navigation_simulator.data_sources.OceanCurrentField import OceanCurrentField
from ocean_navigation_simulator.data_sources.SeaweedGrowthField import SeaweedGrowthField
from ocean_navigation_simulator.data_sources.SolarIrradianceField import SolarIrradianceField
from typing import Dict, Optional, Union, Tuple, List, AnyStr, Literal, Callable
import yaml

scenario_name = 'gulf_of_mexico_HYCOM_hindcast_local'
with open(f'config/arena/{scenario_name}.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
casadi_cache_dict=config['casadi_cache_dict']
platform_dict=config['platform_dict']
ocean_dict= config['ocean_dict']
use_geographic_coordinate_system=config['use_geographic_coordinate_system']
spatial_boundary=config['spatial_boundary']

ocean_field = OceanCurrentField(
    casadi_cache_dict=casadi_cache_dict,
    hindcast_source_dict=ocean_dict['hindcast'],
    forecast_source_dict=ocean_dict['forecast'],
    use_geographic_coordinate_system=use_geographic_coordinate_system)

ocean_source = ocean_field.hindcast_data_source

INFO:arena.ocean_field:DataField: Create Hindcast Source (5.6s)
INFO:arena.ocean_field:DataField: Forecast is the same as Hindcast for OceanCurrents.


+ Create the typical implementation of PlatformStateSet

In [3]:
from typing import List
import dataclasses
from dataclasses import astuple
@dataclasses.dataclass
class PlatformStateSet:
    states:List[PlatformState]

    def __array__(self):
        return np.array(self.states) #rows are the number of platforms

    def __len__(self):
        return len(self.states)

    def __getitem__(self, platform_id):
    #     return np.array(self.states[platform_id])
        return self.states[platform_id]

    def __post_init__(self):
        self.lon_arr = np.array(self.states)[:,0]
        self.lat_arr = np.array(self.states)[:,1]
        self.date_time_arr = np.array(self.states)[:,2]
        self.battery_charge = np.array(self.states)[:,3]
        self.seaweed_mass = np.array(self.states)[:,4]

+ Functions to generate random platform states

In [4]:
def rand_platforms(nb_platforms):
    lon = np.random.uniform(low=-83.5, high=-83, size=(nb_platforms,))
    lat = np.random.uniform(low=23, high=23.5, size=(nb_platforms,))
    id = np.arange(start=1, stop=nb_platforms+1, step=1)
    t = np.repeat(datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc), nb_platforms)
    return lon,lat,id,t
    
def get_platform(nb_platforms):
    lon,lat,id,t = rand_platforms(nb_platforms)
    platforms_list = [PlatformState(lon=units.Distance(deg=lon[k]), lat=units.Distance(deg=lat[k]), date_time=t[k]) \
                      for k in range(nb_platforms)]
    # create the Set class for the platforms (version 1 with list of Platform states)
    platforms_set = PlatformStateSet(states=platforms_list)
    return platforms_set

**Rewrite** `convert_to_x_y_time_bounds` to support multiple platform states, making use of the methods implemented in PlatformStateSet

In [5]:
nb_platforms = 20
platforms_set = get_platform(nb_platforms)

In [6]:
def convert_to_x_y_time_bounds(
    states: PlatformStateSet,
    x_T: SpatialPoint,
    deg_around_x0_xT_box: float,
    temp_horizon_in_s: float,
):
    """Helper function for spatio-temporal subsetting
    Args:
        x_0: SpatioTemporalPoint
        x_T: SpatialPoint goal locations
        deg_around_x0_xT_box: buffer around the box in degree
        temp_horizon_in_s: maximum temp_horizon to look ahead of x_0 time in seconds

    Returns:
        t_interval: if time-varying: [t_0, t_T] as utc datetime objects
                    where t_0 and t_T are the start and end respectively
        lat_bnds: [y_lower, y_upper] in degrees
        lon_bnds: [x_lower, x_upper] in degrees
    """

    t_interval = [datetime.datetime.fromtimestamp(min(platforms_set.date_time_arr), tz=datetime.timezone.utc), \
                  datetime.datetime.fromtimestamp(max(platforms_set.date_time_arr), tz=datetime.timezone.utc) + datetime.timedelta(seconds=temp_horizon_in_s)]
    lon_bnds = [
        min(min(states.lon_arr), x_T.lon.deg) - deg_around_x0_xT_box,
        max(max(states.lon_arr), x_T.lon.deg) + deg_around_x0_xT_box,
    ]
    lat_bnds = [
        min(min(states.lat_arr), x_T.lat.deg) - deg_around_x0_xT_box,
        max(max(states.lat_arr), x_T.lat.deg) + deg_around_x0_xT_box,
    ]

    return t_interval, lat_bnds, lon_bnds

+ Create random platforms

In [7]:
# Step 1: Create the intervals to query data for
t_interval, y_interval, x_interval, = convert_to_x_y_time_bounds(
    states=platforms_set,
    x_T=platforms_set[0].to_spatial_point(),
    deg_around_x0_xT_box=ocean_source.source_config_dict["casadi_cache_settings"]["deg_around_x_t"],
    temp_horizon_in_s=ocean_source.source_config_dict["casadi_cache_settings"]["time_around_x_t"],
)
print(t_interval, x_interval, y_interval)
# Step 2: Get the data from itocean_source and update casadi_grid_dict
xarray = ocean_source.get_data_over_area(x_interval, y_interval, t_interval)
ocean_source.casadi_grid_dict = ocean_source.get_grid_dict_from_xr(xarray)

# Step 3: Set up the grid with time lat lon order as for the xarray dataset
grid = [
    units.get_posix_time_from_np64(xarray.coords["time"].values),  
    xarray.coords["lat"].values,
    xarray.coords["lon"].values,
]

[datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc), datetime.datetime(2021, 11, 26, 12, 0, tzinfo=datetime.timezone.utc)] [-83.96089419441918, -82.55584921681174] [22.53840774842949, 23.989400836068118]


### CasADi Interpolation

In [8]:
ocean_source.initialize_casadi_functions(grid, xarray)
lon_lat_t = np.array(platforms_set)[:,[2,1,0]].T #extract time, lat, lon with platforms as columns for casADi interpolation
u_curr_cas = ocean_source.u_curr_func(lon_lat_t)
v_curr_cas = ocean_source.v_curr_func(lon_lat_t)
print("u_curr interpolation results: ", u_curr_cas,"\nv_curr interpolation results: ", v_curr_cas)

u_curr interpolation results:  [[0.785773, 1.19858, 0.742993, 1.1934, nan, 0.73964, 1.18849, 1.18441, 1.04891, 0.639861, 1.09112, 1.05455, 1.098, 1.0263, 0.875028, 0.723886, 0.730927, 1.17588, 1.11193, 1.02551]] 
v_curr interpolation results:  [[0.293336, 0.361847, 0.0298195, 0.457341, nan, 0.0672256, 0.331299, 0.440071, 0.384896, -0.272115, 0.395701, 0.196144, 0.191956, 0.264372, 0.0488227, 0.0217272, 0.0214983, 0.312047, 0.241093, 0.0693245]]


In [9]:
import timeit

In [10]:
%%timeit
u_curr = ocean_source.u_curr_func(lon_lat_t)
v_curr = ocean_source.v_curr_func(lon_lat_t)

63.2 µs ± 8.54 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### JAX Interpolation

In [11]:
import jax.numpy as jnp
import jax.scipy.ndimage
from jax import jit
from scipy import interpolate

In [12]:
@jit
def init_lin_interpo_3D_fields(states_np, u_curr, v_curr, grid_eval):
    idx_point = jnp.array(# time lat lon order as for the xarray dataset
    [
        jnp.interp(states_np[:,2], grid_eval[0], np.arange(len(grid[0]))),
        jnp.interp(states_np[:,1], grid_eval[1], np.arange(len(grid[1]))),
        jnp.interp(states_np[:,0], grid_eval[2], np.arange(len(grid[2]))),
    ]
    )
    u_curr_interp = jax.scipy.ndimage.map_coordinates(u_curr, idx_point, order=1)
    v_curr_interp = jax.scipy.ndimage.map_coordinates(v_curr, idx_point, order=1)
    return u_curr_interp, v_curr_interp

In [13]:
from functools import partial
get_current_interp = partial(init_lin_interpo_3D_fields, u_curr=xarray.water_u.values, v_curr = xarray.water_v.values, \
                            grid_eval = grid)

In [14]:
u_curr_jax, v_curr_jax = get_current_interp(np.array(platforms_set))
print("u_curr interpolation results: ", u_curr_jax, "\nv_curr interpolation results: ", v_curr_jax)
print("diff jax - casADi\n", "u_curr", u_curr_jax-u_curr_cas.T, "v_curr", v_curr_jax-v_curr_cas.T )



u_curr interpolation results:  [0.7857713  1.1985755  0.7429931  1.1934026         nan 0.7396395
 1.1884897  1.1844103  1.0489087  0.63986397 1.0911233  1.0545527
 1.0980009  1.0263036  0.8750268  0.7238873  0.73092705 1.1758838
 1.1119236  1.025509  ] 
v_curr interpolation results:  [ 0.29333332  0.36184558  0.02981923  0.4573411          nan  0.06722391
  0.33129844  0.4400713   0.38488737 -0.27210727  0.3957012   0.19613881
  0.1919581   0.2643722   0.04882056  0.02173114  0.02149608  0.3120486
  0.24108893  0.0693258 ]
diff jax - casADi
 u_curr [-2.03377e-06, -6.52945e-07, 3.2205e-07, 9.70939e-08, nan, -7.55515e-07, -7.69353e-07, -6.57877e-07, 3.03647e-06, 2.63617e-06, 3.46762e-07, -1.23693e-06, 2.69162e-06, 2.64201e-06, -1.4551e-06, 1.40103e-06, -2.42133e-07, 6.35927e-07, -1.43355e-06, 1.74766e-06] v_curr [-2.36763e-06, -1.66946e-06, -2.56267e-07, 2.98024e-08, nan, -1.69816e-06, -8.63365e-07, 4.94964e-07, -8.96302e-06, 7.79124e-06, -6.67959e-08, -4.83039e-06, 1.99395e-06, -1.09414

In [15]:
%timeit u_curr, v_curr = get_current_interp(np.array(platforms_set))

118 µs ± 25.4 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
