In [1]:
from dataclasses import dataclass
import casadi as ca
import numpy as np
import time
import math
from typing import Dict, Optional
import logging
import datetime
import timeit
from tqdm import tqdm

from ocean_navigation_simulator.data_sources import OceanCurrentSource
from ocean_navigation_simulator.data_sources.SolarIrradiance.SolarIrradianceSource import SolarIrradianceSource
from ocean_navigation_simulator.data_sources.SeaweedGrowth.SeaweedGrowthSource import SeaweedGrowthSource
from ocean_navigation_simulator.utils import units
from ocean_navigation_simulator.environment.PlatformState import PlatformState
from ocean_navigation_simulator.environment.PlatformState import SpatialPoint
from ocean_navigation_simulator.environment.ArenaFactory import ArenaFactory
from ocean_navigation_simulator.environment.NavigationProblem import NavigationProblem
from ocean_navigation_simulator.controllers.hj_planners.HJReach2DPlanner import HJReach2DPlanner
from ocean_navigation_simulator.utils import units
import matplotlib.pyplot as plt
import os
os.chdir('/home/nicolas/documents/Master_Thesis_repo/OceanPlatformControl')
print(os.getcwd())

/home/nicolas/documents/Master_Thesis_repo/OceanPlatformControl


In [2]:
from ocean_navigation_simulator.data_sources.OceanCurrentField import OceanCurrentField
from ocean_navigation_simulator.data_sources.SeaweedGrowthField import SeaweedGrowthField
from ocean_navigation_simulator.data_sources.SolarIrradianceField import SolarIrradianceField
from typing import Dict, Optional, Union, Tuple, List, AnyStr, Literal, Callable
import yaml

scenario_name = 'gulf_of_mexico_HYCOM_hindcast_local'
with open(f'config/arena/{scenario_name}.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
casadi_cache_dict=config['casadi_cache_dict']
platform_dict=config['platform_dict']
ocean_dict= config['ocean_dict']
use_geographic_coordinate_system=config['use_geographic_coordinate_system']
spatial_boundary=config['spatial_boundary']

ocean_field = OceanCurrentField(
    casadi_cache_dict=casadi_cache_dict,
    hindcast_source_dict=ocean_dict['hindcast'],
    forecast_source_dict=ocean_dict['forecast'],
    use_geographic_coordinate_system=use_geographic_coordinate_system)

ocean_source = ocean_field.hindcast_data_source

INFO:arena.ocean_field:DataField: Create Hindcast Source (10.8s)
INFO:arena.ocean_field:DataField: Forecast is the same as Hindcast for OceanCurrents.


In [38]:
x_0 = PlatformState(lon=units.Distance(deg=-83.25), lat=units.Distance(deg=23.25),
                    date_time=datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc)) 
                    # define a platform approx. in the middle for interpolation and init
# Step 1: Create the intervals to query data for
t_interval, y_interval, x_interval, = ocean_source.convert_to_x_y_time_bounds(
    x_0=x_0.to_spatio_temporal_point(),
    x_T=x_0.to_spatial_point(),
    deg_around_x0_xT_box=ocean_source.source_config_dict["casadi_cache_settings"]["deg_around_x_t"],
    temp_horizon_in_s=ocean_source.source_config_dict["casadi_cache_settings"]["time_around_x_t"],
)
# Step 2: Get the data from itocean_source and update casadi_grid_dict
xarray = ocean_source.get_data_over_area(x_interval, y_interval, t_interval)
ocean_source.casadi_grid_dict = ocean_source.get_grid_dict_from_xr(xarray)

# Step 3: Set up the grid with time lat lon order as for the xarray dataset
grid = [
    units.get_posix_time_from_np64(xarray.coords["time"].values),  
    xarray.coords["lat"].values,
    xarray.coords["lon"].values,
]


In [4]:
import jax.numpy as jnp
import jax.scipy.ndimage
from jax import jit
from scipy import interpolate

