In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path
from glob import glob
from datetime import datetime
import sys
sys.path.append('../../../spicy-snow/')

from spicy_snow.processing.snow_index import calc_delta_cross_ratio, calc_delta_gamma, \
    clip_delta_gamma_outlier, calc_snow_index, calc_snow_index_to_snow_depth
from spicy_snow.processing.wet_snow import id_newly_wet_snow, id_wet_negative_si, \
    id_newly_frozen_snow, flag_wet_snow
from spicy_snow.retrieval import retrieval_from_parameters

from dask.distributed import Client


In [None]:
client = Client(local_directory='/tmp', processes=False)
client

In [None]:
files = sorted(glob('spicy_s1_stacks/*.nc'))


f = files[1]
ds_name = f.split('stacks/')[-1].split('.')[0]
print(datetime.now(), f' -- starting {ds_name}')

    # Open dataset 
ds_ = xr.open_dataset(f).load()
dataset = ds_[['s1','deltaVV','ims','fcf','lidar-sd']]
td = abs(pd.to_datetime(dataset.time) - pd.to_datetime(dataset.attrs['lidar-flight-time']))
closest_ts_idx = np.where(td == td.min())[0][0]
closest_ts = dataset.time[closest_ts_idx]

a = 2.5
b = 0.2
c = 0.55

In [None]:
ds = retrieval_from_parameters(dataset,A=a,B=b,C=c,wet_SI_thresh=2,freezing_snow_thresh=1,wet_snow_thres=-2)

In [None]:
ds['wet_snow'].plot(col='time',col_wrap=10)
#ds['wet_flag'].plot(col='time',col_wrap=10)
#ds['alt_wet_flag'].plot(col='time',col_wrap=10)
#ds['freeze_flag'].plot(col='time',col_wrap=10)
#ds['perma_wet'].plot(col='time',col_wrap=10)

#ds['snow_index'].plot(col='time',col_wrap=10)


In [None]:
mask_wet = ~(ds['lidar-sd'].isnull() | ds['snow_depth'].sel(time=closest_ts).isnull() | ds['wet_snow'].sel(time=closest_ts).astype(bool))
diff_wet = ds['lidar-sd'].where(mask_wet) - ds['snow_depth'].sel(time=closest_ts).where(mask_wet)
rmse_wet = float(np.sqrt((diff_wet**2).sum()/len(diff_wet.values.flatten())))
print(f'RMSE with wet snow masked out = {rmse_wet:0.2f}')
#rmse_wet_flag.loc[a, b, c,wst] = rmse_wet
# Compare snow depths - no wet snow mask
mask = ~(ds['lidar-sd'].isnull() | pd.isnull(ds['snow_depth'].sel(time=closest_ts)))
diff = ds['lidar-sd'].where(mask) - ds['snow_depth'].sel(time=closest_ts).where(mask)
rmse = float(np.sqrt((diff**2).sum()/len(diff.values.flatten())))
print(f'Full RMSE = {rmse:0.2f}')
#rmse_no_flag.loc[a,b,c,wst] = rmse
#valid_pixels.loc[a,b,c,wst] = mask_wet.sum() / mask.sum()
print(f'Frac valid pixels = {mask_wet.sum()/ mask.sum():0.2f}')

In [None]:
f,ax=plt.subplots(1,2,figsize=(10,4))
mask.plot(ax=ax[0])
mask_wet.plot(ax=ax[1])

In [None]:
files = sorted(glob('spicy_s1_stacks/*.nc'))

# Create parameter space
a = 2.5
b = 0.2
c = 0.55

wet_snow_thresh = np.arange(-3, -0.9, 0.1)
freeze_snow_thresh = np.arange(1, 3.1, 0.1)
SI_thresh = [0,-100]

total_count = len(wet_snow_thresh)*len(freeze_snow_thresh)*len(SI_thresh)

