In [14]:
import numpy as np
import matplotlib.pyplot as plt

In [15]:
from numpy.typing import ArrayLike
import xarray as xr
import cftime
from importlib import reload

In [16]:
import sys
sys.path.append('../')
import constants
sys.path.append(constants.MODULE_DIR)

In [17]:
import xarray_class_accessors as xca

## Setup

In [18]:
def generate_mock_dataset(input_mock_values:ArrayLike) -> xr.DataArray:
    t0 = cftime.datetime(1, 1, 1, 0, 0, 0, 0, calendar='gregorian')

    lat = np.linspace(-25, -10, 10)
    lon = np.linspace(110, 135, 20)

    time = xr.cftime_range(start=t0, periods=len(input_mock_values), freq='1Y')

    mock_values = np.array([np.array([input_mock_values for _ in lat]) for _ in lon])

    mock_da = xr.Dataset({'tas':(('lon', 'lat', 'time'), mock_values)},
                         {'lat':lat, 'lon':lon,'time':time}).to_array(name='tas').squeeze()
    
    return mock_da

In [19]:
number_points = 100
gradient = 2
window = 30
min_periods = window

In [20]:
input_mock_values = np.arange(0, number_points) * gradient
mock_da = generate_mock_dataset(input_mock_values)

In [21]:
window=31

In [22]:
mock_signal_new = mock_da.sn.rolling_signal(window=window)

In [24]:
mock_signal_new.chunk('auto').chunks

((1,), (20,), (10,), (70,))

## Test

# No Import of XCA

In [58]:
def trend_line(x, use = [0][0]):
    if all(~np.isfinite(x)):
        return np.nan

    t = np.arange(len(x))

    # Getting the gradient of a linear interpolation
    idx = np.isfinite(x) #checking where the nans.
    x = x[idx]
    t = t[idx]

    if len(x) < 3:
        return np.nan

    poly = np.polyfit(t,x,1)

    return poly[use]

def _apply_along_helper(arr, axis, func1d,logginglevel='ERROR'):
    axis = axis if isinstance(axis, int) else axis[0]

    # func1ds, axis, arr 
    return np.apply_along_axis(func1d, axis, arr)

In [63]:
signal_da = mock_da.rolling(time = window, min_periods = min_periods, center = True)\
    .reduce(_apply_along_helper, func1d = trend_line) * window

In [258]:
def mult_func(arr, arr2):
    return arr * arr2

def calculate_grad(arr, axis):
    print(f'arr={arr.shape}')
    xs = np.arange(arr.shape[axis])
    
    xs_mult_arr = np.apply_along_axis(mult_func, axis=axis, arr=arr, arr2=xs)
    denominator = np.mean(xs) **2 - np.mean(xs**2)
    
    t1 = np.nanmean(xs) * np.nanmean(arr, axis=axis)
    t2 = np.nanmean(xs_mult_arr, axis=axis)
    numerator = (t1-t2)
    print(f't1={t1.shape}, t2={t2.shape}')
    print(f'den = {denominator}')
    result = numerator/denominator
    print(f'result={result.shape}')
    return result

In [318]:
window=30

In [149]:
time_axis_num = mock_da.get_axis_num('time')
time_axis_num

2

In [327]:
mock_sig_01 = mock_da.rolling(time=window, min_periods=min_periods,center=True).reduce(calculate_grad) * window

arr=(20, 10, 100, 30)
t1=(20, 10, 100), t2=(20, 10, 100)
den = -74.91666666666669
result=(20, 10, 100)


In [328]:
signal_da.isel(lat=0, lon=0)

In [329]:
mock_sig_01.isel(lat=0, lon=0)

# Prototyping

In [150]:
ys = mock_da.values
ys.shape

(20, 10, 100)

In [293]:
ys[...,:30].shape

(20, 10, 30)

In [294]:
ys = ys[...,:30]

In [295]:
xs = np.arange(ys.shape[time_axis_num])
xs.shape

(30,)

In [296]:
xs

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])

In [297]:
applied = np.apply_along_axis(mult_func, axis=time_axis_num, arr=ys, arr2=xs)

In [298]:
xs

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])

In [299]:
ys[0,0,:]

array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32,
       34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58])

In [300]:
applied[0,0,:]

array([   0,    2,    8,   18,   32,   50,   72,   98,  128,  162,  200,
        242,  288,  338,  392,  450,  512,  578,  648,  722,  800,  882,
        968, 1058, 1152, 1250, 1352, 1458, 1568, 1682])

In [301]:
xs_mult_ys = np.apply_along_axis(mult_func, axis=time_axis_num, arr=ys, arr2=xs)

In [302]:
(np.mean(xs) * np.mean(ys, axis=time_axis_num)).shape

(20, 10)

In [304]:
np.mean(xs) **2 - np.mean(xs**2)

-74.91666666666669

In [305]:
np.polyfit(xs, ys[0,0,:], 1)

array([2.00000000e+00, 2.59453523e-15])

In [306]:
(np.mean(xs) * np.mean(ys, axis=time_axis_num) - np.mean(xs_mult_ys, axis=time_axis_num))/(np.mean(xs) **2 - np.mean(xs**2))

array([[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]])

In [278]:
td = mock_da.rolling(time=window, min_periods=min_periods,center=True).construct('window_dim')

In [290]:
td.mean(dim='window_dim')

In [289]:
mock_da.isel(lat=0, lon=0, time=3)

In [288]:
td.isel(lat=0, lon=0, time=3)

In [320]:
mock_sig_01.isel(lat=1, lon=0)

In [255]:
mock_da.rolling(time=window, min_periods=min_periods,center=True).reduce(calculate_grad,
                                                                         time_axis_num=time_axis_num)

arr=(20, 10, 100, 30)
t1=(20, 10, 100), t2=(20, 10, 100)
den = -74.91666666666669
result=(20, 10, 100)


In [243]:
def add_1(arr, axis, time_axis_num):
    print(arr.shape)
    return arr + 1

In [253]:
window

30

In [245]:
_=mock_da.rolling(time=window, min_periods=min_periods,center=True).reduce(add_1,
                                                                         time_axis_num=time_axis_num)

(20, 10, 100, 30)
