In [None]:
"""
Created on Fri Jun 4 15:50 2020

This is a script to convert the NEMO temperature and salinity to potential temperature and practical salinity

@author: Clara Burgard
"""

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import gsw
import matplotlib.pyplot as plt
import basal_melt_param.useful_functions as uf
import basal_melt_param.T_S_profile_functions as tspf
import basal_melt_param.melt_functions as meltf

import itertools

import distributed
import glob

In [None]:
client = distributed.Client(n_workers=4, dashboard_address=':8795', local_directory='/tmp', memory_limit='6GB')

In [None]:
client

In [None]:
%matplotlib qt5

READ IN DATA

In [None]:
nemo_run = 'bf663' # 'bf663','bi646' 

In [None]:
inputpath_data='/bettik/burgardc/DATA/NN_PARAM/interim/SMITH_'+nemo_run+'/'
inputpath_profiles='/bettik/burgardc/DATA/NN_PARAM/interim/T_S_PROF/SMITH_'+nemo_run+'/'
inputpath_isf='/bettik/burgardc/DATA/NN_PARAM/interim/ANTARCTICA_IS_MASKS/SMITH_'+nemo_run+'/'

# make the domain a little smaller to make the computation even more efficient - file isf has already been made smaller at its creation
map_lim = [-3000000,3000000]

In [None]:
file_mask = xr.open_dataset(inputpath_data+'mask_variables_of_interest_allyy_Ant_stereo.nc')#, chunks={'x': 600, 'y': 600})
file_mask = file_mask.assign_coords({'time': range(1970,1970+len(file_mask.time))})#.chunk({'time': 1})
file_mask_71 = file_mask.isel(time=range(72))
file_mask2 = xr.open_dataset(inputpath_data+'mask_depth_coord_Ant_stereo.nc')


