# GEOG 497 - Spring 2022 - Cryosphere & Climate Systems
## A6: Remote sensing of ice sheet surface melt

## Part 3: Applying and evaluating your melt detection method at other AWS sites.

Input data: ASCAT backscatter and AWS-derived melt rates (see Part 1 for details).

For this part of the assignment, you will scale-up your melt detection method and your chosen melt_threshold to other AWS sites in Antarctica. Does it work equally well elsewhere? And how well does ASCAT do in detecting melt across the broader AWS network?

### Run the following code block first.
This loads python packages and sets some plotting-related variables. 

In [None]:
# Import python packages

# for file searching
import glob 

# for data reading/analysis
import xarray as xr
import pandas as pd
import numpy as np

#  for geographic projections
from pyproj import Transformer

# for plotting
import matplotlib.pyplot as plt
import matplotlib.style as style
import matplotlib.dates as mdates
import matplotlib.lines as mlines
import matplotlib.transforms as mtransforms
from matplotlib.dates import DateFormatter
import datetime

# seaborn adds some extra visual appeal to our plots
import seaborn as sns

# Handle date time conversions between pandas and matplotlib
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()

# set some universal plot settings here
plt.rcParams["figure.dpi"] = 200
plt.rcParams['axes.xmargin'] = 0.05
sns.set_style('darkgrid')
sns.set_context("notebook", font_scale=0.75)
%config InlineBackend.figure_format = 'retina' # make high res plots for retina 5k displays

# lastly, this specifies where the AWS and ASCAT data are located
awsFolder = './Data/'
ascatFolder = '/gpfs/group/ljt5282/default/ASCAT_data/netcdfs/annual/'

### Preliminary step 1: read all the ASCAT data files
This reads all of the continent-wide, daily temporal resolution ASCAT netcdf files using `xarray`. Afterward, we will subset these data to only the grid cell nearest to the AWS location of interest.

Here, we'll also read netcdfs containing winter-mean (June, July, August; JJA) backscatter grids that have been pre-calculated.

In [None]:
# read all yearly netcdfs using xarray and create new xarray DataSet variable
ascat_ds = xr.open_mfdataset(ascatFolder + 'msfa*.nc')

# read a combined jja means netcdf file
ascat_jja_ds = xr.open_dataset(ascatFolder + 'jja_means/msfa-a-Ant2007-2021_jja_means.nc')

## Define an `ascat_melt_detector` function.
Running this code block defines a new python function that performs all of the operations previously in individual code blocks in Part 1 and 2 of this assignment. After defining it, you'll be able to run the full melt detection method by only defining a few variables. 

The input variables are:
1. `aws_name`: string containing a name for the AWS (e.g., 'AWS 14').
2. `aws_fname`: string containing the file name for this AWS located in the Data folder (e.g., 'IMAU_aws18_high-res_meteo.tab')
3. `aws_lat`: latitude of the AWS (e.g., -71.49)
4. `aws_lon`: longitude of the AWS (e.g., -50.03)
5. `melt_thresh`: threshold value in dB below the winter mean at which point to classify melt (e.g., 1)
6. `showPlot`: True or False - will show a plot of the melt detection at that AWS site.



