### Build masks 
- This notebooks shows how to build region masks for the South Atlantic, and Eastern and Western South Atlatic.

In [None]:
%load_ext autoreload
%autoreload 2
import xarray as xr 
import numpy as np  
import pop_tools
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import cartopy
import cartopy.feature as cfeature
import distributed
import ncar_jobqueue
import intake
from dask.distributed import Client
from ncar_jobqueue import NCARCluster
%matplotlib inline
import warnings, getpass, os

### Read the pop 1 deg grid from pop_tools

In [None]:
# Read the pop 1 deg grid from pop_tools
# We will use variables TLONG and TLAT
pop_grid = pop_tools.get_grid('POP_gx1v7')

In [None]:
pop_grid

In [None]:
atl = pop_grid.REGION_MASK.where(pop_grid.REGION_MASK==6)/pop_grid.REGION_MASK.where(pop_grid.REGION_MASK==6)
atl['TLAT'] = pop_grid['TLAT']
atl['TLONG'] = pop_grid['TLONG']
atl = atl.fillna(0)

In [None]:
atl

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = atl.plot.pcolormesh(ax=ax,
                        transform=ccrs.PlateCarree(),
                        x='TLONG',
                        y='TLAT',
                        add_colorbar=True)                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();

### South Atlantic Mask

In [None]:
south_atl = atl * atl.where(atl.TLAT<0.)
south_atl = south_atl.fillna(0)

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = south_atl.plot.pcolormesh(ax=ax,
                              transform=ccrs.PlateCarree(),
                              x='TLONG',
                              y='TLAT')                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();

### Western South Atlantic Mask

In [None]:
tmp1 = (south_atl *  south_atl.where(south_atl.TLONG>345)) 
tmp1 = tmp1.fillna(0)
tmp2 = south_atl *  south_atl.where(south_atl.TLONG<20)
tmp2 = tmp2.fillna(0)
east_south_atl = tmp1 + tmp2

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = east_south_atl.plot.pcolormesh(ax=ax,
                              transform=ccrs.PlateCarree(),
                              x='TLONG',
                              y='TLAT',
                              add_colorbar=True)                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();

### Eastern South Atlantic Mask

In [None]:
west_south_atl = south_atl *  south_atl.where(south_atl.TLONG<345) * south_atl.where(south_atl.TLONG>50)
west_south_atl = west_south_atl.fillna(0)

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = west_south_atl.plot.pcolormesh(ax=ax,
                              transform=ccrs.PlateCarree(),
                              x='TLONG',
                              y='TLAT',
                              add_colorbar=True)                                    
ax.gridlines(draw_labels=True);
ax.coastlines()
ax.gridlines();

### Applying the masks to the variables the heat flux at the surface and the temperature

In [None]:
mem_per_worker = 30 # memory per worker in GB 
num_workers = 26 # number of workers
cluster = NCARCluster(cores=1, processes=1, memory=f'{mem_per_worker} GB',resource_spec=f'select=1:ncpus=1:mem={mem_per_worker}GB', walltime='6:00:00')
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

In [None]:
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)

In [None]:
cat_subset = catalog.search(component='ocn',variable=['SHF','TEMP','SALT'],frequency='month_1')

In [None]:
%%time
# Load catalog entries for subset into a dictionary of xarray datasets
dset_dict_raw  = cat_subset.to_dataset_dict(zarr_kwargs={'consolidated': True}, storage_options={'anon': True})
print(f'\nDataset dictionary keys:\n {dset_dict_raw.keys()}')

