# Test `preproc.interp`

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import sys
sys.path.append('/efs_ecco/mgoldber/EH24-processors-llc/MITpreprobs/')

from MITpreprobs.preproc import UngriddedObsPreprocessor

%load_ext autoreload
%autoreload 2

In [None]:
grid_noblank_ds = xr.open_dataset('/efs_ecco/ECCO/V4/r4/ECCO_L4_GEOMETRY_LLC0090GRID_V4R4/GRID_GEOMETRY_ECCO_V4r4_native_llc0090.nc')

In [None]:
from MITpreprobs.utils import generate_random_points
#(nobs, lon_range=(-180, 180), lat_range=(-90, 90)):

nobs = 10
ungridded_lons, ungridded_lats = generate_random_points(nobs)#, lat_range=(70, 90))
# ungridded_lons, ungridded_lats = ([-50], [89])
# ungridded_lons, ungridded_lats = ([-40], [40])

grid_file = '/efs_ecco/ECCO/V4/r4/ECCO_L4_GEOMETRY_LLC0090GRID_V4R4/GRID_GEOMETRY_ECCO_V4r4_native_llc0090.nc'
grid_noblank_ds = xr.open_dataset(grid_file)

UOP = UngriddedObsPreprocessor('profiles')
UOP.get_obs_point(ungridded_lons,
                  ungridded_lats,
                  grid_type = 'llc',
                  grid_noblank_ds = grid_noblank_ds,
                  num_interp_points = 4)

In [None]:
def get_rnd(lower_lim, upper_lim, N, ndepth=50):
    arr = np.random.uniform(lower_lim, upper_lim, size=(ndepth, N))
    arr_sorted = np.sort(arr, axis=0)[::-1]
    return arr_sorted.T

nobs = 100
ungridded_lons, ungridded_lats = generate_random_points(nobs)

prof_T = get_rnd(3, 30, nobs, 50)
prof_Tweight = get_rnd(3, 0.5, nobs, 50)
prof_Tweight = np.flip(prof_Tweight)

In [None]:
# make ungridded_ds
UOP = UngriddedObsPreprocessor('profiles')
UOP.get_obs_point(ungridded_lons,
                  ungridded_lats,
                 )
ungridded_ds = UOP.ungridded_obs_ds.copy()
ungridded_ds['prof_T'] = xr.DataArray(prof_T, dims=['iPROF', 'iDEPTH'])
ungridded_ds['prof_Tweight'] = xr.DataArray(prof_Tweight, dims=['iPROF', 'iDEPTH'])
print(ungridded_ds)
UOP = UngriddedObsPreprocessor('profiles', ungridded_obs_ds = ungridded_ds)
UOP.get_obs_point(grid_type = 'llc',
                  grid_noblank_ds = grid_noblank_ds,
                  num_interp_points = 4)

In [None]:
import cartopy.crs as ccrs
# plot
#fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.Orthographic(central_longitude=0, central_latitude=90)})
fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()})

nskip = 2
#ax.scatter(UOP.xc.ravel()[::nskip], UOP.yc.ravel()[::nskip], color='grey', s=3, transform=ccrs.PlateCarree(), alpha=.5, label='model grid points')
prof_lons = UOP.xc_wm.ravel()[UOP.ungridded_obs_ds.prof_point]
prof_lats = UOP.yc_wm.ravel()[UOP.ungridded_obs_ds.prof_point]

# color interp groups together
color_by_group = (nobs <= 100)
if color_by_group:
    cmap = plt.cm.winter  # Choose a colormap
    num_groups = prof_lons.shape[0]
    colors = cmap(np.linspace(0, 1, num_groups))  # Get colors for each group
    for i in range(num_groups):
        ax.scatter(prof_lons[i], prof_lats[i], color=colors[i], s=40, transform=ccrs.PlateCarree(), label=f'obs group {i+1}' if i == 0 else "", alpha=0.3)
else:
    ax.scatter(prof_lons, prof_lats, color='b', s=40, transform=ccrs.PlateCarree(), label=f'observations', alpha=0.1)
    
