In [None]:
# standard imports
import numpy as np
import xarray as xr
import rioxarray
import matplotlib.pyplot as plt

# load up the saved dataset
import pickle
# fp = '/Users/zachkeskinen/Documents/spicy-snow/tests/test_data/2_img_ds'
fp = '/Users/zachkeskinen/Documents/spicy-snow/data/10_img_dB.pkl'
# fp = '/Users/zachkeskinen/Documents/spicy-snow/data/main_test_proc.pkl'
with open(fp, 'rb') as f:
    ds = pickle.load(f)

# add module to path so python can import it. This is the directory with 
# the __init__.py file in it to let python know this is a module.
import sys
sys.path.append('/Users/zachkeskinen/Documents/spicy-snow/spicy_snow')

from processing.snow_index import calc_prev_snow_index, calc_snow_index, find_repeat_interval, calc_delta_gamma

In [None]:
ds['fcf'] = ds['fcf']/100

In [None]:
ds = calc_delta_gamma(ds)

In [None]:
# standard imports
import numpy as np
import xarray as xr
import rioxarray
import matplotlib.pyplot as plt

# add module to path so python can import it. This is the directory with 
# the __init__.py file in it to let python know this is a module.
import sys
sys.path.append('/Users/zachkeskinen/Documents/spicy-snow/spicy_snow')

from processing.snow_index import calc_prev_snow_index, calc_snow_index, find_repeat_interval

backscatter = np.random.randn(10, 10, 3, 3)
deltaGamma = np.random.randn(10, 10 , 3)
times = [np.datetime64(t) for t in ['2020-01-01', '2020-01-07', '2020-01-14']]
x = np.linspace(0, 9, 10)
y = np.linspace(10, 19, 10)
lon, lat = np.meshgrid(x, y)

test_ds = xr.Dataset(
    data_vars = dict(
        s1 = (["x", "y", "time", "band"], backscatter),
        deltaGamma = (["x", "y", "time"], deltaGamma)
    ),

    coords = dict(
        lon = (["x", "y"], lon),
        lat = (["x", "y"], lat),
        band = ['VV', 'VH', 'inc'],
        time = times,
        relative_orbit = (["time"], [24, 24, 24])))

In [None]:
ds1= calc_snow_index(test_ds)

In [None]:
assert ds1['snow_index'].isel(time = 0).isnull().sum() == 100

In [None]:
ds1['snow_index'].isel(time = 2)

In [None]:
ds1['deltaGamma'].isel(time = 1) + ds1['deltaGamma'].isel(time = 2)

In [None]:
np.allclose(ds1['snow_index'].isel(time = 1), ds1['deltaGamma'].isel(time = 1))

In [None]:
ds1['deltaGamma'].isel(time = 1)

In [None]:
ds1['snow_index'].isel(time = 1)

In [None]:
def calc_prev_snow_index(dataset: xr.Dataset, current_time: np.datetime64, repeat: pd.Timedelta) -> xr.DataArray:
    """
    Calculate previous snow index for +/- 5 days (6 day timestep) or +/- 11 days 
    (12 day time step) from previous time step (6/12 days)'s snow index

    SI (i, t_previous) = sum (t_pri - 5/11 days, t_pri + 5/11 days)(SI * weights) / sum(weights)

    with:
        w_k: as the inverse distance in time from t_previous so for 6-days: 
        wgts=repmat(win+1-abs([-win:win]),dim,1); [1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1]

    Args:
    dataset: dataset of sentinel-1 images with 'snow-index' data variable
    current_time: the current image date
    repeat: is this region capturing s1 images every 6 or 12 days

    Returns:
    prev_si: the weighted average of previous snow indexes
    """
    # calculate how many days ago we are centering previous snow indexes (6 or 12 days)
    t_prev = current_time - repeat
    # get slice of +- 5 or +- 11 days depending on repeat interval
    t_oldest, t_youngest = pd.to_datetime(t_prev - (repeat - pd.Timedelta('1 day'))) , pd.to_datetime(t_prev + (repeat - pd.Timedelta('1 day')))
    # slice dataset to get all images in previous period
    prev = dataset.sel(time = slice(t_oldest, t_youngest))
    # calculate weights based on days between centered date and image acquistions
    wts = repeat.days - np.abs([int((t - t_prev).days) for t in prev.time.values])
    # calculate previous snow index weighted by time from last acquistion
    prev_si = (prev['snow_index']*wts).sum(dim = 'time')/np.sum(wts)

    return prev_si

def calc_snow_index(dataset: xr.Dataset, inplace: bool = False) -> xr.Dataset:
    """
    Calculate snow index for each time step from previous time steps' snow index
    weights, and current delta-gamma.

    SI (i, t) = SI (i, t_previous) + delta-gamma (i, t)

    with SI (i, t_previous) as:
        SI (i, t_previous) = sum (t_pri - 5/11 days, t_pri + 5/11 days)(SI * weights) / sum(weights)

    Args:
    dataset: Xarray Dataset of sentinel images with delta-gamma
    inplace: operate on dataset in place or return copy

    Returns:
    dataset: Xarray Dataset of sentinel images with snow-index added as band
    """
    # check inplace flag
    if not inplace:
        dataset = dataset.copy(deep=True)

    # set all snow index to 0 to start
    dataset['snow_index'] = xr.zeros_like(dataset['deltaGamma'])

    # find repeat interval of dataset
    repeat = find_repeat_interval(dataset)

    # iterate through time steps
    for ct in dataset.time.values:
        # calculate previous snow index
        prev_si = calc_prev_snow_index(dataset, ct, repeat)
        # add deltaGamma to previous snow inded
        dataset['snow_index'].loc[dict(time = ct)] = prev_si + dataset['deltaGamma'].sel(time = ct)
    
    if not inplace:
        return dataset

In [None]:
backscatter = np.random.randn(10, 10, 6, 3)
deltaGamma = np.random.randn(10, 10 , 6)
times = [np.datetime64(t) for t in ['2020-01-01','2020-01-02', '2020-01-07','2020-01-08', '2020-01-14', '2020-01-15']]
x = np.linspace(0, 9, 10)
y = np.linspace(10, 19, 10)
lon, lat = np.meshgrid(x, y)

test_ds = xr.Dataset(
    data_vars = dict(
        s1 = (["x", "y", "time", "band"], backscatter),
        deltaGamma = (["x", "y", "time"], deltaGamma),
        snow_index = (["x", "y", "time"], np.zeros_like(deltaGamma)),
    ),

    coords = dict(
        lon = (["x", "y"], lon),
        lat = (["x", "y"], lat),
        band = ['VV', 'VH', 'inc'],
        time = times,
        relative_orbit = (["time"], [24, 1, 24, 1, 24, 1])))

ds = calc_snow_index(test_ds)

In [None]:
ds['snow_index'].isel(time = 1)*5/(6+5) + ds['deltaGamma'].isel(time = 2)

In [None]:
v = ds['snow_index'].isel(time = 2) == ds['snow_index'].isel(time = 1)*5/(6+5) + ds['deltaGamma'].isel(time = 2)