## Flow of Workshop
1. Reading data 
    1. Read LPJ dataset from STAC
    2. Read subset of Merra-2 (few variables) data using credentials
    3. Creating different region of interest
2. Statistical Analysis
    1. Select the region and create time series of total emission in the region
    Note we will have line chart for two different years to show inter annual variability
    2. Create dual folium map to show visual comparision
    3. Plot monthly mean and climate mean using merra-2 data for same region of interest
    4. Plot time series for the following data and time period to interpret the results:
        a. 2020, 2021 - LPJ monthly emissions
        b. 2020, 2021 - Merra 2 T2M monthly anomaly
        c. 2020, 2021 - Total precipitation rate

Datasets to be used
1. Monthly LPJ Wetland CH4 Emissions
2. Monthly MERRA-2 Precipitation RateDataset: MERRA2_400.tavgM_2d_flx_Nx
Variable: ‘PRECTOT’
https://disc.gsfc.nasa.gov/datasets/M2TMNXFLX_5.12.4/summary

3. Monthly MERRA-2 Surface Soil MoistureDataset: MERRA2_400.tavgM_2d_lnd_Nx
Variable: ‘SFMC’
Long-term mean variable: ‘GWETTOP’
https://disc.gsfc.nasa.gov/datasets/M2TMNXLND_5.12.4/summary

4. Monthly MERRA-2 T2MDataset: MERRA2_400.instM_2d_asm_Nx
Variable: ‘T2M’
https://disc.gsfc.nasa.gov/datasets/M2IMNXASM_5.12.4/summary

5. MERRA-2 Long-Term MeansMERRA2.tavgC_2d_ltm_Nx
https://disc.gsfc.nasa.gov/datasets/M2TCNXLTM_1/summary

Use case to be discussed:
1. Midwest floods in 2019
2. Pick events of interest

## Defining region of interest

In [None]:
boundaries={
    'Global':[-180,180,-90,90],
    'Louisiana': [-95.9,-87.50,28.7,33.5],
    'CONUS':[-127.08,-63.87,23.55,49.19],   #   conus
    'Florida':[-84.07,-79.14,24.85,30.5],
    'Northeast':[-74.88,-69.81,40.48,42.88]
}

### Reading MERRA-2 Data

In [None]:
params={
    'MERRA-2 T2M':
        {'var':'T2M',
        'cmap':'Spectral_r',
        'dir':merra_t2m_dir,
        'nickname':'merra2_t2m',
        'climdir':merra_t2m_clim_dir,
        'climvar':'T2MMEAN'},
    'MERRA-2 Surface Soil Moisture':
        {'var':'GWETTOP',
        'cmap':'Blues',
        'dir':merra_soil_moisture_dir,
        'nickname':'merra2_sm',
        'climdir':merra_soil_moisture_clim_dir,
        'climvar':'GWETTOP'},
    'MERRA-2 Precipitation Rate':
        {'var':'PRECTOT',
        'cmap':'Spectral_r',
        'dir':merra_precip_rate_dir,
        'nickname':'merra2_pr',
        'climdir':merra_precip_rate_clim_dir,
        'climvar':'PRECTOT'}
}

## Mothly Time Series

