# Analysis notebook for moremelt piControl run
## Set up
### Packages

In [1]:
import numpy as np
import xarray as xr
import xesmf as xe
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import pandas as pd
import scipy
from scipy import stats, interpolate
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.mathtext import _mathtext as mathtext
import matplotlib.ticker as mticker
from matplotlib import gridspec, animation
import matplotlib.path as mpath
import matplotlib.colors as colors
import matplotlib.dates as mdates
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import warnings
warnings.simplefilter('ignore', UserWarning)
warnings.filterwarnings('ignore')
import datetime as dt
from datetime import timedelta
from cmcrameri import cm
from Processing_functions import FixLongitude, Wilks_pcrit
import jinja2
import cftime
import dask
from dask_jobqueue import PBSCluster
from dask.distributed import Client
from functools import partial
from collections import defaultdict
import os

In [2]:
## Plot types to make - CHANGE
# 0: Weighted spatial mean or sea ice area
# 1: Volume
# 2: Leave alone (doing spatial map or sea ice concentration)
plots = {
    'map': [True, 2],
    'ts': [False, 0],
    'sie': [False, 0],
    'siv': [False, 1]
}

## Categorical plot type - DO NOT CHANGE
plot_types = {
    'spatial': plots['map'][0],
    'line': plots['ts'][0] or plots['siv'][0] or plots['sie'][0]
}

# Spatial domain - CHANGE s_domain & t_domain only
s_domain = 2 # 0: Global, 1: Arctic, 2: Antarctic
a_domain = plot_types['spatial'] # True: 50-90, False: 70-90
t_domain = 911 # start year

## Time averaging type - CHANGE
time_avg = 0   # 0: Monthly, 1: Yearly, 2: Seasonal, 3: All data, 4: Timeseries

# Variables - CHANGE
comp = 'ice' # compset
var_ind = 1

# DO NOT CHANGE
var_list = {'atm': ['TREFHT'],
            'ice': ['aice', 'hi', 'hs']}
var = var_list[comp][var_ind]

In [3]:
## Test names
# O: CMIP6 baseline
# 1: lessmelt
# 2: moremelt
# attribute structure: [use dataset, dataset type]
# CHANGE ONLY - use dataset
ds_names = {
    'LENS2 piControl': [True, 0],
    'lessmelt piControl': [False, 1],
    'moremelt_rsnw0': [True, 2]
}
vercompres = 'b.e21.B1850cmip6.f09_g17.'
cesm2piC = 'b.e21.B1850.f09_g17.'

## Filepaths - DO NOT CHANGE
path_to_work = '/glade/work/glydia/'
path_to_pidata = path_to_work+'processed_CESM2_LENS_data/moremelt_comparison/'
path_to_lmdata = path_to_work+'processed_CESM2_lessmelt_data/'
path_to_expdata = path_to_work+'Arctic_controls_processed_data/'
path_to_plotdata  = path_to_expdata+'climo_plotting_data/'

# Extensions - DO NOT CHANGE
h_ext = {'atm': '.h0.',
       'ice': '.h.'}
yr_extn = ".*."
vert_lev = {'atm': [False],
            'ice': [False,False,False]}
file_bool = not vert_lev[comp][var_ind]
file_ext = {True: 'nc', False: 'zarr'}

In [4]:
########################## DO NOT CHANGE ANYTHING BELOW THIS LINE #############################

In [5]:
%%time
    
## Select plot type
time_str_list = {0: 'month', 1: 'year', 2: 'season', 3: 'all', 4: 'timeseries'}
time_outstr = time_str_list[time_avg]

## Select time and spatial domain strings
sd_str_list = {0: 'Global', 1: 'Arctic', 2: 'Antarctic'}
sd_str = sd_str_list[s_domain]
td_str = str(t_domain).zfill(4)

CPU times: user 5 µs, sys: 1e+03 ns, total: 6 µs
Wall time: 8.58 µs


### Cluster

In [6]:
cluster = PBSCluster(cores    = 1,
                     memory   = '50GiB',
                     queue    = 'casper',
                     walltime = '12:00:00',
                     account  = 'UCUB0155',
                     name='PiC_UVnudge_process_'+var)