In [None]:
# Variables
fb=(['TEMP','SALT','SHF']) # total surface heat flux [W/m2], and temperature [oC]
pd=('historical','ssp370') # periods
ff=('cmip6','smbb') # forcing
for ifb in range(0,len(fb)):
    for iff in range(0,len(ff)):
        for ipd in range(0,len(pd)):
            st=f'ds_{pd[ipd]}_{ff[iff]}_{fb[ifb]} = dset_dict_raw[\'ocn.{pd[ipd]}.pop.h.{ff[iff]}.{fb[ifb]}\']'
            exec(st)
        st=f'ds_{ff[iff]}_{fb[ifb]}=xr.combine_nested([ds_{pd[0]}_{ff[iff]}_{fb[ifb]},ds_{pd[1]}_{ff[iff]}_{fb[ifb]}],concat_dim=[\'time\']);'
        exec(st)
        st=f'del ds_{pd[0]}_{ff[iff]}_{fb[ifb]}, ds_{pd[1]}_{ff[iff]}_{fb[ifb]}'
        exec(st)
    st=f'ds_{fb[ifb]}=xr.combine_nested([ds_{ff[0]}_{fb[ifb]},ds_{ff[1]}_{fb[ifb]}],concat_dim=[\'member_id\']);'
    exec(st) 
    st=f'del ds_{ff[0]}_{fb[ifb]}, ds_{ff[1]}_{fb[ifb]}'
    exec(st)
    print(f'Done!')

In [None]:
ilat,flat=85,187 # latitude (initial, final)
ilon1,flon1,ilon2,flon2=307,320,0,54 # longitude (initial, final) 

pop_grid=pop_tools.get_grid('POP_gx1v7')

# Area
area = xr.combine_nested([[pop_grid.TAREA.isel(nlat = slice(ilat,flat),nlon = slice(ilon1,flon1)),
                              pop_grid.TAREA.isel(nlat = slice(ilat,flat),nlon = slice(ilon2,flon2))]],
                            concat_dim=['nlat','nlon'])

In [None]:
%%time
ba=('south_atl','west_south_atl','east_south_atl')

for iba in range(len(ba)):
    print(f'Mask: {ba[iba]}')    
    
    # Masks
    st=f'{ba[iba]}_new=xr.combine_nested([[{ba[iba]}.isel(nlat = slice(ilat,flat),nlon = slice(ilon1,flon1)),{ba[iba]}.isel(nlat = slice(ilat,flat),nlon = slice(ilon2,flon2))]],concat_dim=[\'nlat\',\'nlon\']).chunk(nlon=(flon1 + flon2 - ilon1 - ilon2))'; exec(st)   
    st=f'{ba[iba]}_new.coords[\'nlon\'] = ({ba[iba]}_new.coords[\'nlon\'] + 180) % 360 - 180'; exec(st) # change the longitudes: -180 0 180
    st=f'{ba[iba]}_new={ba[iba]}_new.sortby({ba[iba]}_new.nlon)'; exec(st)
    #st=f'{ba[iba]}_area=area[np.where({ba[iba]}_new != 0.)]'; exec(st)
                               
    for ifb in range(len(fb)): 
        print(f'Variable: {fb[ifb]}')
        
        # Variable
        st=f'ds_{fb[ifb]}_new=ds_{fb[ifb]}.{fb[ifb]}'; exec(st)
        st=f'ds_{fb[ifb]}_new[\'TLONG\']=pop_grid.TLONG; ds_{fb[ifb]}_new[\'TLAT\'] = pop_grid.TLAT'; exec(st)
        st=f'{ba[iba]}_{fb[ifb]}=ds_{fb[ifb]}_new'; exec(st)
        st=f'{ba[iba]}_{fb[ifb]}=xr.combine_nested([[{ba[iba]}_{fb[ifb]}.isel(nlat = slice(ilat,flat),nlon = slice(ilon1,flon1)),{ba[iba]}_{fb[ifb]}.isel(nlat = slice(ilat,flat),nlon = slice(ilon2,flon2))]],concat_dim=[\'nlat\',\'nlon\']).chunk(nlon=(flon1 + flon2 - ilon1 - ilon2))'; exec(st)
        st=f'{ba[iba]}_{fb[ifb]}.coords[\'nlon\'] = ({ba[iba]}_{fb[ifb]}.coords[\'nlon\'] + 180) % 360 - 180 # change the longitudes: -180 0 180'; exec(st)
        st=f'{ba[iba]}_{fb[ifb]}={ba[iba]}_{fb[ifb]}.sortby({ba[iba]}_{fb[ifb]}.nlon)'; exec(st)
        st=f'{ba[iba]}_{fb[ifb]}={ba[iba]}_{fb[ifb]}.resample(time=\'1Y\', closed=\'left\').mean(\'time\')'; exec(st)
        st=f'{ba[iba]}_{fb[ifb]}={ba[iba]}_{fb[ifb]}.where({ba[iba]}_new != 0.)'; exec(st) 

        if ifb<=1:
            # Building a mask for salinity and temperature
            st=f'mask_ocean_{ba[iba]}_{fb[ifb]} = 2 * np.ones((len({ba[iba]}_{fb[ifb]}.coords[\'nlat\']), len({ba[iba]}_{fb[ifb]}.coords[\'nlon\']))) * np.isfinite({ba[iba]}_{fb[ifb]}.isel(time=0,member_id=0))'; exec(st)  
            st=f'mask_land_{ba[iba]}_{fb[ifb]} = 1 * np.ones((len({ba[iba]}_{fb[ifb]}.coords[\'nlat\']), len({ba[iba]}_{fb[ifb]}.coords[\'nlon\']))) * np.isnan({ba[iba]}_{fb[ifb]}.isel(time=0,member_id=0))'; exec(st)  
            st=f'mask_array_{ba[iba]}_{fb[ifb]} = mask_ocean_{ba[iba]}_{fb[ifb]} + mask_land_{ba[iba]}_{fb[ifb]}'; exec(st)
           
            # Applying the built up mask to the area
            st=f'area_new=area.where({ba[iba]}_new != 0.)'; exec(st) # Applying the basin masks in the area
            
            st=f'area_new=np.array([area_new]*len({ba[iba]}_{fb[ifb]}.coords[\'z_t\']))'; exec(st) # Replicating the surface area for all as depths
            st=f'area_array=xr.DataArray(area_new, dims=[\'z_t\', \'nlat\', \'nlon\'])'; exec(st) # We transform the array into a xarray so that there are no mismatches in the oprerations
            st=f'area_array_{ba[iba]}_{fb[ifb]}=area_array.where(mask_array_{ba[iba]}_{fb[ifb]} != 1.)'; exec(st) # The masks of what is ocean and what is continent or floor were built
        
            # Do the subtraction for a reference temperature. Let's use the first temperature of the series
            st=f'{ba[iba]}_{fb[ifb]}={ba[iba]}_{fb[ifb]}-{ba[iba]}_{fb[ifb]}.isel(time=0)'; exec(st)
        
            # Let's calculate the average of the results in the lon and lat coordinate
            st=f'{ba[iba]}_{fb[ifb]}={ba[iba]}_{fb[ifb]}.weighted(area_array_{ba[iba]}_{fb[ifb]}.fillna(0)).mean(dim=[\'nlon\',\'nlat\'])'; exec(st)

