### Build masks 
- This notebooks shows how to build region masks for the South Atlantic, and Eastern and Western South Atlatic.

In [None]:
%load_ext autoreload
%autoreload 2
import xarray as xr 
import numpy as np  
import pop_tools
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import cartopy
import dask
import cartopy.feature as cfeature
import distributed
import ncar_jobqueue
import intake
from dask.distributed import Client
from ncar_jobqueue import NCARCluster
%matplotlib inline
import warnings, getpass, os

### Read the pop 1 deg grid from pop_tools

In [None]:
# We will use variables TLONG and TLAT
pop_grid = pop_tools.get_grid('POP_gx1v7')
pop_grid

In [None]:
atl = pop_grid.REGION_MASK.where(pop_grid.REGION_MASK==6)/pop_grid.REGION_MASK.where(pop_grid.REGION_MASK==6)
atl['TLAT'] = pop_grid['TLAT']
atl['TLONG'] = pop_grid['TLONG']
atl = atl.fillna(0)
atl

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = atl.plot.pcolormesh(ax=ax,
                        transform=ccrs.PlateCarree(),
                        x='TLONG',
                        y='TLAT',
                        add_colorbar=True)                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();

### South Atlantic Mask

In [None]:
south_atl = atl * atl.where(atl.TLAT<0.)
south_atl = south_atl.fillna(0)
south_atl = south_atl * south_atl.where(south_atl.TLAT>-34.)
south_atl = south_atl.fillna(0)

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = south_atl.plot.pcolormesh(ax=ax,
                              transform=ccrs.PlateCarree(),
                              x='TLONG',
                              y='TLAT')                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();
plt.savefig('south_atl.png',dpi=300,bbox_inches='tight')

### Western South Atlantic Mask

In [None]:
tmp1 = (south_atl *  south_atl.where(south_atl.TLONG>345)) 
tmp1 = tmp1.fillna(0)
tmp2 = south_atl *  south_atl.where(south_atl.TLONG<20)
tmp2 = tmp2.fillna(0)
east_south_atl = tmp1 + tmp2

In [None]:
south_atl.shape

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = east_south_atl.plot.pcolormesh(ax=ax,
                              transform=ccrs.PlateCarree(),
                              x='TLONG',
                              y='TLAT',
                              add_colorbar=True)                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();
#plt.savefig('eastern_atl.png',dpi=300,bbox_inches='tight')

### Eastern South Atlantic Mask

In [None]:
west_south_atl = south_atl *  south_atl.where(south_atl.TLONG<345) * south_atl.where(south_atl.TLONG>50)
west_south_atl = west_south_atl.fillna(0)

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = west_south_atl.plot.pcolormesh(ax=ax,
                              transform=ccrs.PlateCarree(),
                              x='TLONG',
                              y='TLAT',
                              add_colorbar=True)                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();
#plt.savefig('western_atl.png',dpi=300,bbox_inches='tight')

### Merge the masks

In [None]:
ds_south_atl_masks = xr.merge([south_atl.rename('south_atl'),west_south_atl.rename('west_south_atl'),east_south_atl.rename('east_south_atl')])
del west_south_atl, south_atl, east_south_atl     

### Dask framework

In [None]:
mem_per_worker = 30 # memory per worker in GB 
num_workers = 26 # number of workers
cluster = NCARCluster(cores=1, processes=1, memory=f'{mem_per_worker} GB',resource_spec=f'select=1:ncpus=1:mem={mem_per_worker}GB', walltime='6:00:00')
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

In [None]:
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)

In [None]:
cat_subset = catalog.search(component='ocn',variable=['PV'],frequency='month_1')

In [None]:
%%time
# Load catalog entries for subset into a dictionary of xarray datasets
dset_dict_raw  = cat_subset.to_dataset_dict(zarr_kwargs={'consolidated': True}, storage_options={'anon': True})
print(f'\nDataset dictionary keys:\n {dset_dict_raw.keys()}')

### Concatenation of variables

In [None]:
pd=('historical','ssp370')                 # Periods (historical and projection)
ff=('cmip6','smbb')                        # Forcings
fb=(['PV']) # Variable.