cluster.scale(4*9)
client = Client(cluster)

In [7]:
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/glydia/Arctic_breakdown/proxy/42753/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/glydia/Arctic_breakdown/proxy/42753/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://128.117.208.98:46005,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/glydia/Arctic_breakdown/proxy/42753/status,Total threads: 0
Started: Just now,Total memory: 0 B


### Custom functions

#### LoadData

In [8]:
def LoadData(casename, set_type, varname):
    # Create file name
    if set_type == 0:
        filename = cesm2piC+'CMIP6-piControl.001'+h_ext[comp]+varname+yr_extn+file_ext[file_bool]
        totalpath = path_to_pidata+filename
    elif set_type == 1:
        filename = cesm2piC+'CMIP6-piControl.001_branch2'+h_ext[comp]+varname+yr_extn+file_ext[file_bool]
        totalpath = path_to_lmdata+filename
    else:
        filename = vercompres+casename+h_ext[comp]+varname+yr_extn+file_ext[file_bool]
        totalpath = path_to_expdata+'processed_'+casename+'_data/'+filename

    # Load
    data = xr.open_mfdataset(totalpath)
    return data

In [9]:
# Set spatial domain slicing
slice_icemod = {1: dict(nj=slice(250,385)), 2: dict(nj=slice(0,250))}
slice_atmwei = dict(lat=slice(70,90))
slice_time = dict(time=slice(td_str+'-01-01',str(t_domain+199).zfill(4)+'-12-31'))

#### CreateMasterDS

In [10]:
def CreateMasterDS(varname):
    ds_load_list = []

    # Find which plots will be made
    funcstr = None
    for plotname, attrs in plots.items():
        # If plot will be made
        if attrs[0]:
            # If weighted spatial average or sea ice extent
            if attrs[1] == 0:
                funcstr = 'S' if comp == 'ice' else 'W'
                break
            # If sea ice volume
            elif attrs[1] == 1:
                funcstr = 'V'
                break

    print('Figure out processing calculations to do')

    ## Load all datasets
    # Load each dataset
    for dsname, attrs in ds_names.items():
        # If dataset is going to be plotted
        if attrs[0]:
            print(dsname)
            ds_load = LoadData(dsname, attrs[1], varname)
    
            # Reduce to dataarray
            ds_load = ds_load[varname]
        
            # Rename dataarray name to be casename
            ds_load = ds_load.rename(dsname)
            
            print('  Initial data loading complete')

            ## Do weighted averages, ensemble means, sea ice area calculations, and area maximums
            # Do spatial domain slicing
            # ICE
            if comp == 'ice':
                ds_load = ds_load.loc[slice_icemod[s_domain]]
            # ATMOSPHERE
            else: 
                if s_domain == 1:
                    ds_load = ds_load.loc[slice_atmwei]

            print('  Sliced data')
                    
            # Sort processing by variable and graph type
            # Doing spatial plot or SIC
            if funcstr == None:
                print('  No proceesing, spatial')
                # If ERA5, rename lat/lon     
            # Doing SIA plot(s)
            elif funcstr == 'S':
                print('  Calculating sea ice area')
                ds_load = CalcSIA(ds_load, attrs[1])
            # Doing SIV plot(s)
            elif funcstr == 'V':
                print('  Calculating sea ice volume')
                ds_load = CalcSIV(ds_load)
            # Doing ts or TOA plot(s)
            elif funcstr == 'W':
                print('  Calculating weighted average')
                ds_load = CalcWeightedMean(ds_load)

            ds_load = ds_load.rename(dsname)

            print('  Processing on dataset complete')
            ds_load_list.append(ds_load)
    
    # Merge dataarrays
    ds_proc = xr.merge(ds_load_list)

    print('All datasets merged')
    
    return ds_proc

#### CalcWeightedMean

In [11]:
def CalcWeightedMean(ds):
    # Set up
    avg_dim = ('lon','lat')

    # Create weights
    weights = np.cos(np.deg2rad(ds.lat))
    weights.compute()

    # Weight data
    ds_w = ds.weighted(weights)

    # Calculate weighted mean
    ds_mean_w = ds_w.mean(avg_dim, skipna=True)

    ds_mean_w.compute()
    
    return ds_mean_w