In [None]:
def ascat_melt_detector(aws_name, aws_fname, aws_lat, aws_lon, melt_thresh, showPlot):
    # Read the data
    aws_datafile = aws_fname
    aws_df = pd.read_csv(awsFolder + aws_datafile,
                            sep='\t', 
                            skiprows=40,
                            parse_dates=['Date/Time'],
                            index_col=['Date/Time'])

    # Resample hourly data to daily means and set as new pandas DataFrame variable
    aws_df_daily = aws_df.resample('D').mean()
    
    # Limit AWS data to ASCAT era (2008-2020)
    aws_df_daily = aws_df_daily['2008-01-01':'2020-12-31']
    
    # Use pyproj convert the lon/lat coordinates to the projected x/y coordinates of the ASCAT data
    transformer = Transformer.from_crs("epsg:4326", "epsg:3976")
    aws_x,aws_y = transformer.transform(aws_lat, aws_lon)

    # Now, subset the ascat data to only the point (i.e., the grid cell nearest to the AWS site)
    ascat_ds_aws = ascat_ds['sigma0'].sel(x=aws_x, method="nearest").sel(y=aws_y, method="nearest")

    # READ THE AWS DATA TIME LIMITS AGAIN
    startyear = aws_df_daily.index.min().year-1
    endyear = aws_df_daily.index.max().year+1
    years = np.arange(startyear,endyear)

    # GET THE WINTER MEANS AT THIS SITE
    jja_sigma0_site = ascat_jja_ds['sigma0'].sel(x=aws_x, method="nearest").sel(y=aws_y, method="nearest")

    # APPLY OUR PREVIOUSLY DEFINED THRESHOLD
    melt_thresholds = jja_sigma0_site - melt_thresh

    # DETECT MELT USING ASCAT AT THIS SITE
    startdates_thresh = []
    enddates_thresh = []
    melt_detected_list = []
    for i in range(len(years)):
        # select the ascat data in this melt year
        startdate_t = str(years[i]) + '-07-01'
        enddate_t = str(years[i]+1) + '-06-30'
        cur_melt_year_ascat = ascat_ds_aws.sel(time=slice(startdate_t, enddate_t))
        # get the threshold for this year
        cur_years_threshold = melt_thresholds.sel(year=years[i]).values
        # set values <= than the threshold to 1, else 0
        cur_melt_classified = xr.where(cur_melt_year_ascat <= cur_years_threshold, 1, 0)
        melt_detected_list.append(cur_melt_classified) 

    melt_detected_da = xr.concat(melt_detected_list, dim='time')

    # PLOT THE RESULTS, if True
    
    if showPlot == True:
        # create a figure
        fig, ax = plt.subplots(figsize=(15, 7.5))

        # plot ascat data on first vetical axis
        line1 = ax.plot(ascat_ds_aws.time, ascat_ds_aws,
                         label='ASCAT sigma0')

        # make second vertical axis; plot aws data on it
        ax2 = ax.twinx()
        line2 = ax2.plot(aws_df_daily['Melt rate [mm w.e.] (surface melt, within dt)'], 
                        label='AWS-derived melt rate', 
                        color='indianred')

        # plot the winter means as horizontal bars
        startdates = []
        enddates = []
        for i in range(len(years)):
            startdate = str(years[i]) + '-06-01'
            enddate = str(years[i]) + '-08-31'
            startdates.append(startdate)
            enddates.append(enddate)

        startdates_x = [datetime.datetime.strptime(d,"%Y-%m-%d").date() for d in startdates]
        enddates_x = [datetime.datetime.strptime(d,"%Y-%m-%d").date() for d in enddates]

        # loop through winter means and plot
        # jja_sigma0_site.sel(time=str(years[i])).values gets the value associated with the current year
        for i in range(len(years)):
            ax.plot([startdates_x[i],enddates_x[i]], [jja_sigma0_site.sel(year=years[i]).values,jja_sigma0_site.sel(year=years[i]).values],
                    color='black',
                    linewidth=3) 

        # now move on to the same logic for the melt thresholds
        startdates_thresh = []
        enddates_thresh = []
        for i in range(len(years)):
            startdate_t = str(years[i]) + '-07-01'
            enddate_t = str(years[i]+1) + '-06-30'
            startdates_thresh.append(startdate_t)
            enddates_thresh.append(enddate_t)

        startdates_thresh_x = [datetime.datetime.strptime(d,"%Y-%m-%d").date() for d in startdates_thresh]
        enddates_thresh_x = [datetime.datetime.strptime(d,"%Y-%m-%d").date() for d in enddates_thresh]

        # and plot the melt thresholds now
        for i in range(len(years)):
            ax.plot([startdates_thresh_x[i],enddates_thresh_x[i]], 
                    [melt_thresholds.sel(year=years[i]).values,melt_thresholds.sel(year=years[i]).values],
                    color='dimgray',
                    linewidth=2,
                    linestyle='dotted') 

        # show detected melt on the plot
        ax3 = ax.twinx()
        line3 = ax3.fill_between(melt_detected_da.time.values,melt_detected_da, 
                        label='ASCAT-detected melt', 
                        color='cornflowerblue',
                        alpha=0.3)

        # Set some display properties
        ax.set(ylabel='ASCAT backscatter at AWS site [dB]',
               xlabel='Date',
               title=aws_name)

        ax2.set(ylabel='Daily mean melt rate [mm w.e.]')
        ax2.yaxis.grid(False)

        ax3.set_ylim((0,1))
        ax2.set_ylim((-0.5, 0.5))
        ax.set_ylim((-30, 30))

        ax3.yaxis.grid(False)

        # set the x-axis limits plot:
        # start with the first winter mean associated with the overlap with the aws data
        # end with the end of the last melt threshold that overlaps with the aws data
        ax.set_xlim(startdates_x[0],enddates_thresh_x[len(years)-1])

        # add legend
        # this could be expanded as we have many things plotted now...
        lines = line1 + line2
        labels = [l.get_label() for l in lines]
        ax.legend(lines, labels, loc='lower left')

        # display plot
        plt.show()
        
    # return some variables for future use
    return aws_df_daily,melt_detected_da   