In [39]:
idx_point = jnp.array(# time lat lon order as for the xarray dataset
    [
        jnp.interp(x_0.date_time.timestamp(), grid[0], np.arange(len(grid[0]))),
        jnp.interp(x_0.lat.deg, grid[1], np.arange(len(grid[1]))),
        jnp.interp(x_0.lon.deg, grid[2], np.arange(len(grid[2]))),
    ]
).reshape(-1, 1)
print(idx_point)
print(grid[0])
print(grid[1])

[[ 1.  ]
 [13.25]
 [13.75]]
[1.6377516e+09 1.6377552e+09 1.6377588e+09 1.6377624e+09 1.6377660e+09
 1.6377696e+09 1.6377732e+09 1.6377768e+09 1.6377804e+09 1.6377840e+09
 1.6377876e+09 1.6377912e+09 1.6377948e+09 1.6377984e+09 1.6378020e+09
 1.6378056e+09 1.6378092e+09 1.6378128e+09 1.6378164e+09 1.6378200e+09
 1.6378236e+09 1.6378272e+09 1.6378308e+09 1.6378344e+09 1.6378380e+09
 1.6378416e+09 1.6378452e+09 1.6378488e+09 1.6378524e+09 1.6378560e+09
 1.6378596e+09 1.6378632e+09 1.6378668e+09 1.6378704e+09 1.6378740e+09
 1.6378776e+09 1.6378812e+09 1.6378848e+09 1.6378884e+09 1.6378920e+09
 1.6378956e+09 1.6378992e+09 1.6379028e+09 1.6379064e+09 1.6379100e+09
 1.6379136e+09 1.6379172e+09 1.6379208e+09 1.6379244e+09 1.6379280e+09
 1.6379316e+09]
[22.71999931 22.76000023 22.79999924 22.84000015 22.87999916 22.92000008
 22.95999908 23.         23.04000092 23.07999992 23.12000084 23.15999985
 23.20000076 23.23999977 23.28000069 23.31999969 23.36000061 23.39999962
 23.44000053 23.47999954 23

In [45]:
print(xarray["water_u"].values.shape)
print(xarray)
print(idx_point.shape)
u_curr_interp = jax.scipy.ndimage.map_coordinates(xarray["water_u"].values, idx_point, order=1)
print(u_curr_interp)

(51, 28, 28)
<xarray.Dataset>
Dimensions:  (time: 51, lat: 28, lon: 28)
Coordinates:
  * time     (time) datetime64[ns] 2021-11-24T11:00:00 ... 2021-11-26T13:00:00
    depth    float64 4.0
  * lat      (lat) float64 22.72 22.76 22.8 22.84 ... 23.68 23.72 23.76 23.8
  * lon      (lon) float64 -83.8 -83.76 -83.72 -83.68 ... -82.8 -82.76 -82.72
Data variables:
    water_u  (time, lat, lon) float32 nan nan nan nan ... 0.406 0.418 0.436
    water_v  (time, lat, lon) float32 nan nan nan nan ... 0.502 0.557 0.605
Attributes: (12/14)
    classification_level:      UNCLASSIFIED
    distribution_statement:    Approved for public release; distribution unli...
    downgrade_date:            not applicable
    classification_authority:  not applicable
    institution:               Naval Oceanographic Office
    source:                    HYCOM archive file
    ...                        ...
    Conventions:               CF-1.6 NAVO_netcdf_v1.1
    History:                   Translated to CF-1.0 C

In [86]:
water_uv = np.stack((xarray.water_u.values, xarray.water_v.values), axis=3)
print(water_uv.shape)
print(idx_point.shape)
#idx_points_stack = np.concatenate((idx_point, idx_point), axis=1)
#print(idx_points_stack.shape)
uv_curr_interp = jax.scipy.ndimage.map_coordinates(water_uv, idx_point, order=1)

(51, 28, 28, 2)
(3, 1)


ValueError: coordinates must be a sequence of length input.ndim, but 3 != 4

In [88]:
from typing import List
import dataclasses
from dataclasses import astuple
@dataclasses.dataclass
class PlatformStateSet:
    states:List[PlatformState]

    def __array__(self):
        return np.array(self.states) #rows are the number of platforms

    def __len__(self):
        return len(self.states)

    def __getitem__(self, platform_id):
        return np.array(self.states[platform_id])

In [89]:
def rand_platforms(nb_platforms):
    lon = np.random.uniform(low=-83.5, high=-83, size=(nb_platforms,))
    lat = np.random.uniform(low=23, high=23.5, size=(nb_platforms,))
    id = np.arange(start=1, stop=nb_platforms+1, step=1)
    t = np.repeat(datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc), nb_platforms)
    return lon,lat,id,t

