In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
start_date = "2021-11-10"
end_date = "2021-12-10"
date_list = pd.date_range(start_date, end_date).to_list()

In [None]:
precip = xr.open_dataset(f'precip_data/ppt_{start_date}-{end_date}.nc')
tmean = xr.open_dataset(f'precip_data/tmean_{start_date}-{end_date}.nc')

In [None]:
precip.lat.values

In [None]:
snow_threshold = 0 # degrees celsius
mf = 5 # melt factor, mm/degree celsius that snow melts when T>snow threshold

snow_accum = np.zeros((12, 21))

dss = []

first_day_precip = precip['ppt'].sel(date=date_list[0])
first_day_temp = tmean['tmean'].sel(date=date_list[0])
first_day_snowfall = np.where(first_day_temp < snow_threshold, first_day_precip, 0)

net_snow_da = xr.DataArray(first_day_snowfall,
                            dims=['lat', 'lon'],
                            coords={'lat': precip.lat.values,
                                    'lon': precip.lon.values})

net_snow_da = net_snow_da.expand_dims(dim = 'date')
net_snow_da.coords['date'] = ('date', [date_list[0]])

dss.append(net_snow_da)

for date in date_list[1:]:
    # get precip and temp for the current loop
    date_precip = precip['ppt'].sel(date=date)
    date_temp = tmean['tmean'].sel(date=date)
    
    # get net snowfall/melt for the day
    date_snowfall = np.where(date_temp < snow_threshold, date_precip, 0)
    date_snowmelt = np.where(date_temp > snow_threshold, mf*(date_temp - snow_threshold), 0)
    net_snow_per_day = date_snowfall - date_snowmelt
    
    # get accumulated snow so far
    net_snow_total = net_snow_da + net_snow_per_day
    
    # remove negative values (can't have negative snow)
    net_snow_idx = (net_snow_total < 0)
    net_snow_total = xr.where(net_snow_idx, 0, net_snow_total)
    
    net_snow_da = 0
    
    net_snow_da = xr.DataArray(net_snow_total,
                            dims=['date', 'lat', 'lon'],
                            coords={'date': [date],
                                    'lat': precip.lat.values,
                                    'lon': precip.lon.values})
    
    dss.append(net_snow_da)
    
final_snow = xr.concat(dss[1:30], 'date')

In [None]:
final_snow.isel(date=slice(0, 31, 1)).plot.imshow(col = 'date', col_wrap = 3, vmin = 0, vmax = 20)