### Run the `ascat_melt_detector` function
Provide inputs to the function and run it as in the following code block.

In [None]:
# fist, define these variables (you need to input values for aws_lat, aws_lon, and melt_thresh!)
aws_name = 'AWS14'
aws_fname = 'IMAU_aws14_high-res_meteo.tab'
aws_lat = -67.021
aws_lon = -61.5
melt_thresh = 2 # <- ENTER NUMBER HERE! (before #) 
showPlot = True

# then, run the melt detector
outputs = ascat_melt_detector(aws_name, aws_fname, aws_lat, aws_lon, melt_thresh, showPlot)

## Part 3b: Apply and evaluate melt detection method at all 10 AWS sites

To do this, we'll first define a second function that we can apply at each site. This function will sum the melt duration detected by ASCAT as well as the melt duration based on the in situ observations and surface energy balance model. The output will be a scatter plot showing how well the ASCAT-detected melt compares to the in situ SEB-derived melt. 

The input variables are:
1. `aws_name`: string containing a name for the AWS (e.g., 'AWS 14').
2. `aws_fname`: string containing the file name for this AWS located in the Data folder (e.g., 'IMAU_aws18_high-res_meteo.tab')
3. `aws_latlons`: list of lat lon pairs (e.g., [(-71.49, -50.13),(-70.04,13.15)])
5. `ascat_melt_thresh`: threshold value in dB below the winter mean at which point to classify melt (e.g., 1)
6. `aws_melt_thresh`: threshold value in mm w.e. of SEB-derived melt above which to consider a day as experiencing melt (probably best to keep at 0)

While this function could be applied at an individual AWS site, the real power will be feeding the function all available AWS data and seeing the results. See code block after the function for an idea of how this will work.

In [None]:
def ascat_aws_cal_val_plotter(aws_name,aws_fname,aws_latlons,ascat_melt_tresh,aws_melt_thresh):
    aws_lat = aws_latlons[0]
    aws_lon = aws_latlons[1]
    
    melt_thresh = ascat_melt_tresh
    
    # using input, run main melt-detection function
    melt_results = ascat_melt_detector(aws_name, aws_fname, aws_lat, aws_lon, melt_thresh, showPlot)
    
    # get a couple results from the melt detection function
    aws_df_daily = melt_results[0]
    ascat_melt_detected_da = melt_results[1]
    
    # now we can move on to summing aws and ascat-derived melt over full melt years (7/1 to 6/30)
    
    # check month and day of aws data min value
    aws_start_month = aws_df_daily.index.min().month

    # if the aws data start in jan-june, we want to start summing in the next year
    if 1 <= aws_start_month <= 6:
        sum_start_year = aws_df_daily.index.min().year
    elif 7 <= aws_start_month <= 12:
        sum_start_year = aws_df_daily.index.min().year+1
        
    # check month and day of aws data min value
    aws_start_month = aws_df_daily.index.min().month

    # need to find last full meltyear of AWS data (MY = 07/01-Y1 to 06/30-Y2 )
    aws_end_month = aws_df_daily.index.max().month

    # since little melt occurs during winter, it's safe to count the last melt year as being full if it ends in may
    if 1 <= aws_end_month <= 4:
        sum_end_year = aws_df_daily.index.max().year-1
    elif 5 <= aws_end_month <= 12:
        sum_end_year = aws_df_daily.index.max().year
        
    # now we know when to start and end, let's move on to the classifying
    aws_melt_years = np.arange(sum_start_year,sum_end_year)

    # AWS: classify binary melt days
    aws_melt_days = aws_df_daily['Melt rate [mm w.e.] (surface melt, within dt)'] > aws_melt_thresh

    aws_annual_meltDay_sums = []
    for i in range(len(aws_melt_years)):
        cur_aws_meltYear_sum = aws_melt_days.loc[str(aws_melt_years[i]) + '-07-01' : str(aws_melt_years[i]+1) + '-06-30'].sum()
        aws_annual_meltDay_sums.append(cur_aws_meltYear_sum)

    # ASCAT: sum up the ASCAT-detected melt days over those same time periods
    ascat_annual_meltDay_sums = []
    for i in range(len(aws_melt_years)):
        cur_ascat_meltYear_sum = ascat_melt_detected_da.sel(time=slice(str(aws_melt_years[i]) + '-07-01', str(aws_melt_years[i]+1) + '-06-30')).sum().values
        ascat_annual_meltDay_sums.append(int(cur_ascat_meltYear_sum))
    
    return aws_annual_meltDay_sums, ascat_annual_meltDay_sums
    

