# Investigate how different nearest neighbor algorithms map from ungridded to the model grid

In [None]:
import pyresample
import numpy as np
from os.path import expanduser,join,isdir
import sys
user_home_dir = expanduser('~')

import ecco_v4_py as ecco
import ecco_access as ea
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import pyresample as pr

from scipy.spatial import KDTree

def latlon_ungridded_to_gridded(ds, ungridded_lat, ungridded_lon, nn=1):

    # get gridded/llc coordinates
    grid_shape = ds.XC.shape

    xc = ds.XC.values
    yc = ds.YC.values
    
    gridded_coords = np.c_[yc.ravel(), xc.ravel()]
    ungridded_coords = np.c_[ungridded_lat, ungridded_lon]

    kd_tree = KDTree(gridded_coords)
    distance, nearest_grid_idx = kd_tree.query(ungridded_coords, k=nn)

    assert((nearest_grid_idx>np.prod(grid_shape)).sum()==0)

    gridded_lat = yc.ravel()[nearest_grid_idx]
    gridded_lon = xc.ravel()[nearest_grid_idx]
    return gridded_lat, gridded_lon

### Load any ECCO dataset that comes with the grid
This is overkill, just want to grab the model grid's lon and lat fields

In [None]:
# access_mode = 's3_open_fsspec'


# # ECCO_dir specifies parent directory of all ECCOv4r4 downloads
# # ECCO_dir = None downloads to default path ~/Downloads/ECCO_V4r4_PODAAC/
# ECCO_dir = join('/efs_ecco','ECCO_V4r4_PODAAC')

# # for access_mode = 's3_open_fsspec', need to specify the root directory 
# # containing the jsons
# jsons_root_dir = join('/efs_ecco','mzz-jsons')


# ShortNames_list = ["ECCO_L4_TEMP_SALINITY_LLC0090GRID_MONTHLY_V4R4"]
# ShortNames_list = ['MZZ_LLC0090GRID_GEOMETRY']
# # retrieve files
# StartDate = '2010-01'
# EndDate = '2010-12'
# grid_ds = ea.ecco_podaac_to_xrdataset(ShortNames_list,\
# #                                 StartDate=StartDate,EndDate=EndDate,\
#                                  mode=access_mode,\
#                                  download_root_dir=ECCO_dir,\
#                                  max_avail_frac=0.5,\
#                                  jsons_root_dir=jsons_root_dir)

# Clone the [ECCO-obs-pipeline](https://github.com/ECCO-GROUP/ECCO-obs-pipeline) repo to use fancy grid-aware interpolation

In [None]:
sys.path.append('/home/jovyan/efs_ecco/mgoldber/ECCO-obs-pipeline/ecco_pipeline/utils/processing_utils/')
from transformation_utils import *

- Source grid: model grid
- Target grid: a set of arbitrary lat-lon points

TO DO: remove land

In [None]:
import xarray as xr
grid_dir = '/home/jovyan/efs_ecco/mgoldber/ECCO-obs-pipeline/ecco_pipeline/grids/'
ecco_grid = xr.open_dataset(grid_dir + 'ECCO_llc90.nc')
#ecco_grid = xr.open_dataset('/efs_ecco/ECCO/V4/r4/ECCO_L4_GEOMETRY_LLC0090GRID_V4R4/GRID_GEOMETRY_ECCO_V4r4_native_llc0090.nc')

### Grid-aware nearest neighbors

In [None]:
def get_interp_points(ungridded_lat, ungridded_lon,
                      grid_ds, nneighbours=4,
                      source_grid_min_L= 20e3,
                      source_grid_max_L = 111e3,
                      max_target_grid_radius = int(15e4),
                     ):

    target_grid = pr.geometry.SwathDefinition(
        lats=ungridded_lats, lons=ungridded_lons
    )
    
    source_grid = pr.geometry.SwathDefinition(
        lons=grid_ds.XC.values.ravel(), lats=grid_ds.YC.values.ravel()
    )

    # make sure we have this attribute
    # write an assert stateements to check that this field exists
    target_grid_radius = grid_ds.effective_grid_radius.values.ravel()
    
    factors = find_mappings_from_source_to_target(
        source_grid,
        target_grid,
        target_grid_radius,
        source_grid_min_L,
        source_grid_max_L,
        grid_name='ecco',
        less_output=False
    )

    nn_info = pr.kd_tree.get_neighbour_info(
        source_grid,
        target_grid,
        radius_of_influence=int(max_target_grid_radius),
        neighbours=nneighbours,
    )

    return nn_info

# Generate ungridded point
You can use the random point generator function or hard code your own! Note that the random points generator by default produces points in high latitudes -- see `lat_range` argument