In [None]:
%%time
# TEMP
for iba in range(len(ba)):
    st=f'{ba[iba]}_TEMP_array = list()'; exec(st)
    st=f'print(f\'Mask: {ba[iba]}\')'; exec(st) 
    for member_id in range(100):
        st=f'{ba[iba]}_TEMP_small={ba[iba]}_TEMP.isel(member_id=member_id)'; exec(st)
        st=f'{ba[iba]}_TEMP_small={ba[iba]}_TEMP_small.load()'; exec(st)
        st=f'{ba[iba]}_TEMP_array.append({ba[iba]}_TEMP_small)'; exec(st)
        st=f'print(f\'done with member #{member_id}\')'; exec(st)
    st=f'{ba[iba]}_TEMP_merged = xr.concat({ba[iba]}_TEMP_array, dim=\'member_id\', compat=\'override\', join=\'override\', coords=\'minimal\')'; exec(st)
    st=f'del {ba[iba]}_TEMP'; exec(st)

In [None]:
# TEMP
ds_out_TEMP = xr.merge([east_south_atl_TEMP_merged.rename('east_south_atl_TEMP'),west_south_atl_TEMP_merged.rename('west_south_atl_TEMP'),south_atl_TEMP_merged.rename('south_atl_TEMP')])
ds_out_TEMP.attrs['description'] = 'Temperature in each South Atlantic region: the entire basin (south_atlantic), eastern side (east_south_atlantic), western side (west_south_atlantic)'
ds_out_TEMP.attrs['units'] = 'K'
ds_out_TEMP.attrs['author'] = 'Mauricio Rocha'
ds_out_TEMP.attrs['email'] = 'mauricio.rocha@usp.br'
ds_out_TEMP
# create a directory on scratch to save the output
path = '/glade/scratch/mauricio/Data/LENS2/TEMP/'.format(getpass.getuser())
os.system('mkdir -p '+path)
ds_out_TEMP.to_netcdf(path+'TEMP_south_atl_regions.nc')

