Skip to content

Commit

Permalink
numba correction with masked array
Browse files Browse the repository at this point in the history
  • Loading branch information
AntSimi committed May 4, 2023
1 parent 1728815 commit 1b9ab25
Show file tree
Hide file tree
Showing 12 changed files with 256 additions and 94 deletions.
1 change: 0 additions & 1 deletion doc/spectrum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ Compute and display spectrum
ax.set_title("Spectrum")
ax.set_xlabel("km")
for name_area, area in areas.items():
lon_spec, lat_spec = raw.spectrum_lonlat("adt", area=area)
mappable = ax.loglog(*lat_spec, label="lat %s raw" % name_area)[0]
ax.loglog(
Expand Down
4 changes: 3 additions & 1 deletion examples/16_network/pet_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def update_axes(ax, mappable=None):
# Merging in networks longer than 10 days, with dead end remove (shorter than 10 observations)
# --------------------------------------------------------------------------------------------
ax = start_axes("")
merger = n10.remove_dead_end(nobs=10).merging_event()
n10_ = n10.copy()
n10_.remove_dead_end(nobs=10)
merger = n10_.merging_event()
g_10_merging = merger.grid_count(bins)
m = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1)
update_axes(ax, m).set_label("Pixel used in % of time")
Expand Down
3 changes: 2 additions & 1 deletion examples/16_network/pet_follow_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def save(self, *args, **kwargs):
# %%
n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651)
n = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269))
n = n.remove_dead_end(nobs=0, ndays=10)
n.remove_dead_end(nobs=0, ndays=10)
n = n.remove_trash()
n.numbering_segment()
c = GridCollection.from_netcdf_cube(
get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"),
Expand Down
8 changes: 6 additions & 2 deletions examples/16_network/pet_relative.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@
# Remove dead branch
# ------------------
# Remove all tiny segments with less than N obs which didn't join two segments
n_clean = n.remove_dead_end(nobs=5, ndays=10)
n_clean = n.copy()
n_clean.remove_dead_end(nobs=5, ndays=10)
n_clean = n_clean.remove_trash()
fig = plt.figure(figsize=(15, 12))
ax = fig.add_axes([0.04, 0.54, 0.90, 0.40])
ax.set_title(f"Original network ({n.infos()})")
Expand Down Expand Up @@ -261,7 +263,9 @@
# --------------------

# Get a simplified network
n = n2.remove_dead_end(nobs=50, recursive=1)
n = n2.copy()
n.remove_dead_end(nobs=50, recursive=1)
n = n.remove_trash()
n.numbering_segment()
# %%
# Only a map can be tricky to understand, with a timeline it's easier!
Expand Down
1 change: 0 additions & 1 deletion examples/16_network/pet_replay_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def get_obs(dataset):
for b0, b1 in [
(datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008)
]:

ref, delta = datetime(1950, 1, 1), 20
b0_, b1_ = (b0 - ref).days, (b1 - ref).days
ax = timeline_axes()
Expand Down
118 changes: 87 additions & 31 deletions src/py_eddy_tracker/dataset/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,45 @@
"""
Class to load and manipulate RegularGrid and UnRegularGrid
"""
import logging
from datetime import datetime
import logging

from cv2 import filter2D
from matplotlib.path import Path as BasePath
from netCDF4 import Dataset
from numba import njit, prange
from numba import types as numba_types
from numpy import (arange, array, ceil, concatenate, cos, deg2rad, empty,
errstate, exp, float_, floor, histogram2d, int_, interp,
isnan, linspace, ma)
from numpy import mean as np_mean
from numpy import (meshgrid, nan, nanmean, ones, percentile, pi, radians,
round_, sin, sinc, where, zeros)
from numba import njit, prange, types as numba_types
from numpy import (
arange,
array,
ceil,
concatenate,
cos,
deg2rad,
empty,
errstate,
exp,
float_,
floor,
histogram2d,
int_,
interp,
isnan,
linspace,
ma,
mean as np_mean,
meshgrid,
nan,
nanmean,
ones,
percentile,
pi,
radians,
round_,
sin,
sinc,
where,
zeros,
)
from pint import UnitRegistry
from scipy.interpolate import RectBivariateSpline, interp1d
from scipy.ndimage import gaussian_filter
Expand All @@ -26,13 +51,25 @@
from .. import VAR_DESCR
from ..data import get_demo_path
from ..eddy_feature import Amplitude, Contours
from ..generic import (bbox_indice_regular, coordinates_to_local, distance,
interp2d_geo, local_to_coordinates, nearest_grd_indice,
uniform_resample)
from ..generic import (
bbox_indice_regular,
coordinates_to_local,
distance,
interp2d_geo,
local_to_coordinates,
nearest_grd_indice,
uniform_resample,
)
from ..observations.observation import EddiesObservations
from ..poly import (create_vertice, fit_circle, get_pixel_in_regular,
poly_area, poly_contain_poly, visvalingam,
winding_number_poly)
from ..poly import (
create_vertice,
fit_circle,
get_pixel_in_regular,
poly_area,
poly_contain_poly,
visvalingam,
winding_number_poly,
)

logger = logging.getLogger("pet")

Expand Down Expand Up @@ -86,7 +123,7 @@ def value_on_regular_contour(x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size

@njit(cache=True)
def mean_on_regular_contour(
x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=None, nan_remove=False
x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=-1, nan_remove=False
):
x_val, y_val = vertices[:, 0], vertices[:, 1]
x_new, y_new = uniform_resample(x_val, y_val, num_fac, fixed_size)
Expand Down Expand Up @@ -406,8 +443,8 @@ def setup_coordinates(self):
x_name, y_name = self.coordinates
if self.is_centered:
# logger.info("Grid center")
self.x_c = self.vars[x_name].astype("float64")
self.y_c = self.vars[y_name].astype("float64")
self.x_c = array(self.vars[x_name].astype("float64"))
self.y_c = array(self.vars[y_name].astype("float64"))

self.x_bounds = concatenate((self.x_c, (2 * self.x_c[-1] - self.x_c[-2],)))
self.y_bounds = concatenate((self.y_c, (2 * self.y_c[-1] - self.y_c[-2],)))
Expand All @@ -419,8 +456,8 @@ def setup_coordinates(self):
self.y_bounds[-1] -= d_y[-1] / 2

else:
self.x_bounds = self.vars[x_name].astype("float64")
self.y_bounds = self.vars[y_name].astype("float64")
self.x_bounds = array(self.vars[x_name].astype("float64"))
self.y_bounds = array(self.vars[y_name].astype("float64"))

if len(self.x_dim) == 1:
self.x_c = self.x_bounds.copy()
Expand Down Expand Up @@ -757,7 +794,7 @@ def eddy_identification(

# Test of the rotating sense: cyclone or anticyclone
if has_value(
data, i_x_in, i_y_in, cvalues, below=anticyclonic_search
data.data, i_x_in, i_y_in, cvalues, below=anticyclonic_search
):
continue

Expand Down Expand Up @@ -788,7 +825,6 @@ def eddy_identification(
contour.reject = 4
continue
if reset_centroid:

if self.is_circular():
centi = self.normalize_x_indice(reset_centroid[0])
else:
Expand Down Expand Up @@ -1285,8 +1321,8 @@ def compute_pixel_path(self, x0, y0, x1, y1):
def clean_land(self, name):
"""Function to remove all land pixel"""
mask_land = self.__class__(get_demo_path("mask_1_60.nc"), "lon", "lat")
x,y = meshgrid(self.x_c, self.y_c)
m = mask_land.interp('mask', x.reshape(-1), y.reshape(-1), 'nearest')
x, y = meshgrid(self.x_c, self.y_c)
m = mask_land.interp("mask", x.reshape(-1), y.reshape(-1), "nearest")
data = self.grid(name)
self.vars[name] = ma.array(data, mask=m.reshape(x.shape).T)

Expand All @@ -1310,7 +1346,7 @@ def get_step_in_km(self, lat, wave_length):
min_wave_length = max(step_x_km, step_y_km) * 2
if wave_length < min_wave_length:
logger.error(
"wave_length too short for resolution, must be > %d km",
"Wave_length too short for resolution, must be > %d km",
ceil(min_wave_length),
)
raise Exception()
Expand Down Expand Up @@ -1361,6 +1397,24 @@ def kernel_lanczos(self, lat, wave_length, order=1):
kernel[dist_norm > order] = 0
return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt)

def kernel_loess(self, lat, wave_length, order=1):
"""
https://fr.wikipedia.org/wiki/R%C3%A9gression_locale
"""
order = self.check_order(order)
half_x_pt, half_y_pt, dist_norm = self.estimate_kernel_shape(
lat, wave_length, order
)

def inc_func(xdist):
f = zeros(xdist.size)
f[abs(xdist) < 1] = 1
return f

kernel = (1 - abs(dist_norm) ** 3) ** 3
kernel[abs(dist_norm) > order] = 0
return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt)

def kernel_bessel(self, lat, wave_length, order=1):
"""wave_length in km
order must be int
Expand Down Expand Up @@ -1638,11 +1692,13 @@ def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=Fal
data1[-schema:] = nan
data2[:schema] = nan

d = self.EARTH_RADIUS * 2 * pi / 360 * 2 * schema
# Distance for one degree
d = self.EARTH_RADIUS * 2 * pi / 360
# Mulitply by 2 step
if vertical:
d *= self.ystep
d *= self.ystep * 2 * schema
else:
d *= self.xstep * cos(deg2rad(self.y_c))
d *= self.xstep * cos(deg2rad(self.y_c)) * 2 * schema
return (data1 - data2) / d

def compute_stencil(
Expand Down Expand Up @@ -1855,7 +1911,7 @@ def speed_coef_mean(self, contour):
return mean_on_regular_contour(
self.x_c,
self.y_c,
self._speed_ev,
self._speed_ev.data,
self._speed_ev.mask,
contour.vertices,
nan_remove=True,
Expand Down Expand Up @@ -1945,7 +2001,7 @@ def interp(self, grid_name, lons, lats, method="bilinear"):
g = self.grid(grid_name)
m = self.get_mask(g)
return interp2d_geo(
self.x_c, self.y_c, g, m, lons, lats, nearest=method == "nearest"
self.x_c, self.y_c, g.data, m, lons, lats, nearest=method == "nearest"
)

def uv_for_advection(
Expand Down Expand Up @@ -1981,7 +2037,7 @@ def uv_for_advection(
u = -u
v = -v
m = u.mask + v.mask
return u, v, m
return u.data, v.data, m

def advect(self, x, y, u_name, v_name, nb_step=10, rk4=True, **kw):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class CheltonTracker(Model):

__slots__ = tuple()

GROUND = RegularGridDataset(
Expand Down
32 changes: 24 additions & 8 deletions src/py_eddy_tracker/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,27 @@
Tool method which use mostly numba
"""

from numba import njit, prange
from numba import types as numba_types
from numpy import (absolute, arcsin, arctan2, bool_, cos, empty, floor,
histogram, interp, isnan, linspace, nan, ones, pi, radians,
sin, where, zeros)
from numba import njit, prange, types as numba_types
from numpy import (
absolute,
arcsin,
arctan2,
bool_,
cos,
empty,
floor,
histogram,
interp,
isnan,
linspace,
nan,
ones,
pi,
radians,
sin,
where,
zeros,
)


@njit(cache=True)
Expand Down Expand Up @@ -285,14 +301,14 @@ def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y):


@njit(cache=True, fastmath=True)
def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None):
def uniform_resample(x_val, y_val, num_fac=2, fixed_size=-1):
"""
Resample contours to have (nearly) equal spacing.
:param array_like x_val: input x contour coordinates
:param array_like y_val: input y contour coordinates
:param int num_fac: factor to increase lengths of output coordinates
:param int,None fixed_size: if defined, will be used to set sampling
:param int fixed_size: if > -1, will be used to set sampling
"""
nb = x_val.shape[0]
# Get distances
Expand All @@ -303,7 +319,7 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None):
dist[1:][dist[1:] < 1e-3] = 1e-3
dist = dist.cumsum()
# Get uniform distances
if fixed_size is None:
if fixed_size == -1:
fixed_size = dist.size * num_fac
d_uniform = linspace(0, dist[-1], fixed_size)
x_new = interp(d_uniform, dist, x_val)
Expand Down
Loading

0 comments on commit 1b9ab25

Please sign in to comment.