#### Regrid

In [12]:
def Regrid(ds_timeavg, regridder):
    print('Regridding CICE grid -> ATM grid...')

    # Do fillna if regridding sic/sit
    nval = 0.000001

    # Regrid data
    regrid_list = []
    for dsname, da in ds_timeavg.items():
        print('Regridding '+dsname)
        da_re = regridder(da)

        # Rename x, y to be lon, lat and reassign lon and lat data (which includes cyclic point)
        da_re = da_re.rename({'x':'lon', 'y':'lat'})

        # Only fillna for sic/sit
        da_re = da_re.assign_coords({'lon':lons, 'lat': lats})
        da_re = da_re.fillna(nval)
            
        regrid_list.append(da_re.rename(dsname))
        
    ds_regrid = xr.merge(regrid_list)

    return ds_regrid

#### CalcSIA

In [13]:
def CalcSIA(ds, set_type):
    # Extracts aice variable and only counts cells with sic > 15%
    ds_aice = ds
    ds_aice = xr.where(ds_aice > .15,1,0)

    # Multiples selected sic cells with tarea, then only selects Arctic sea ice, sums over the entire domain, then converts to km2
    dsa = (ds_aice*ds['tarea']).sum(dim=['ni','nj'])*1e-6*1e-6

    # Modifies attributes of DataArray accordingly
    dsa.attrs['units'] = 'million km^2'
    dsa.attrs['long_name'] = 'sea ice extent'

    # Assigns new variable to dataset and returns original dataset
    return dsa

#### CalcSIV

In [14]:
def CalcSIV(ds):
    # Calculates SIV, given ice domain already sliced, converts to 
    dsvol = (ds*ds['tarea']).sum(dim=['ni','nj'])*1.0e-13
    
    # Modifies attributes of DataArray accordinging
    dsvol.attrs['units'] = '10^13 m^3'
    dsvol.attrs['long_name'] = 'sea ice volume'

    return dsvol

#### AddCoordTrend

In [15]:
def AddCoordTrend(da, time_type):
    da_sizes = da.sizes
    new_time = np.arange(1,da_sizes[time_type]+1)
    if time_type == 'time':
        da = da.assign_coords(time=new_time)
    elif time_type == 'year':
        da = da.assign_coords(year=new_time)
    return da

#### AddCyclic

In [16]:
def AddCyclic(da: xr.DataArray, londim) -> xr.DataArray:
    # Add cyclic point
    cyclic_data, cyclic_lon = add_cyclic_point(da.data, coord=da[londim])
    cyclic_coords = {dim: da.coords[dim] for dim in da.dims}
    cyclic_coords[londim] = cyclic_lon

    da = xr.DataArray(cyclic_data, dims=da.dims, coords=cyclic_coords, attrs=da.attrs, name=da.name)
    return da

#### SpatZonAvg

In [17]:
def SpatZonAvg(ds):
    # Calculating averages
    if time_avg == 0:
        ds_avg = ds.groupby('time.month').mean('time')
        ds_avg = ds_avg.assign_coords(month=mon_str)
        
    # Seasonal averaging
    elif time_avg == 2:
        ds_avg = ds.resample(time='QS-DEC').mean('time')
        ds_avg = ds_avg.groupby('time.month').mean('time')
        ds_avg = ds_avg.assign_coords(month=seas_str)
        ds_avg = ds_avg.rename({'month':'season'})

    return ds_avg

#### AddAllCyclic

In [18]:
def AddAllCyclic(ds):
    londim = 'lon'
    latdim = 'lat'
    
    cyclic_ds_list = []
        
    for dsname, da in ds.items():
        da = da.transpose(time_outstr, latdim,londim)
        da_cyc = AddCyclic(da, londim)
            
        cyclic_ds_list.append(da_cyc.rename(dsname))
            
    ds = xr.merge(cyclic_ds_list)

    return ds

#### SaveData