for f in files:
    ds_name = f.split('stacks/')[-1].split('.')[0]
    print(datetime.now(), f' -- starting {ds_name}')

    if Path(f'rmse_test_wet_snow/{ds_name}_wet_flag.nc').is_file():
        print('This file already exists, continuing.')
        continue

    # Open dataset 
    ds_ = xr.open_dataset(f).load()
    dataset = ds_[['s1','deltaVV','ims','fcf','lidar-sd']]
    td = abs(pd.to_datetime(dataset.time) - pd.to_datetime(dataset.attrs['lidar-flight-time']))
    closest_ts_idx = np.where(td == td.min())[0][0]
    closest_ts = dataset.time[closest_ts_idx]

    if 'Frasier_2020-02-11' in ds_name:
        closest_ts = '2020-02-16T13:09:43.000000000'
    
    # Initialize RMSE arrays
    rmse_wet_flag = xr.DataArray(np.empty((len(SI_thresh),len(wet_snow_thresh), len(freeze_snow_thresh)))*np.nan,
                        coords=(SI_thresh,wet_snow_thresh,freeze_snow_thresh), dims=('SI_thresh','wet_snow_thresh','freeze_snow_thresh'))
    rmse_no_flag = xr.DataArray(np.empty((len(SI_thresh),len(wet_snow_thresh), len(freeze_snow_thresh)))*np.nan,
                        coords=(SI_thresh,wet_snow_thresh,freeze_snow_thresh), dims=('SI_thresh','wet_snow_thresh','freeze_snow_thresh'))
    valid_pixels = xr.DataArray(np.empty((len(SI_thresh),len(wet_snow_thresh), len(freeze_snow_thresh)))*np.nan,
                        coords=(SI_thresh,wet_snow_thresh,freeze_snow_thresh), dims=('SI_thresh','wet_snow_thresh','freeze_snow_thresh'))
    # Brute-force loop
    for wst in wet_snow_thresh:
        for fst in freeze_snow_thresh:
            for sit in SI_thresh:
                print(f'sit={sit:0.2f}, wst={wst:0.2f}; fst={fst:0.2f}')

                ds = retrieval_from_parameters(dataset,A=a,B=b,C=c,wet_SI_thresh=sit,freezing_snow_thresh=fst,wet_snow_thres=wst)

                mask_wet = ~(ds['lidar-sd'].isnull() | ds['snow_depth'].sel(time=closest_ts).isnull() | ds['wet_snow'].sel(time=closest_ts).astype(bool))
                diff_wet = ds['lidar-sd'].where(mask_wet) - ds['snow_depth'].sel(time=closest_ts).where(mask_wet)
                rmse_wet = float(np.sqrt((diff_wet**2).sum()/len(diff_wet.values.flatten())))
                print(f'RMSE with wet snow masked out = {rmse_wet:0.2f}')
                rmse_wet_flag.loc[sit,wst,fst] = rmse_wet
                # Compare snow depths - no wet snow mask
                mask = ~(ds['lidar-sd'].isnull() | pd.isnull(ds['snow_depth'].sel(time=closest_ts)))
                diff = ds['lidar-sd'].where(mask) - ds['snow_depth'].sel(time=closest_ts).where(mask)
                rmse = float(np.sqrt((diff**2).sum()/len(diff.values.flatten())))
                print(f'Full RMSE = {rmse:0.2f}')
                rmse_no_flag.loc[sit,wst,fst] = rmse
                valid_pixels.loc[sit,wst,fst] = mask_wet.sum() / mask.sum()
                print(f'Frac valid pixels = {mask_wet.sum()/ mask.sum():0.2f}')


    # After loop, save RMSE results per file
    rmse_wet_flag.to_netcdf(f'rmse_test_wet_snow/{ds_name}_wet_flag.nc')
    rmse_no_flag.to_netcdf(f'rmse_test_wet_snow/{ds_name}_no_flag.nc')
    valid_pixels.to_netcdf(f'rmse_test_wet_snow/{ds_name}_valid_pixels.nc')
    

In [None]:
directory = 'rmse_test_wet_snow'


which_site = 0

results1 = sorted(glob(f'{directory}/*wet*.nc'))
results2 = sorted(glob(f'{directory}/*no*.nc'))
results3 = sorted(glob(f'{directory}/*valid*.nc'))

wet_snow = xr.open_dataarray(results1[which_site])
all_snow = xr.open_dataarray(results2[which_site])
frac_valid = xr.open_dataarray(results3[which_site])

all_rmse = xr.concat([wet_snow,all_snow],'wet_or_all')



In [None]:
f,ax=plt.subplots(1,2)
wet_snow.sel(SI_thresh=0).plot(ax=ax[0])
frac_valid.sel(SI_thresh=0).plot(ax=ax[1])

In [None]:
all_rmse

In [None]:
sit = 0
f=all_rmse.sel(SI_thresh=sit).plot(hue='wet_or_all',col='wet_snow_thresh',add_legend=False)
for wst,ax in zip(wet_snow_thresh,f.axs[0]):
    frac_ax = ax.twinx()
    fv = frac_valid.sel(SI_thresh=sit,wet_snow_thresh=wst).plot(ax=frac_ax,color='green',label='dry pixel fraction')
    frac_ax.set_title('')
    ax.axvline(wet_snow.sel(SI_thresh=sit,wet_snow_thresh=wst).idxmin(),color='black',linestyle='--')
    ax.set_title('')
    dry_percent = 100*frac_valid.sel(SI_thresh=sit,wet_snow_thresh=wst,freeze_snow_thresh=float(wet_snow.sel(SI_thresh=sit,wet_snow_thresh=wst).idxmin()))
    ax.set_title(f'sit={sit:0.1f}, wst={wst:0.1f}, \n min(RMSE)={float(wet_snow.sel(SI_thresh=sit,wet_snow_thresh=wst).min()):0.2f} @ {float(wet_snow.sel(SI_thresh=sit,wet_snow_thresh=wst).idxmin()):0.2f}dB,\n Dry={dry_percent:0.2f}%')

    
ax.legend(labels=['wet snow mask','no mask'], title= 'RMSE', loc='lower right')
frac_ax.legend(handles=fv,labels=['Dry pixel fraction'], loc='upper right')
plt.tight_layout()