# missing data interpolation

statistics is the answer to everything

### potential shenanigans

"Several techniques have been used to fill the gaps in either the UWLS or OI derived total vector maps.

These are implemented using covariance derived from normal mode analysis (Lipphardt et al. 2000), open-boundary modal analysis (OMA) (Kaplan and Lekien 2007), and empirical orthogonal function (EOF) analysis (Beckers and Rixen 2003; Alvera-Azcárate et al. 2005); and using idealized or smoothed observed covariance (Davis 1985)."

- normal mode analysis
- open-boundary modal analysis (OMA)
- empirical orthogonal function analysis (EOF)
- use idealized/smoothed observed covariance

---

### other ideas

DINEOF (could only find an implementation in R)

to be honest I don't understand any of these methods but they look cool

### currently implemented:

rip data straight from the lower resolution data for areas where data is considered missing in the high resolution data

In [None]:
%matplotlib inline

In [None]:
import math
import time
import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
from parcels import FieldSet
from datetime import timedelta, datetime

from utils import generate_mask, conv_to_dataarray
from parcels_utils import get_file_info, xr_dataset_to_fieldset

### target and interp_references

`target` is the data you are interpolating.

`interp_references` is a list of reference data to interpolate from. A few specifications:
- should be ordered from most accurate data to least accurate
- time domain should be identical or bigger than the one of the target
- lat and lon domain should be bigger than the target's to prevent any out-of-bounds complications

In [None]:
target = get_file_info("current_netcdfs/west_coast_1km_hourly/tijuana_river_now.nc", 1, name="target")

interp_references = [
    get_file_info("current_netcdfs/west_coast_2km_hourly/region0.nc", 2, name="ref2km"),
    get_file_info("current_netcdfs/west_coast_6km_hourly/region0.nc", 6, name="ref6km")
]

### interpolation type

more information can be found in the `tutorial_interpolation` notebook

EDIT: just use `linear`

In [None]:
reference_interp_method = "linear"

for r in interp_references:
    r["fs_flat"].U.interp_method = reference_interp_method
    r["fs_flat"].V.interp_method = reference_interp_method

## nan values and parcels

note that when this xarray Dataset is passed into parcels, all the nan values change to 0 and the mask generation won't work anymore

so the Dataset is copied for use with the FieldSet instead

In [None]:
invalid = generate_mask(target["xrds"]["u"].values)
num_invalid = invalid.sum()
print(f"total invalid values on target data: {num_invalid}")

In [None]:
invalid_pos = np.where(invalid)
invalid_pos

### use of Parcels Field for interpolation

indexing Field values goes [time, depth, lat, lon]

Field does interpolation automatically when indexing values between it's coordinate values

### note

https://stackoverflow.com/questions/12923586/nearest-neighbor-search-python

in theory, the latitude and longitude values are equally spaced. however, the difference between the coordinate values always flucuates a very small amount between two distinct values, so it's not perfectly equally spaced.

from testing, this causes enough error to completely change a simulation, so a kdtree must be used.

# another note

I just looked at the particle trajectory that used the indexing without kd trees and it looked MAJORLY fucked up, so I probably wrote something wrong with the method. TODO will come back to this later.

In [None]:
import scipy.spatial

# set up kdtrees for all references
for f in interp_references:
    f["latkdtree"] = scipy.spatial.cKDTree(np.array([f["lat"]]).T)
    f["lonkdtree"] = scipy.spatial.cKDTree(np.array([f["lon"]]).T)

In [None]:
def get_nearest_current(ref, t, lat, lon):
    index = (t, ref["latkdtree"].query([lat])[1], ref["lonkdtree"].query([lon])[1])
    return ref["xrds"]["u"].values[index], ref["xrds"]["v"].values[index]


def get_current_interp(ref, t, lat, lon):
    return ref["fs_flat"].U[t, 0, lat, lon], ref["fs_flat"].V[t, 0, lat, lon]