In [None]:
%%time
# SHF
for iba in range(len(ba)):
    st=f'{ba[iba]}_SHF_array = list()'; exec(st)
    st=f'print(f\'Mask: {ba[iba]}\')'; exec(st) 
    for member_id in range(100):
        st=f'{ba[iba]}_SHF_small={ba[iba]}_SHF.isel(member_id=member_id)'; exec(st)
        st=f'{ba[iba]}_SHF_small={ba[iba]}_SHF_small.load()'; exec(st)
        st=f'{ba[iba]}_SHF_array.append({ba[iba]}_SHF_small)'; exec(st)
        st=f'print(f\'done with member #{member_id}\')'; exec(st)
    st=f'{ba[iba]}_SHF_merged = xr.concat({ba[iba]}_SHF_array, dim=\'member_id\', compat=\'override\', join=\'override\', coords=\'minimal\')'; exec(st)
    st=f'del {ba[iba]}_SHF'; exec(st)

In [None]:
# SHF
ds_out_SHF = xr.merge([east_south_atl_SHF_merged.rename('east_south_atl_SHF'),west_south_atl_SHF_merged.rename('west_south_atl_SHF'),south_atl_SHF_merged.rename('south_atl_SHF')])
ds_out_SHF.attrs['description'] = 'Total surface heat flux in each South Atlantic region: the entire basin (south_atlantic), eastern side (east_south_atlantic), western side (west_south_atlantic)'
ds_out_SHF.attrs['units'] = 'W/m2'
ds_out_SHF.attrs['author'] = 'Mauricio Rocha'
ds_out_SHF.attrs['email'] = 'mauricio.rocha@usp.br'
ds_out_SHF
# create a directory on scratch to save the output
path = '/glade/scratch/mauricio/Data/LENS2/SHF/'.format(getpass.getuser())
os.system('mkdir -p '+path)
ds_out_SHF.to_netcdf(path+'SHF_south_atl_regions.nc')

In [None]:
%%time
# SALT
for iba in range(len(ba)):
    st=f'{ba[iba]}_SALT_array = list()'; exec(st)
    st=f'print(f\'Mask: {ba[iba]}\')'; exec(st) 
    for member_id in range(100):
        st=f'{ba[iba]}_SALT_small={ba[iba]}_SALT.isel(member_id=member_id)'; exec(st)
        st=f'{ba[iba]}_SALT_small={ba[iba]}_SALT_small.load()'; exec(st)
        st=f'{ba[iba]}_SALT_array.append({ba[iba]}_SALT_small)'; exec(st)
        st=f'print(f\'done with member #{member_id}\')'; exec(st)
    st=f'{ba[iba]}_SALT_merged = xr.concat({ba[iba]}_SALT_array, dim=\'member_id\', compat=\'override\', join=\'override\', coords=\'minimal\')'; exec(st)
    st=f'del {ba[iba]}_SALT'; exec(st)

In [None]:
# SALT
ds_out_SALT = xr.merge([east_south_atl_SALT_merged.rename('east_south_atl_SALT'),west_south_atl_SALT_merged.rename('west_south_atl_SALT'),south_atl_SALT_merged.rename('south_atl_SALT')])
ds_out_SALT.attrs['description'] = 'Salinity in each South Atlantic region: the entire basin (south_atlantic), eastern side (east_south_atlantic), western side (west_south_atlantic)'
ds_out_SALT.attrs['units'] = 'gram/kilogram'
ds_out_SALT.attrs['author'] = 'Mauricio Rocha'
ds_out_SALT.attrs['email'] = 'mauricio.rocha@usp.br'
ds_out_SALT
# create a directory on scratch to save the output
path = '/glade/scratch/mauricio/Data/LENS2/SALT/'.format(getpass.getuser())
os.system('mkdir -p '+path)
ds_out_SALT.to_netcdf(path+'SALT_south_atl_regions.nc')