In [None]:
# generate nobs random observations
def generate_random_points(nobs, lon_range=(-180, 180), lat_range=(60, 90)):
    lons = np.random.uniform(low=lon_range[0], high=lon_range[1], size=nobs)
    lats = np.random.uniform(low=lat_range[0], high=lat_range[1], size=nobs)
    return lons, lats

nobs = 10
ungridded_lons, ungridded_lats = generate_random_points(nobs)
#ungridded_lons, ungridded_lats = ([-50], [89])

nn_info = get_interp_points(ungridded_lats, ungridded_lons, ecco_grid, nneighbours=4)
nn_tf = nn_info[1]
nn_src_to_target = nn_info[2] # nobs by nneighbours -- nearest_grid_index
nn_dist = nn_info[3]

### Examine weighting methods

In [None]:
print(nn_dist)
inv_dist = 1 / nn_dist
inv_dist_sq = inv_dist ** 2
weights = inv_dist / np.sum(inv_dist)
weights_is = inv_dist_sq / np.sum(inv_dist_sq)
print(f'{weights.sum():.2e}, {weights}')
print(f'{weights_is.sum():.2e}, {weights_is}')

# Demonstrate the difference in a plot

In [None]:
gridded_lats = ecco_grid.YC.values.ravel()[nn_src_to_target]
gridded_lons = ecco_grid.XC.values.ravel()[nn_src_to_target]

# plot
fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.Orthographic(central_longitude=0, central_latitude=90)})

grid_lon, grid_lat = (ecco_grid.XC, ecco_grid.YC) # each are size (13, 90, 90)
nskip = None # easier to see plot with fewer points
grid_lon_flat = grid_lon.values.flatten()
grid_lat_flat = grid_lat.values.flatten()

gridded_lat_naive, gridded_lon_naive = latlon_ungridded_to_gridded(ecco_grid, ungridded_lats, ungridded_lons, nn=4)

ax.scatter(grid_lon_flat, grid_lat_flat, color='grey', s=3, transform=ccrs.PlateCarree(), alpha=.5, label='model grid points')
ax.scatter(gridded_lons, gridded_lats, color='blue', s=20, transform=ccrs.PlateCarree(), label='observations (grid-aware nearest nbr)')
ax.scatter(gridded_lon_naive, gridded_lat_naive, color='g', s=20, transform=ccrs.PlateCarree(), label='observations (naive nearest nbr)')
ax.scatter(ungridded_lons, ungridded_lats, color='r', s=10, transform=ccrs.PlateCarree(), label='observations (ungridded)')

# play with extent to view more points
ax.set_extent([-180, 180, 80, 90], crs=ccrs.PlateCarree())
ax.coastlines()
ax.gridlines()
ax.legend(facecolor='white', framealpha=1)
fig.tight_layout()
fig.set_size_inches(5, 5)
ax.set_title('ungridded-to-gridded interpolation', fontsize=20)

### Sweet's method
Haven't finished this yet, but this is [Sweet's method](https://github.com/ECCO-GROUP/ECCO-Insitu-Python/blob/fcf3415bae73d833953df2749134f9eda1573fc3/tools.py) for finding nearest neighbors

In [None]:
import math
from scipy.interpolate import griddata

def sph2cart(r, theta, phi):
    """Converts spherical coordinates (r, theta, phi) to Cartesian coordinates (x, y, z)."""
    x = r * math.sin(phi) * math.cos(theta)
    y = r * math.sin(phi) * math.sin(theta)
    z = r * math.cos(phi)
    return x, y, z

X, Y, Z = (ds.XC.values.ravel(), ds.YC.values.ravel(), np.ones_like(ds.XC).ravel())
xyz = np.column_stack((X, Y, Z))
# map a grid index to each profile.
AI = np.arange(ds.XC.size)
    
deg2rad = np.pi/180.0

# Read and process the profile files
prof_x, prof_y, prof_z = sph2cart(ungridded_lons*deg2rad, ungridded_lats*deg2rad, 1)

prof_llcN_cell_index = griddata(xyz, AI, np.column_stack((prof_x, prof_y, prof_z)), 'nearest')

# I think these still need to be converted back to lat/lon
X[int(prof_llcN_cell_index)], Y[int(prof_llcN_cell_index)]


### Plot model longitude and latitude (XC and YC, respectively)

In [None]:
import matplotlib.pyplot as plt
_, XC_wm = ecco.plot_tiles(ds.XC, layout='latlon', rotate_to_latlon=True, show_tile_labels=False)
plt.close()
_, YC_wm = ecco.plot_tiles(ds.YC, layout='latlon', rotate_to_latlon=True, show_tile_labels=False)
plt.close()
# XC_wm.shape