In [81]:
import numpy as np
from numpy import sinc, pi, sin, cos, arctan2, empty, ma, ceil, errstate, ones, arange, meshgrid
from scipy.special import j1
from cv2 import filter2D
from netCDF4 import Dataset
from datetime import datetime, timedelta
from scipy.signal import bessel

In [None]:
from ../script/generate_segmentation_mask import eddy
regions_list = ["Gulfstream"]
years = ["2015", "2016", "2017", "2018", "2019", "2020"]
for region_ in regions_list:
    os.makedirs(f"/albedo/home/ssunar/CNN_eddy_detection/for_paper/segmentation_masks/{region_}", exist_ok=True)
    for year in years:
        for month in range(1,13):
            data_addr_nn = f'/albedo/home/ssunar/CNN_eddy_detection/for_paper/interpolation/{region_}/interpolation_{year}_001_'+str(month).zfill(2)+'.nc'
            eddy_instance = eddy(dataset_path=data_addr_nn)
            outfile = f"/albedo/home/ssunar/CNN_eddy_detection/for_paper/segmentation_masks/{region_}/seg_mask_gridded_{year}_001_"+str(month).zfill(2)+"_new.nc"
            eddy_instance.generate_mask(outfile)

In [58]:
def distance(lon0, lat0, lon1, lat1):
    """
    Compute distance between points from each line.
    :param float lon0:
    :param float lat0:
    :param float lon1:
    :param float lat1:
    :return: distance (in m)
    :rtype: array
    """
    D2R = pi / 180.0
    sin_dlat = sin((lat1 - lat0) * 0.5 * D2R)
    sin_dlon = sin((lon1 - lon0) * 0.5 * D2R)
    cos_lat1 = cos(lat0 * D2R)
    cos_lat2 = cos(lat1 * D2R)
    a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2
    return 6370997.0 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5)

In [59]:
def get_step_in_km(lat, wave_length, xstep, ystep):
    
    step_y_km = ystep * distance(0, 0, 0, 1) / 1000
    step_x_km = xstep * distance(0, lat, 1, lat) / 1000
    min_wave_length = max(step_x_km, step_y_km) * 2
    if wave_length < min_wave_length:
        raise Exception()
    return step_x_km, step_y_km

In [60]:
def estimate_kernel_shape(lat, wave_length, order, x_c, y_c):
        xstep = (x_c[1:] - x_c[:-1]).mean()
        ystep = (y_c[1:] - y_c[:-1]).mean()
        step_x_km, step_y_km = get_step_in_km(lat, wave_length, xstep, ystep)
        # half size will be multiply with by order
        half_x_pt, half_y_pt = (
            ceil(wave_length / step_x_km).astype(int),
            ceil(wave_length / step_y_km).astype(int),
        )
        # x size is not good over 60 degrees
        y = arange(
            lat - ystep * half_y_pt * order,
            lat + ystep * half_y_pt * order + 0.01 * ystep,
            ystep,
        )
        # We compute half + 1 and the other part will be compute by symetry
        x = arange(0, xstep * half_x_pt * order + 0.01 * xstep, xstep)
        y, x = meshgrid(y, x)
        dist_norm = distance(0, lat, x, y) / 1000.0 / wave_length
        return half_x_pt, half_y_pt, dist_norm

In [61]:
def finalize_kernel(kernel, order, half_x_pt, half_y_pt):
        # Symetry
        kernel_ = empty((half_x_pt * 2 * order + 1, half_y_pt * 2 * order + 1))
        kernel_[half_x_pt * order :] = kernel
        kernel_[: half_x_pt * order] = kernel[:0:-1]
        # remove unused row/column
        k_valid = kernel_ != 0
        x_valid = np.where(k_valid.sum(axis=1))[0]
        x_slice = slice(x_valid[0], x_valid[-1] + 1)
        y_valid = np.where(k_valid.sum(axis=0))[0]
        y_slice = slice(y_valid[0], y_valid[-1] + 1)
        return kernel_[x_slice, y_slice]