In [19]:
def SaveData(ds, plot_type, varname, tavg, plot_level=None):
    level_str = '' if plot_level == None else 'Z'+str(int(plot_level))+'.'
    
    filename = plot_type+'.'+varname+'.'+sd_str+'.'+level_str+td_str+'.'+tavg+'.nc'
    
    print('Saving '+filename)
    # File format is:
    # (plot type, including special averaging, i.e. anomalies).varname.spatialdomain.timedomain.timeaveraging.ensemble_type.nc
    ds.to_netcdf(path_to_plotdata+filename,
                format='NETCDF4')
    return None

In [20]:
hix = xr.open_dataset('/glade/work/glydia/Arctic_controls_processed_data/processed_moremelt_rsnw0_data/b.e21.B1850cmip6.f09_g17.moremelt_rsnw0.h.hi.101101-106012.nc')

### Process and load data

In [21]:
%%time

ds_proc = CreateMasterDS(var)

Figure out processing calculations to do
LENS2 piControl
  Initial data loading complete
  Sliced data
  No proceesing, spatial
  Processing on dataset complete
moremelt_rsnw0
  Initial data loading complete
  Sliced data
  No proceesing, spatial
  Processing on dataset complete
All datasets merged
CPU times: user 2.06 s, sys: 443 ms, total: 2.5 s
Wall time: 59.6 s


In [22]:
ds_proc = ds_proc.loc[slice_time]

### Plotting set-up

In [23]:
%%time

mon_str = np.array(['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec'])


CPU times: user 13 µs, sys: 2 µs, total: 15 µs
Wall time: 17.4 µs


## Line plots
### Set up

In [24]:
if plot_types['line']:
    graph_type_str = 'Linear'
        
    # Time averaging for yearly plot
    if time_avg == 1:
        dim_avg = 'time.year'  
        period= 'year'
        
    if time_avg == 0:
        dim_avg = 'time.month'
        period = 'month'   

### Data processing

In [25]:
%%time

if plot_types['line']:
    ds_abs = ds_proc.groupby(dim_avg).mean('time')

    if time_avg == 0:
        ds_abs = ds_abs.assign_coords(month=mon_str)

    SaveData(ds_abs, graph_type_str+'.abs', var, period)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.72 µs


## Create Regridders

In [26]:
%%time