def get_interped(i, ref, invalid_where):
    """
    Args:
        i (int): index on invalid_where
        ref (dict): reference Dataset
        invalid_where (array-like): (3, n) dimensional array representing all invalid positions
    
    Returns:
        (u, v): (nan, nan) if no data was found, interpolated values otherwise
    """
    time_diff = np.diff(ref["fs_flat"].U.grid.time)[0]
    t = invalid_where[0][i]
    lat = target["lat"][invalid_where[1][i]]
    lon = target["lon"][invalid_where[2][i]]
    current_u, current_v = get_current_interp(ref, t * time_diff, lat, lon)
    current = current_u + current_v
    # TODO: testing whether `current == 0` actually helps or not
    # basically makes interpolation a bit more "aggressive"?
    if np.isnan(get_nearest_current(ref, t, lat, lon)[0]):
        return np.nan, np.nan
    return current_u, current_v

### today i learned

doing DataArray.values does not return a numpy array copy

In [None]:
time_st = time.time()
target_interped_u = target["xrds"]["u"].values.copy()
target_interped_v = target["xrds"]["v"].values.copy()
invalid_interped = invalid.copy()
for f in interp_references:
    invalid_pos_new = np.where(invalid_interped)
    num_invalid_new = invalid_interped.sum()
    arr_u = np.zeros(num_invalid_new)
    arr_v = np.zeros(num_invalid_new)
    for i in range(num_invalid_new):
        c_u, c_v = get_interped(i, f, invalid_pos_new)
        arr_u[i] = c_u
        arr_v[i] = c_v
    target_interped_u[invalid_pos_new] = arr_u
    target_interped_v[invalid_pos_new] = arr_v
    invalid_interped = generate_mask(target_interped_u)
    print(f"total invalid values after interpolation with {f['name']}: {invalid_interped.sum()}")
    print(f"    values filled: {num_invalid_new - invalid_interped.sum()}")
time_en = time.time()
print(f"time elapsed: {time_en - time_st}")

#### result: interpolate only when interpolated current is non-zero OR not nan

total invalid values after interpolation with ref2km: 1562

    values filled: 12589
total invalid values after interpolation with ref6km: 2

    values filled: 1560
    
YEP it fills a lot more values

#### result: interpolate whenever nearest data point is not nan

total invalid values after interpolation with ref2km: 3176

    values filled: 10975
total invalid values after interpolation with ref6km: 758

    values filled: 2418

In [None]:
print(f"total invalid values on interpolated: {invalid_interped.sum()}")

### gridfill invalid values after interpolation

in theory there should still be invalid spaces left over because the 6 km data will sometimes have gaps

hopefully after this interpolation, the gaps of invalid data are small enough to let gridfilling finish the job

In [None]:
import numpy.ma as ma
from gridfill import fill

masked_u = ma.masked_array(target_interped_u, invalid_interped)
masked_v = ma.masked_array(target_interped_v, invalid_interped)

kw = dict(eps=1e-4, relax=0.6, itermax=1e4, initzonal=False,
          cyclic=False, verbose=True)

# since data is 3d, use axes 2 and 1 since axis 0 is time
filled_u, converged_u = fill(masked_u, 2, 1, **kw)
filled_v, converged_v = fill(masked_v, 2, 1, **kw)

In [None]:
invalid_filled = generate_mask(filled_u)
print(f"total invalid values on filled: {invalid_filled.sum()}")

In [None]:
print(f"percent invalid filled: {(invalid_interped.sum() - invalid_filled.sum()) / invalid_interped.sum()}")

### wow

unsurprisingly gridfill did jack shit

### formatting, saving, and testing

In [None]:
# re-add coordinates, dimensions, and metadata to interpolated data
darr_u = conv_to_dataarray(filled_u, target["xrds"]["u"])
darr_v = conv_to_dataarray(filled_v, target["xrds"]["v"])
target_interped_xrds = target["xrds"].drop_vars(["u", "v"]).assign(u=darr_u, v=darr_v)
target_interped_xrds

In [None]:
save_path = target["path"].split(".nc")[0] + "_interped.nc"
target_interped_xrds.to_netcdf(save_path)
print(f"saved to {save_path}")

### display field to see if interpolation worked

In [None]:
fs_interp = xr_dataset_to_fieldset(target_interped_xrds)

In [None]:
target["fs"].U.show()
fs_interp.U.show()