In [62]:
def kernel_bessel(lat, wave_length, x_c, y_c, order=1):
        """wave_length in km
        order must be int
        """
        half_x_pt, half_y_pt, dist_norm = estimate_kernel_shape(
            lat, wave_length, order, x_c, y_c
        )
        with errstate(invalid="ignore"):
            kernel = sinc(dist_norm / order) * j1(2 * pi * dist_norm) / dist_norm
        kernel[0, half_y_pt * order] = pi
        kernel[dist_norm > order] = 0
        return finalize_kernel(kernel, order, half_x_pt, half_y_pt)

In [63]:
def convolve_filter_with_dynamic_kernel(
        grid, kernel_func, x_c, y_c, lat_max=85, extend=False, **kwargs_func):
        """
        :param str grid: grid name
        :param func kernel_func: function of kernel to use
        :param float lat_max: absolute latitude above no filtering apply
        :param bool extend: if False, only non masked value will return a filtered value
        :param dict kwargs_func: look at kernel_func
        :return: filtered value
        :rtype: array
        """
        
        
        # Matrix for result
        data = grid.copy()
        data_out = ma.empty(data.shape)
        data_out.mask = np.ones(data_out.shape, dtype=bool)
        nb_lines = y_c.shape[0]
        dt = list()

        debug_active = False

        for i, lat in enumerate(y_c):
            if abs(lat) > lat_max or data[:, i].mask.all():
                data_out.mask[:, i] = True
                continue
            # Get kernel
            kernel = kernel_func(lat=lat, x_c=x_c, y_c=y_c, **kwargs_func)
            # Kernel shape
            k_shape = kernel.shape
            t0 = datetime.now()
            if debug_active and len(dt) > 0:
                dt_mean = np_mean(dt) * (nb_lines - i)
                print(
                    "Remain ",
                    dt_mean,
                    "ETA ",
                    t0 + dt_mean,
                    "current kernel size :",
                    k_shape,
                    "Step : %d/%d    " % (i, nb_lines),
                    end="\r",
                )

            # Half size, k_shape must be always impair
            d_lat = int((k_shape[1] - 1) / 2)
            d_lon = int((k_shape[0] - 1) / 2)
            # Temporary matrix to have exact shape at outuput
            tmp_matrix = ma.zeros((2 * d_lon + data.shape[0], k_shape[1]))
            tmp_matrix.mask = ones(tmp_matrix.shape, dtype=bool)
            # Slice to apply on input data
            # +1 for upper bound, to take in acount this column
            sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat + 1, data.shape[1]))
            # slice to apply on temporary matrix to store input data
            sl_lat_in = slice(
                d_lat - (i - sl_lat_data.start), d_lat + (sl_lat_data.stop - i)
            )
            # Copy data
            tmp_matrix[d_lon:-d_lon, sl_lat_in] = data[:, sl_lat_data]
            # Convolution
            m = ~tmp_matrix.mask
            tmp_matrix[~m] = 0

            demi_x, demi_y = k_shape[0] // 2, k_shape[1] // 2
            values_sum = filter2D(tmp_matrix.data, -1, kernel)[demi_x:-demi_x, demi_y]
            kernel_sum = filter2D(m.astype(float), -1, kernel)[demi_x:-demi_x, demi_y]
            with errstate(invalid="ignore", divide="ignore"):
                if extend:
                    data_out[:, i] = ma.array(
                        values_sum / kernel_sum,
                        mask=kernel_sum < (extend * kernel.sum()),
                    )
                else:
                    data_out[:, i] = values_sum / kernel_sum
            dt.append(datetime.now() - t0)
            if len(dt) == 100:
                dt.pop(0)
        if extend:
            out = ma.array(data_out, mask=data_out.mask)
        else:
            out = ma.array(data_out, mask=data.mask + data_out.mask)
        if debug_active:
            print()
        if out.dtype != data.dtype:
            return out.astype(data.dtype)
        return out