## Only run after time averaging!!!
lats = np.array([-90, -89.0575916230366, -88.1151832460733, -87.1727748691099, 
    -86.2303664921466, -85.2879581151832, -84.3455497382199, -83.4031413612565, 
    -82.4607329842932, -81.5183246073298, -80.5759162303665, -79.6335078534031, 
    -78.6910994764398, -77.7486910994764, -76.8062827225131, -75.8638743455497, 
    -74.9214659685864, -73.979057591623, -73.0366492146597, -72.0942408376963, 
    -71.151832460733, -70.2094240837696, -69.2670157068063, -68.3246073298429, 
    -67.3821989528796, -66.4397905759162, -65.4973821989529, -64.5549738219895, 
    -63.6125654450262, -62.6701570680628, -61.7277486910995, -60.7853403141361,
    -59.8429319371728, -58.9005235602094, -57.9581151832461, -57.0157068062827, 
    -56.0732984293194, -55.130890052356, -54.1884816753927, -53.2460732984293, 
    -52.303664921466, -51.3612565445026, -50.4188481675393, -49.4764397905759, 
    -48.5340314136126, -47.5916230366492, -46.6492146596859, -45.7068062827225, 
    -44.7643979057592, -43.8219895287958, -42.8795811518325, -41.9371727748691,
    -40.9947643979058, -40.0523560209424, -39.1099476439791, -38.1675392670157, 
    -37.2251308900524, -36.282722513089, -35.3403141361257, -34.3979057591623, 
    -33.455497382199, -32.5130890052356, -31.5706806282722, -30.6282722513089, 
    -29.6858638743456, -28.7434554973822, -27.8010471204189, -26.8586387434555, 
    -25.9162303664921, -24.9738219895288, -24.0314136125654, -23.0890052356021, 
    -22.1465968586387, -21.2041884816754, -20.261780104712, -19.3193717277487, 
    -18.3769633507853, -17.434554973822, -16.4921465968586, -15.5497382198953, 
    -14.6073298429319, -13.6649214659686, -12.7225130890052, -11.7801047120419, 
    -10.8376963350785, -9.89528795811519, -8.95287958115183, -8.01047120418848, 
    -7.06806282722513, -6.12565445026178, -5.18324607329843, -4.24083769633508, 
    -3.29842931937173, -2.35602094240838, -1.41361256544502, -0.471204188481678, 
    0.471204188481678, 1.41361256544502, 2.35602094240838, 3.29842931937172, 
    4.24083769633508, 5.18324607329843, 6.12565445026178, 7.06806282722513, 
    8.01047120418848, 8.95287958115183, 9.89528795811518, 10.8376963350785, 
    11.7801047120419, 12.7225130890052, 13.6649214659686, 14.6073298429319, 
    15.5497382198953, 16.4921465968586, 17.434554973822, 18.3769633507853, 
    19.3193717277487, 20.261780104712, 21.2041884816754, 22.1465968586387, 
    23.0890052356021, 24.0314136125654, 24.9738219895288, 25.9162303664921, 
    26.8586387434555, 27.8010471204188, 28.7434554973822, 29.6858638743455, 
    30.6282722513089, 31.5706806282723, 32.5130890052356, 33.455497382199, 
    34.3979057591623, 35.3403141361257, 36.282722513089, 37.2251308900524, 
    38.1675392670157, 39.1099476439791, 40.0523560209424, 40.9947643979058, 
    41.9371727748691, 42.8795811518325, 43.8219895287958, 44.7643979057592, 
    45.7068062827225, 46.6492146596859, 47.5916230366492, 48.5340314136126, 
    49.4764397905759,50.4188481675393, 51.3612565445026, 52.303664921466, 
    53.2460732984293, 54.1884816753927, 55.130890052356, 56.0732984293194, 
    57.0157068062827, 57.9581151832461, 58.9005235602094, 59.8429319371728, 
    60.7853403141361, 61.7277486910995, 62.6701570680628, 63.6125654450262, 
    64.5549738219895, 65.4973821989529, 66.4397905759162, 67.3821989528796, 
    68.3246073298429, 69.2670157068063, 70.2094240837696, 71.151832460733, 
    72.0942408376963, 73.0366492146597, 73.979057591623, 74.9214659685864, 
    75.8638743455497, 76.8062827225131, 77.7486910994764, 78.6910994764398, 
    79.6335078534031, 80.5759162303665, 81.5183246073298, 82.4607329842932, 
    83.4031413612565, 84.3455497382199, 85.2879581151832, 86.2303664921466, 
    87.17277486911, 88.1151832460733, 89.0575916230366, 90])