In [90]:
def get_platform(nb_platforms):
    lon,lat,id,t = rand_platforms(nb_platforms)
    platforms_list = [PlatformState(lon=units.Distance(deg=lon[k]), lat=units.Distance(deg=lat[k]), date_time=t[k]) \
                      for k in range(nb_platforms)]
    # create the Set class for the platforms (version 1 with list of Platform states)
    platforms_set = PlatformStateSet(states=platforms_list)
    return platforms_set

In [160]:
nb_platforms = 10
platforms_set = get_platform(nb_platforms)

In [161]:
@jit
def init_lin_interpo_3D_fields(states_np, u_curr, v_curr, grid_eval):
    idx_point = jnp.array(# time lat lon order as for the xarray dataset
    [
        jnp.interp(states_np[:,2], grid_eval[0], np.arange(len(grid[0]))),
        jnp.interp(states_np[:,1], grid_eval[1], np.arange(len(grid[1]))),
        jnp.interp(states_np[:,0], grid_eval[2], np.arange(len(grid[2]))),
    ]
    )
    u_curr_interp = jax.scipy.ndimage.map_coordinates(u_curr, idx_point, order=1)
    v_curr_interp = jax.scipy.ndimage.map_coordinates(v_curr, idx_point, order=1)
    return u_curr_interp, v_curr_interp

In [162]:
from functools import partial
get_current_interp = partial(init_lin_interpo_3D_fields, u_curr=xarray.water_u.values, v_curr = xarray.water_v.values, \
                            grid_eval = grid)

In [163]:
u_curr, v_curr = init_lin_interpo_3D_fields(np.array(platforms_set), xarray.water_u.values, xarray.water_v.values, grid)

In [164]:
u_curr, v_curr = get_current_interp(np.array(platforms_set))
print(u_curr, v_curr)

[1.0595863  0.749789   1.1416535  1.1602628  1.0976499  0.9721538
 0.86925066 0.9937061  0.8094971  0.7509687 ] [ 0.1838933  -0.0660328   0.41365105  0.33383605  0.23608424  0.28022778
  0.16997062  0.27953756  0.17870289  0.17353368]


In [165]:
import timeit
# from ocean_navigation_simulator.data_sources.OceanCurrentSource import initialize
%timeit u_curr, v_curr = get_current_interp(np.array(platforms_set))

55.1 µs ± 1.95 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [166]:
ocean_source.initialize_casadi_functions(grid, xarray)

In [167]:
# print(np.array(platforms_set))
# print(np.array(platforms_set)[:,[2,1,0]].T)
lon_lat_t = np.array(platforms_set)[:,[2,1,0]].T #extract time, lat, lon with platforms as columns for casADi interpolation
%timeit u_curr = ocean_source.u_curr_func(lon_lat_t)

20.2 µs ± 2.49 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [168]:
%%timeit
u_curr = ocean_source.u_curr_func(lon_lat_t)
v_curr = ocean_source.v_curr_func(lon_lat_t)

40.8 µs ± 1.35 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