ax.scatter(UOP.ungridded_obs_ds.prof_interp_XC11, UOP.ungridded_obs_ds.prof_interp_YC11, color='k', s=30, transform=ccrs.PlateCarree(), label='interp points 11')
ax.scatter(UOP.ungridded_obs_ds.prof_interp_XCNINJ, UOP.ungridded_obs_ds.prof_interp_YCNINJ, color='r', s=20, transform=ccrs.PlateCarree(), label='interp points NINJ')

# play with extent to view more points
#ax.set_extent([-180, 180, 60, 90], crs=ccrs.PlateCarree())
ax.coastlines()
gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')

ax.legend(facecolor='white', framealpha=0.7, loc='lower left', bbox_to_anchor=(0, 0.1))

fig.tight_layout()
#fig.set_size_inches(5, 5)
#fig.set_size_inches(20, 4)
ax.set_title('ungridded-to-gridded interpolation', fontsize=20)

### Bilinear interpolation, compute `sample_interp_weights`

In [None]:
ds = UOP.ungridded_obs_ds.copy()
prof_lon = ds.prof_lon.values  # Shape (100,)
prof_lat = ds.prof_lat.values  # Shape (100,)
prof_points = ds.prof_point.values

# Function to convert spherical coordinates to Cartesian
def spherical_to_cartesian(lon, lat):
    x = np.cos(np.radians(lat)) * np.cos(np.radians(lon))
    y = np.cos(np.radians(lat)) * np.sin(np.radians(lon))
    return x, y

# Convert prof_lon and prof_lat to Cartesian
x_prof, y_prof = spherical_to_cartesian(prof_lon, prof_lat)

# Convert prof_points to Cartesian coordinates
prof_lon_flat = UOP.xc_wm.ravel()[ds.prof_point]  # Shape (100, 4)
prof_lat_flat = UOP.yc_wm.ravel()[ds.prof_point]  # Shape (100, 4)
x_points, y_points = spherical_to_cartesian(prof_lon_flat, prof_lat_flat)

# Compute distances between reference points and interpolation points
# Reshape to (100, 1) for broadcasting
x_prof = x_prof[:, np.newaxis]
y_prof = y_prof[:, np.newaxis]

In [None]:

def bilinear_interpolation_weights(x, y, points):
    '''Interpolate (x,y) from values associated with four points.

    The four points are a list of four pairs:  (x, y).
    The four points can be in any order.  They should form a rectangle.

        >>> bilinear_interpolation_weights(12, 5.5,
        ...                        [(10, 4),
        ...                         (20, 4),
        ...                         (10, 6),
        ...                         (20, 6)])

    '''
    # See formula at:  http://en.wikipedia.org/wiki/Bilinear_interpolation

    points = sorted(points)               # order points by x, then by y
    (x1, y1), (_x1, y2), (x2, _y1), (_x2, _y2) = points
    
#    if x1 != _x1 or x2 != _x2 or y1 != _y1 or y2 != _y2:
#        raise ValueError('points do not form a rectangle')
#    if not x1 <= x <= x2 or not y1 <= y <= y2:
#        raise ValueError('(x, y) not within the rectangle')

    w11 = ((x2 - x) * (y2 - y)) / ((x2 - x1) * (y2 - y1))
    w12 = ((x2 - x) * (y - y1)) / ((x2 - x1) * (y2 - y1))
    w21 = ((x - x1) * (y2 - y)) / ((x2 - x1) * (y2 - y1))
    w22 = ((x - x1) * (y - y1)) / ((x2 - x1) * (y2 - y1))

    return np.array([w11, w12, w21, w22])[0]


In [None]:
# x_points, y_points are shape 100, 4
points_list = []
weights = np.zeros((100, 4))
for i in range(100):
    points = [(x_points[i, 0], y_points[i, 0]),
              (x_points[i, 1], y_points[i, 1]),
              (x_points[i, 2], y_points[i, 2]),
              (x_points[i, 3], y_points[i, 3])]
               
    points_list.append(points)
    weights[i, :] = bilinear_interpolation_weights(x_prof[i], y_prof[i], points)