ds_dict = dict()
for var in fb: # We used loops, because we keep the same structure even if other variables are added
    # 1- combine historical and ssp370 (concatenate in time)
    ds_dict_tmp = dict()
    for scenario in ff:
        ds_dict_tmp[scenario] = xr.combine_nested([dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'], dset_dict_raw[f'ocn.ssp370.pop.h.{scenario}.{var}']],concat_dim=['time'])
        
        # 2- combine cmip6 and smbb (concatenate in member_id)
    ds_dict[var] = xr.combine_nested([ds_dict_tmp['cmip6'], ds_dict_tmp['smbb']], concat_dim=['member_id'])
    #del ds_dict_tmp

### Apply POP grid

In [None]:
ds_dict['TLONG'] = pop_grid.TLONG; ds_dict['TLAT'] = pop_grid.TLAT # coordinate T
ds_dict['ULONG'] = pop_grid.ULONG; ds_dict['ULAT'] = pop_grid.ULAT # coordinate U

### Built a mask for each variable (continent x ocean)
- This mask will be useful to apply to the area after you transform it into 3D

In [None]:
mask_array = dict()
for var in fb:
    #.resample(time='1Y', closed='left').mean('time')
    mask_ocean = 2 * np.ones((len(ds_dict[f'{var}'][f'{var}'].coords['nlat']), # 2 is OCEAN
                              len(ds_dict[f'{var}'][f'{var}'].coords['nlon']))
                            ) * np.isfinite(ds_dict[f'{var}'][f'{var}'].isel(time=0,member_id=0))  
    mask_land  = 1 * np.ones((len(ds_dict[f'{var}'][f'{var}'].coords['nlat']), # 1 is CONTINENT
                              len(ds_dict[f'{var}'][f'{var}'].coords['nlon']))
                            ) * np.isnan(ds_dict[f'{var}'][f'{var}'].isel(time=0,member_id=0))  
    mask_array[var] = mask_ocean + mask_land # Indicates the continent and the ocean
    #ds_dict[f'{var}']['TAREA']=ds_dict[f'{var}']['TAREA'].where(mask_array[f'{var}'] != 1.)
del mask_ocean, mask_land

### Extrapolate the area to 3D and applying the continent x ocean mask
- Remember, you don't have to include z_t if you have a 2D variable, for example SHF
- We add the member coordinate in the area3D, so that down the road we do not lose this reference when we compute the data

In [None]:
for var in fb:
    area3D = pop_grid.TAREA.where(pop_grid.KMT > 0) # area in cm2
    area3D = np.array([area3D]*len(ds_dict[f'{var}'][f'{var}'].coords['z_t'])) # 60 depths (z_t)
    area3D = np.array([area3D]*len(ds_dict[f'{var}'][f'{var}'].coords['member_id'])) # 100 number of members (member_id)
    area3D = xr.DataArray(area3D, dims=['member_id','z_t','nlat','nlon']) # Make a xarray again
    area3D.coords['member_id']=ds_dict[f'{var}'][f'{var}'].coords['member_id'] # Import the name of each member
    area3D = area3D.where(mask_array[f'{var}'] != 1.) # Apply the mask (Remember: 1 is continent and 2 is ocean)

#### Check the 3D area

In [None]:
area3D.isel(member_id=0,z_t=-10).plot() 

### Annual Mean
- Let's take the annual mean, because we are not interested in seasonality and it is a way to decrease the processing time (array reduction)

In [None]:
%%time
dask.config.set(**{'array.slicing.split_large_chunks': False})
for var in fb:
    # 1- Annual mean
    ds_dict[f'{var}']=ds_dict[f'{var}'].resample(time='1Y', closed='left').mean('time')
    # 2- Subtracting from all times minus the initial time the variables
    ds_dict[f'{var}'][f'{var}'] = ds_dict[f'{var}'][f'{var}'] - ds_dict[f'{var}'][f'{var}'].isel(time=0) # Anomaly regarding to the beginning of the time series 

#### Check the new data
- Ckeck it by changing z_t and time

In [None]:
for var in fb:
    ds_dict[f'{var}'][f'{var}'].isel(member_id=1,time=0,z_t=0).plot()

### Apply the masks for each region
- Since this process takes time, we recommend doing it per variable

In [None]:
%%time
warnings.simplefilter("ignore")
masks_sa=(['south_atl','west_south_atl','east_south_atl']) # each mask

for var in fb:
    for basin in masks_sa:
        print(f'Done with region: {basin}') # Each mask
        var_array = list() # Build a list
        for member_id in range(len(ds_dict[f'{var}'][f'{var}'].coords['member_id'])): # STEP BY STEP per member:
            var_small=ds_dict[f'{var}'][f'{var}'].isel(member_id=member_id).where(ds_south_atl_masks[f'{basin}'] != 0.) # 1- Apply the basin mask to the variable
            var_small=(var_small.weighted(
                (area3D.isel(member_id=member_id).where(ds_south_atl_masks[f'{basin}'] != 0.)).fillna(0)) # 2- Apply the basin mask to the area
                       ).mean(dim=['nlon','nlat']) # 3- Make the spatial average
            var_small=var_small.load() # 4- Load the data
            var_array.append(var_small) # 5- Add items to the end of a given list
            print(f'Done with member: {member_id}') # 6- Go to the next member
        st=f'{basin}_var_merged = xr.concat(var_array, dim=\'member_id\', compat=\'override\', join=\'override\', coords=\'minimal\')' # concat the members
        exec(st) # 7- Go to the next basin

### Save data
- Don't forget to include the parameters of the defined variable, such as the unit, etc

In [None]:
ds_out_var = xr.merge([east_south_atl_var_merged.rename('east_south_atl_PV'), # entire basin
                       west_south_atl_var_merged.rename('west_south_atl_PV'), # western sie
                       south_atl_var_merged.rename('south_atl_PV')]) # eastern side
ds_out_var.attrs['description'] = 'Potential Vorticity in each South Atlantic region: the entire basin (south_atlantic), eastern side (east_south_atlantic), western side (west_south_atlantic)'
ds_out_var.attrs['units'] = '1/s/cm'
ds_out_var.attrs['author'] = 'Mauricio Rocha'
ds_out_var.attrs['email'] = 'mauricio.rocha@usp.br'
# create a directory on scratch to save the output
path = '/glade/scratch/mauricio/Data/LENS2/PV/'.format(getpass.getuser())
os.system('mkdir -p '+path)
ds_out_var.to_netcdf(path+'PV_south_atl_regions.nc')