In [77]:
def bessel_high_filter(data, wave_length, y_c, x_c, order=1, lat_max=85, **kwargs):
        """
        :param str grid_name: grid to filter, data will replace original one
        :param float wave_length: in km
        :param int order: order to use, if > 1 negative values of the cardinal sinus are present in kernel
        :param float lat_max: absolute latitude, no filtering above
        :param dict kwargs: look at :py:meth:`RegularGridDataset.convolve_filter_with_dynamic_kernel`

        .. minigallery:: py_eddy_tracker.RegularGridDataset.bessel_high_filter
        """
        data_out = convolve_filter_with_dynamic_kernel(
            data,
            kernel_bessel,
            lat_max=lat_max,
            wave_length=wave_length,
            order=order,
            x_c=x_c,
            y_c=y_c, 
            **kwargs,
        )
        return data - data_out

In [78]:
data_addr_nn = '/work/ollie/bpanthi/nn_interpolation_new/ssh_gridded_196'+str(1)+'_001_'+str(1).zfill(2)+'_new.nc'
ds = Dataset(data_addr_nn)


In [79]:
lat = ds.variables["LATITUDE"][:]
lon = ds.variables["LONGITUDE"][:]
data = ds.variables["ssh"][1-1, :]
g = bessel_high_filter(data, wave_length = 500, y_c = lat, x_c = lon)

In [80]:
data

masked_array(
  data=[[-0.7038369178771973, -0.697141706943512, -0.693639874458313,
         ..., --, --, --],
        [-0.7445776462554932, -0.7236122488975525, -0.7205229997634888,
         ..., --, --, --],
        [-0.7601337432861328, -0.7467331290245056, -0.7473640441894531,
         ..., --, --, --],
        ...,
        [-1.6232173442840576, -1.617802619934082, -1.6113624572753906,
         ..., --, --, --],
        [-1.6377304792404175, -1.631127953529358, -1.6258326768875122,
         ..., --, --, --],
        [-1.6477681398391724, -1.6418572664260864, -1.6364002227783203,
         ..., --, --, --]],
  mask=[[False, False, False, ...,  True,  True,  True],
        [False, False, False, ...,  True,  True,  True],
        [False, False, False, ...,  True,  True,  True],
        ...,
        [False, False, False, ...,  True,  True,  True],
        [False, False, False, ...,  True,  True,  True],
        [False, False, False, ...,  True,  True,  True]],
  fill_value=9.96920996838

In [67]:
g


masked_array(
  data=[[0.3367607630048932, 0.34194302733633264, 0.34347102154531495,
         ..., --, --, --],
        [0.2995666427341772, 0.31897187663865734, 0.3200445882001648,
         ..., --, --, --],
        [0.2875474720150013, 0.29934650765875714, 0.2966630053230037,
         ..., --, --, --],
        ...,
        [-0.013228726941724833, -0.007120927147916678,
         -5.222077763145094e-05, ..., --, --, --],
        [-0.028218077546626086, -0.02086773168470324,
         -0.014892946996818424, ..., --, --, --],
        [-0.038686122149789615, -0.03197018950595987,
         -0.025780207813964395, ..., --, --, --]],
  mask=[[False, False, False, ...,  True,  True,  True],
        [False, False, False, ...,  True,  True,  True],
        [False, False, False, ...,  True,  True,  True],
        ...,
        [False, False, False, ...,  True,  True,  True],
        [False, False, False, ...,  True,  True,  True],
        [False, False, False, ...,  True,  True,  True]],
  fill_val

In [68]:
scipy.signal.bessel

masked_array(data=[-60.      , -59.916668, -59.833332, -59.75    ,
                   -59.666668, -59.583332, -59.5     , -59.416668,
                   -59.333332, -59.25    , -59.166668, -59.083332,
                   -59.      , -58.916668, -58.833332, -58.75    ,
                   -58.666668, -58.583332, -58.5     , -58.416668,
                   -58.333332, -58.25    , -58.166668, -58.083332,
                   -58.      , -57.916668, -57.833332, -57.75    ,
                   -57.666668, -57.583332, -57.5     , -57.416668,
                   -57.333332, -57.25    , -57.166668, -57.083332,
                   -57.      , -56.916668, -56.833332, -56.75    ,
                   -56.666668, -56.583332, -56.5     , -56.416668,
                   -56.333332, -56.25    , -56.166668, -56.083332,
                   -56.      , -55.916668, -55.833332, -55.75    ,
                   -55.666668, -55.583332, -55.5     , -55.416668,
                   -55.333332, -55.25    , -55.166668, -55.083