Skip to content

Commit

Permalink
Updated tiles_interp to use interpn
Browse files Browse the repository at this point in the history
  • Loading branch information
acolite committed Apr 18, 2024
1 parent ad24126 commit 23af717
Showing 1 changed file with 14 additions and 30 deletions.
44 changes: 14 additions & 30 deletions acolite/shared/tiles_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,44 @@
## 2020-11-18 (QV) added dtype to convert from griddata float64, by default float32
## this improves peak memory use when several datasets are kept in memory
## 2021-02-11 (QV) added smooth keyword, default to nearest
## 2024-04-18 (QV) new version using interpn


def tiles_interp(data, xnew, ynew, smooth = False, kern_size=2, method='nearest', mask=None,
target_mask=None, target_mask_full=False, fill_nan = True, dtype='float32'):

import numpy as np
from scipy.interpolate import griddata
from scipy.interpolate import interpn
from scipy.ndimage import uniform_filter,percentile_filter, distance_transform_edt
import acolite as ac

if mask is not None: data[mask] = np.nan

## fill nans with closest value
if fill_nan:
#ind = distance_transform_edt(np.isnan(data), return_distances=False, return_indices=True)
#cur_data = data[tuple(ind)]
cur_data = ac.shared.fillnan(data)
else:
cur_data = data*1.0

## smooth dataset
if smooth:
z = uniform_filter(cur_data, size=kern_size)
zv=list(z.ravel())
else:
zv=list(cur_data.ravel())
dim = data.shape

### tile centers
#x = arange(0.5, dim[1], 1)
#y = arange(0.5, dim[0], 1)
dim = cur_data.shape
if smooth: cur_data = uniform_filter(cur_data, size = kern_size)

## tile edges
## grid positions
x = np.arange(0., dim[1], 1)
y = np.arange(0., dim[0], 1)

xv, yv = np.meshgrid(x, y, sparse=False)
ci = (list(xv.ravel()), list(yv.ravel()))

## interpolate
if target_mask is None:
## full dataset
znew = griddata(ci, zv, (xnew[None,:], ynew[:,None]), method=method)
znew = interpn((y,x), cur_data, (ynew[:,None], xnew[None, :]), method = method, bounds_error = False)
else:
## limit to target mask
vd = np.where(target_mask)
if target_mask_full:
## return a dataset with the proper dimensions
znew = np.zeros((len(ynew), len(xnew)))+np.nan
znew[vd] = griddata(ci, zv, (xnew[vd[1]], ynew[vd[0]]), method=method)
znew = np.zeros((len(ynew), len(xnew))).astype(dtype)+np.nan
znew[vd] = interpn((y,x), cur_data, (ynew[vd[0]], xnew[vd[1]]), method = method, bounds_error = False)
else:
## return only target_mask data
znew = griddata(ci, zv, (xnew[vd[1]], ynew[vd[0]]), method=method)
znew = interpn((y,x), cur_data, (ynew[vd[0]], xnew[vd[1]]), method = method, bounds_error = False)

if dtype is None:
return(znew)
else:
return(znew.astype(np.dtype(dtype)))
## to convert data type - scipy always returns float64
if dtype is not None: znew = znew.astype(np.dtype(dtype))
return(znew)

0 comments on commit 23af717

Please sign in to comment.