### Now run the melt detection calibration/validation at all valid AWS sites
To run this, first define lists of all AWS names, files, and lat/lon pairs. 

Start with the two sites listed and then expand the code to include all AWS sites except AWS4. 

#### Important:

1.  **For this to work properly, you must keep the AWS sites in order in each variable. For example, AWS14 is the first value in `aws_names`, so its filename needs to be first in `aws_fnames`, and its lat/lon needs to be first in `aws_latlons`!**

2. Do not use AWS4, as it is outside of the bounds of the ASCAT data (it starts in 1997 and ends in 2002). 

In [None]:
# Lists of all AWS names, file names, and lat/lons
# Lists in python are contained by square brackets: [ ], and each list item separated by commas.
# Lat/lons have two values per list item, so these are contained within parentheses.

aws_names = ['AWS14','AWS15']

aws_fnames = ['IMAU_aws14_high-res_meteo.tab', 
              'IMAU_aws15_high-res_meteo.tab']

aws_latlons = [(-67.021, -61.5),
               (-67.57,-62.15)]

# set ASCAT and AWS melt thresholds (in dB and mm w.e., respectively)
ascat_melt_thresh = 2 # <- ENTER NUMBER HERE! (before #) 
aws_melt_thresh = 0 

# set whether to show the individual plots (this will increase processing time, but can be useful for evaluation of your method!)
showPlot = False

### When the above lists and variables are defined, we can go on with the number crunching...
This will iterate through each AWS site and then output a scatter plot and some basic statistics for goodness of fit. 

In [None]:
# run the melt detection
nAWS = len(aws_names)

full_aws_melt_days = []
full_ascat_melt_days = []

# loop through each AWS site, detect melt, and count up melt sums for both ASCAT and the AWS data
for i in range(nAWS):
    print('Processing site ' + str(i+1) + ' of ' + str(nAWS) + ': ' + aws_names[i])
    
    # this runs the function
    meltsums = ascat_aws_cal_val_plotter(aws_names[i],aws_fnames[i],aws_latlons[i],ascat_melt_thresh,aws_melt_thresh)
    
    # create lists of all paired melt days
    full_aws_melt_days.extend(meltsums[0])
    full_ascat_melt_days.extend(meltsums[1])


# After the function is run at each AWS site, plot the full results
x = full_aws_melt_days
y = full_ascat_melt_days

# Create a scatter plot
fig, ax = plt.subplots(figsize=(4, 4))
plotpoints = ax.scatter(x, y)

# add 1:1 line
one_to_one_line = mlines.Line2D([0, 1], [0, 1], color='dimgray', linestyle=':', label='1:1 line')
transform = ax.transAxes
one_to_one_line.set_transform(transform)
ax.add_line(one_to_one_line)

# # add regression line
z = np.polyfit(x, y, 1)
p = np.poly1d(z)
regression_line = ax.plot(x,p(x), label="regression line")

# Set some display properties
maxval = max(max(x),max(y))
plt.xlim(-5,maxval+20)
plt.ylim(-5,maxval+20)

ax.set(ylabel='ASCAT annual melt duration [days]',
       xlabel='SEB annual melt duration [days]')

# turn on legend
ax.legend()

# show the plot
plt.show()

# calculate squared correlation coefficient and show it
r = np.corrcoef(x, y)
r2 = r**2
print('r-squared = ' + str(r2[0][1]))

# calculate mean bias and display it
bias = np.mean(y)-np.mean(x)
print('bias = ' + str(bias))