In [None]:
def get_merra2_timeseries(year,focus,p,anomaly):
    files = glob.glob(params[p]['dir']+'%s/*.nc4'%(year))
    if anomaly:
        try:
            clim_files = glob.glob(params[p]['climdir']+'*.nc4')
        except:
            print('Climatological mean files (climdir) not found for specified parameter.')
            breakpoint()
    month_labels = []
    box_totals = []
    month_field = []
    dt = []
    for i,f in enumerate(files):
        data = nc.Dataset(f)
        
        #   Get bounding box
        wlat = np.logical_and(
            data['lat'][:] < boundaries[focus][3],
            data['lat'][:] > boundaries[focus][2]
        )
        wlon = np.logical_and(
            data['lon'][:] < boundaries[focus][1],
            data['lon'][:] > boundaries[focus][0]
        )

        datestamp = f.split('.')[-2]
        month = int(datestamp[-2::])

        dt.append(datetime(year,month,1))
        month_labels.append(datetime(year,month,1).strftime('%B'))

        if anomaly:
            #   Make sure you read the climatology for the right month (whichfile)
            whichfile = [datetime(2020,month,1).strftime('%y%m') in f for f in clim_files]
            climdata = nc.Dataset(np.array(clim_files)[whichfile][0])
            
            #   Calculate sum (emissions) or mean (met params) over your bounding box
            if 'LPJ' in p:
                clim_box_total = np.nansum(climdata[params[p]['climvar']][0,wlat,wlon])
                now_box_total = np.nansum(data[params[p]['var']][0,wlat,wlon])
            elif 'MERRA' in p:
                clim_box_total = np.nanmean(climdata[params[p]['climvar']][0,wlat,wlon])
                now_box_total = np.nanmean(data[params[p]['var']][0,wlat,wlon])

            #   Replace fill values with NaN 
            #   Otherwise differencing might give wild results? (Just be safe)
            wfillclim = np.where(climdata[params[p]['climvar']][0,:,:] == climdata[params[p]['climvar']]._FillValue)
            climfield = climdata[params[p]['climvar']][0,:,:]
            climfield[wfillclim] = np.nan
            wfillnow = np.where(data[params[p]['var']][0,:,:] == data[params[p]['var']]._FillValue)
            nowfield = data[params[p]['var']][0,:,:]
            nowfield[wfillnow] = np.nan

            #   And finally, difference current month and long-term mean 
            box_totals.append(now_box_total - clim_box_total)
            month_field.append(nowfield - climfield)
            climdata.close()
        else:
            if 'LPJ' in p:
                box_totals.append(np.nansum(data[params[p]['var']][0,wlat,wlon]))
            elif 'MERRA' in p:
                box_totals.append(np.nanmean(data[params[p]['var']][0,wlat,wlon]))
            #   Replace fill values with NaN (otherwise maps are hard to read) 
            month_field.append(data[params[p]['var']][0,:,:])
            wfill = np.where(month_field[-1] == data[params[p]['var']]._FillValue)
            month_field[-1][wfill] = np.nan
            #breakpoint()

    #   Sort in case months are out of order
    dti = np.argsort(dt)
    month_labels = np.array(month_labels)[dti]
    box_totals = np.array(box_totals)[dti]
    month_field = np.array(month_field)[dti]

    print('mean ',np.nanmean(month_field))
    print('std ',np.nanstd(month_field))

    data_return = {
        'month_labels':month_labels,
        'box_totals':box_totals,
        'month_fields':month_field,
        'units':data[params[p]['var']].units,
        'lat':data['lat'][:],
        'lon':data['lon'][:]
    }
    data.close()
    return data_return 


def monthly_timeseries(year,focus,param,anomaly):
    labels = []
    cmap = plt.get_cmap('gnuplot') 
    colors = cmap(np.linspace(0,1,len(param)))
    for i,p in enumerate(param):
    #   Don't pass multiple [param] at a time.
        if 'LPJ' in p:
            ts = get_lpj_timeseries(year,focus,p)
        elif 'MERRA' in p:
            ts = get_merra2_timeseries(year,focus,p,anomaly)
            
        if i == 0:
            fig = plt.figure(figsize=(6,3))
            ax = fig.add_subplot(111)

        #breakpoint()
        try:
            ax.plot(
                list(range(0,12)),
                ts['box_totals'],
                linestyle='-',
                linewidth=2,
                color=colors[i],
                markersize=4,
                marker='o',
                label=p
            )
        except ValueError:
            print('Double check that you have all twelve months of MERRA-2 data downloaded!')
            print(params[p]['dir'])
            breakpoint()

        #   Construct plot title
        title = '%s\n%s Mean Monthly %s'%(focus,year,p)
        if anomaly:
           title+=' Anomaly' 
        if 'LPJ' in p:
            title = title.replace('Mean','Total')
        plt.title(title)
        
        plt.xticks(list(range(0,12)))
        ax.set_xticklabels(ts['month_labels'],rotation=40,ha='right')


        if p == param[-1]:
            if i > 0:
                ax.legend(loc='best')
                nickname = '_'.join(params[p]['nickname'] for p in params)
                savename = '%s/box_summed_%s_%s_%s.png'% \
                    (savedir,nickname,year,focus)
            else:
                nickname = params[p]['nickname']
                savename = '%s/%s/%s/box_summed_%s_%s_%s.png'% \
                    (savedir,nickname,focus,nickname,year,focus)
            if anomaly:
                ax.plot(list(range(-1,13)),np.zeros(14),linewidth=0.4)
                savename = savename.replace('.png','_Anomaly.png')
            ax.set_xlim(-1,12)
            ax.set_ylim(-4e-5,4e-5)     #   manual per parameter
            print('Saving to '+savename)
            plt.figure(1).savefig(savename,dpi=300,bbox_inches='tight')

    return ts 