lons = np.array([-180, -178.75, -177.5, -176.25, -175, -173.75, -172.5, -171.25, -170, 
    -168.75, -167.5, -166.25, -165, -163.75, -162.5, -161.25, -160, -158.75, 
    -157.5, -156.25, -155, -153.75, -152.5, -151.25, -150, -148.75, -147.5, 
    -146.25, -145, -143.75, -142.5, -141.25, -140, -138.75, -137.5, -136.25, 
    -135, -133.75, -132.5, -131.25, -130, -128.75, -127.5, -126.25, -125, 
    -123.75, -122.5, -121.25, -120, -118.75, -117.5, -116.25, -115, -113.75, 
    -112.5, -111.25, -110, -108.75, -107.5, -106.25, -105, -103.75, -102.5, 
    -101.25, -100, -98.75, -97.5, -96.25, -95, -93.75, -92.5, -91.25, -90, 
    -88.75, -87.5, -86.25, -85, -83.75, -82.5, -81.25, -80, -78.75, -77.5, 
    -76.25, -75, -73.75, -72.5, -71.25, -70, -68.75, -67.5, -66.25, -65, 
    -63.75, -62.5, -61.25, -60, -58.75, -57.5, -56.25, -55, -53.75, -52.5, 
    -51.25, -50, -48.75, -47.5, -46.25, -45, -43.75, -42.5, -41.25, -40, 
    -38.75, -37.5, -36.25, -35, -33.75, -32.5, -31.25, -30, -28.75, -27.5, 
    -26.25, -25, -23.75, -22.5, -21.25, -20, -18.75, -17.5, -16.25, -15, 
    -13.75, -12.5, -11.25, -10, -8.75, -7.5, -6.25, -5, -3.75, -2.5, -1.25, 
    0, 1.25, 2.5, 3.75, 5, 6.25, 7.5, 8.75, 10, 11.25, 12.5, 13.75, 15, 
    16.25, 17.5, 18.75, 20, 21.25, 22.5, 23.75, 25, 26.25, 27.5, 28.75, 30, 
    31.25, 32.5, 33.75, 35, 36.25, 37.5, 38.75, 40, 41.25, 42.5, 43.75, 45, 
    46.25, 47.5, 48.75, 50, 51.25, 52.5, 53.75, 55, 56.25, 57.5, 58.75, 60, 
    61.25, 62.5, 63.75, 65, 66.25, 67.5, 68.75, 70, 71.25, 72.5, 73.75, 75, 
    76.25, 77.5, 78.75, 80, 81.25, 82.5, 83.75, 85, 86.25, 87.5, 88.75, 90, 
    91.25, 92.5, 93.75, 95, 96.25, 97.5, 98.75, 100, 101.25, 102.5, 103.75, 
    105, 106.25, 107.5, 108.75, 110, 111.25, 112.5, 113.75, 115, 116.25, 
    117.5, 118.75, 120, 121.25, 122.5, 123.75, 125, 126.25, 127.5, 128.75, 
    130, 131.25, 132.5, 133.75, 135, 136.25, 137.5, 138.75, 140, 141.25, 
    142.5, 143.75, 145, 146.25, 147.5, 148.75, 150, 151.25, 152.5, 153.75, 
    155, 156.25, 157.5, 158.75, 160, 161.25, 162.5, 163.75, 165, 166.25, 
    167.5, 168.75, 170, 171.25, 172.5, 173.75, 175, 176.25, 177.5, 178.75])

# NEITHER REGRIDDER INCLUDES CYCLIC POINT!

if comp == 'ice' and plot_types['spatial']:
    ## Create Regridder for CICE grid
    lon2d, lat2d = np.meshgrid(lons, lats)
    
    # Create target grid and sample nj ni grid 
    target_gridSIC = xr.Dataset({'lat': (['y', 'x'], lat2d),'lon': (['y', 'x'], lon2d)})
    ds_samplens = ds_proc['LENS2 piControl'][0]
    
    # Create regridder for sea ice
    regridderSIC = xe.Regridder(ds_samplens, target_gridSIC, 'nearest_s2d', reuse_weights=False)

CPU times: user 5.03 s, sys: 47.5 ms, total: 5.07 s
Wall time: 5.62 s


In [27]:
# Spatial plots
if plot_types['spatial']:
    # Monthly
    if time_avg == 0:
        period = 'time'
        date_str = mon_str

## Spatial plots
### Set up

In [28]:
if plots['map'][0]:
    graph_type_str = 'Map'

### Data processing

In [29]:
%%time

if plots['map'][0]:
    ds_avg = ds_proc
    
    # Calculating averages
    ds_sp = SpatZonAvg(ds_avg)

CPU times: user 54.6 ms, sys: 399 µs, total: 55 ms
Wall time: 60.3 ms


In [30]:
%%time

if plots['map'][0]:
    if comp == 'ice':
        ds_sp = Regrid(ds_sp, regridderSIC)

    # Add cyclic data
    ds_sp = AddAllCyclic(ds_sp)
    SaveData(ds_sp, graph_type_str, var, time_outstr)

Regridding CICE grid -> ATM grid...
Regridding LENS2 piControl
Regridding moremelt_rsnw0
Saving Map.hi.Antarctic.0911.month.nc
CPU times: user 3.08 s, sys: 251 ms, total: 3.34 s
Wall time: 8.43 s


In [31]:
client.shutdown()