In [None]:
file_isf_1970 = xr.open_dataset(inputpath_isf+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_yy00.nc')
file_isf_1970_cutted = uf.cut_domain_stereo(file_isf_1970, map_lim, map_lim)#.squeeze().drop('time')

file_mask_cutted = uf.cut_domain_stereo(file_mask_71, map_lim, map_lim)#.squeeze().drop('time')

#file_TS_orig  = xr.open_mfdataset(inputpath_data+'variables_of_interest_2000_Ant_stereo.nc', chunks={'x': 600, 'y': 600})
#file_TS_orig_cutted = uf.cut_domain_stereo(file_TS_orig, map_lim, map_lim).squeeze().drop('time')

In [None]:
lon = file_isf_1970_cutted['longitude']
lat = file_isf_1970_cutted['latitude']

Prepare the depth axis

In [None]:
nemo_depth = np.round(file_mask2['gdept_0'].squeeze(dim=['lon','lat']), 3) # round to mm scale - should be enough

Cut out the temperature and salinity and assign the new depth axis

CONVERT CONSERVATIVE TEMPERATURE FOR OPEN OCEAN REGIONS TO POTENTIAL TEMPERATURE 
AND ABSOLUTE SALIINITY TO PRACTICAL SALINITY

In [None]:
for tt,timet in enumerate(range(1970, 2012)): 
    print(timet)
    
    ds_ts  = xr.open_dataset(inputpath_data + '3D_variables_of_interest_allyy_Ant_stereo_'+str(timet)+'.nc') #, chunks={'x': 600, 'y': 600})
    ds_ts_cutted = uf.cut_domain_stereo(ds_ts, map_lim, map_lim)
    ds_temp_saline_input = ds_ts_cutted[['thetao', 'so']]
    ds_temp_saline_input = ds_temp_saline_input.rename({'thetao': 'temperature', 'so': 'salinity'})
    ds_temp_saline_input = ds_temp_saline_input.rename({'deptht': 'depth'})
    ds_temp_saline_input['depth'] = np.round(ds_temp_saline_input.depth, 3)
    ds_temp_saline_input = ds_temp_saline_input.assign_coords(depth=nemo_depth.values)
    
    file_isf = xr.open_dataset(inputpath_isf+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_yy'+str(timet-1970).zfill(2)+'.nc')
    file_isf_cutted = uf.cut_domain_stereo(file_isf, map_lim, map_lim)#.squeeze().drop('time')
    mask_ocean = file_isf_cutted['ISF_mask'] == 1  #ocean without ice shelf cavity
    
    ds_temp_saline_input['theta_ocean'] = ds_temp_saline_input['temperature'].where(mask_ocean) 
    ds_temp_saline_input['salinity_ocean'] = xr.apply_ufunc(gsw.SP_from_SA, ds_temp_saline_input['salinity'].where(mask_ocean), ds_temp_saline_input['depth'], lon, lat, dask = 'allowed')
    ds_temp_saline_output = ds_temp_saline_input[['theta_ocean', 'salinity_ocean']]
    ds_temp_saline_output.to_netcdf(inputpath_profiles + 'T_S_theta_ocean_corrected_'+str(timet)+'.nc')

In [None]:
ds_temp_saline_input['salinity_ocean'].isel(depth=40).plot()

IF I USE THE OCEAN CONC TO CORRECT T AND S (BUT THE RESULT IS NOT REALLY SATISFYING)

In [None]:
ds_ts  = xr.open_dataset(inputpath_data + '3D_variables_of_interest_allyy_Ant_stereo.nc') #, chunks={'x': 600, 'y': 600})
ds_ts_cutted = uf.cut_domain_stereo(ds_ts, map_lim, map_lim)
ds_ts_cutted = ds_ts_cutted.assign_coords({'time': range(1970,1970+len(ds_ts_cutted.time))})#.chunk({'time': 1})

float_oi_mask = xr.open_dataset(inputpath_data+'custom_ocean_ice_mask_Ant_stereo.nc')
oi_mask_cutted = uf.cut_domain_stereo(float_oi_mask['oi_mask01'], map_lim, map_lim)
oi_mask_cutted = oi_mask_cutted.assign_coords({'time': range(1970,1970+len(oi_mask_cutted.time))})#.chunk({'time': 1})


for tt,timet in enumerate(range(1970, 1974)): 
    print(timet)
    
    ocean_conc = oi_mask_cutted.sel(time=timet)
    ocean_conc_clean = ocean_conc.where(ocean_conc < 0.999)
    ds_temp_saline_input = ds_ts_cutted[['thetao', 'so']].sel(time=timet) / (1 - ocean_conc_clean)
    ds_temp_saline_input = ds_temp_saline_input.rename({'thetao': 'temperature', 'so': 'salinity'})
    ds_temp_saline_input = ds_temp_saline_input.rename({'deptht': 'depth'})
    ds_temp_saline_input['depth'] = np.round(ds_temp_saline_input.depth, 3)
    ds_temp_saline_input = ds_temp_saline_input.assign_coords(depth=nemo_depth.values)
    
    file_isf = xr.open_dataset(inputpath_isf+'nemo_5km_isf_masks_and_info_and_distance_oneFRIS_yy'+str(timet-1970).zfill(2)+'.nc')
    file_isf_cutted = uf.cut_domain_stereo(file_isf, map_lim, map_lim)#.squeeze().drop('time')
    mask_ocean = file_isf_cutted['ISF_mask'] == 1  #ocean without ice shelf cavity
    
    ds_temp_saline_input['theta_ocean'] = ds_temp_saline_input['temperature'].where(mask_ocean) 
    ds_temp_saline_input['salinity_ocean'] = xr.apply_ufunc(gsw.SP_from_SA, ds_temp_saline_input['salinity'].where(mask_ocean), ds_temp_saline_input['depth'], lon, lat, dask = 'allowed')
    ds_temp_saline_output = ds_temp_saline_input[['theta_ocean', 'salinity_ocean']]
    ds_temp_saline_output.to_netcdf(inputpath_profiles + 'T_S_theta_ocean_corrected_withconc_'+str(timet)+'.nc')

In [None]:
ds_ts  = xr.open_dataset(inputpath_data + '3D_variables_of_interest_allyy_Ant_stereo_1973.nc') #, chunks={'x': 600, 'y': 600})
ds_ts['thetao'].max()

In [None]:
    ocean_conc = oi_mask_cutted.sel(time=timet)
    ocean_conc_clean = ocean_conc.where(ocean_conc < 0.999)
    ds_temp_saline_input = ds_ts_cutted[['thetao', 'so']].sel(time=timet) / (1 - ocean_conc_clean)

In [None]:
ds_temp_saline_input['theta_ocean'].isel(depth=20).where(ds_temp_saline_input['theta_ocean'].isel(depth=20) < -5).count()

In [None]:
    ocean_conc = oi_mask_cutted.sel(time=timet)
    ocean_conc_clean = ocean_conc.where(ocean_conc < 0.999)
    ds_temp_saline_input = ds_ts_cutted[['thetao', 'so']].sel(time=timet) / (1 - ocean_conc_clean)
    ds_temp_saline_input = ds_temp_saline_input.rename({'thetao': 'temperature', 'so': 'salinity'})
    ds_temp_saline_input = ds_temp_saline_input.rename({'deptht': 'depth'})
    ds_temp_saline_input['depth'] = np.round(ds_temp_saline_input.depth, 3)
    ds_temp_saline_input = ds_temp_saline_input.assign_coords(depth=nemo_depth.values)

In [None]:
test = ds_temp_saline_input['salinity'].where(mask_ocean)

In [None]:
ds_ts['so'].isel(deptht=40).plot(vmin=33)

In [None]:
ds_temp_saline_input['salinity'].isel(depth=40).plot(vmin=33,vmax=35)

In [None]:
test.isel(depth=40).where(test.isel(depth=40) > 5).plot()

In [None]:
ds_temp_saline_output['salinity_ocean'].isel(depth=40).plot()

In [None]:
ds_temp_saline_input['thetao'].isel(deptht=20).where(ds_temp_saline_input['so'].isel(deptht=20) > 1000, drop=True).plot()

In [None]:
ocean_conc.where(ocean_conc < 0.999).plot()

In [None]:
ds_ts_cutted['thetao'].sel(time=timet).isel(deptht=0)

In [None]:
ds_temp_saline_output['salinity_ocean'].isel(depth=10).plot()

Write the results to multiple files (1 per year)

In [None]:
yearly_datasets = list(tspf.split_by_chunks(ds_temp_saline_output.unify_chunks(),'time'))
paths = [tspf.create_filepath(ds, 'T_S_theta_ocean_corrected', inputpath_profiles, ds.time[0].values) for ds in yearly_datasets]

this takes approximately 1 min per year

In [None]:
xr.save_mfdataset(datasets=yearly_datasets, paths=paths)