In [1]:
%matplotlib inline
import xarray as xr
import numpy as np
import metpy.calc as mpcalc
from metpy.units import units
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
def transpose_da(da, dim_list):
    if len(dim_list) == 1:
        return da
    if len(dim_list) == 2:
        return da.transpose(dim_list[0],dim_list[1])
    if len(dim_list) == 3:
        return da.transpose(dim_list[0],dim_list[1],dim_list[2])
    if len(dim_list) == 4:
        return da.transpose(dim_list[0],dim_list[1],dim_list[2],dim_list[3])
    if len(dim_list) == 5:
        return da.transpose(dim_list[0],dim_list[1],dim_list[2],dim_list[3],dim_list[4])

In [5]:
def format_da_to_target(da_replace, da_temp_var, data_type):
    if data_type == 'forcing':
        dim_replace = []
        var_coords_dic = {}
        for dim in da_temp_var.dims:
            dim_replace.append(coord_name_dic_forcing[dim])
            if dim in coords_dic.keys():
                var_coords_dic.update({dim:coords_dic[dim]})
        da_replace_t = transpose_da(da_replace,dim_replace)
        da_replace_n = xr.DataArray(da_replace_t.values, dims=da_temp_var.dims,
                                    coords = var_coords_dic)
        da_replace_n.attrs = da_temp_var.attrs
    if data_type == 'scalar':
        if len(np.shape(da_replace.values)) == 0:
            value = np.float64(da_replace.values)
        else:
            value = np.float64(da_replace.values[0])
        
        da_replace_n = xr.DataArray(value)
        da_replace_n.attrs = da_temp_var.attrs
    if data_type == 'initial':
        dims_tuple = ('t0',)
        var_coords_dic = {'t0':coords_dic['t0']}
        da_init = da_replace.isel(time=0)
        if len(da_init.dims)>0:
            for dim in da_init.dims:
                if dim in coord_name_dic_reverse.keys():
                    dim_target = coord_name_dic_reverse[dim]
                    dims_tuple += (dim_target,)
                    if dim_target in coords_dic.keys():
                        var_coords_dic.update({dim:coords_dic[dim_target]})
        da_replace_n = xr.DataArray([da_init.values], dims=dims_tuple,
                                  coords = var_coords_dic)
        da_replace_n.attrs = da_temp_var.attrs

    return da_replace_n

In [6]:
temp_file = "/glade/work/yifanc/code/scm_v7/ccpp-scm/scm/data/processed_case_input/gabls3_noahmp_SCM_driver.nc"
static_file = "/glade/work/yifanc/proj/wpo2023/static/ameriflux/ameriflux_static_fields.C1152.US-Whs.nc"
forcing_file = "/glade/derecho/scratch/yifanc/wpo2023/SCM/processed/US-Whs/processed_era5.201902.US-Whs.derecho.nc"

In [None]:
excel_file = "ERA5_variable.SCM.input.xlsx"
excel_file = pd.ExcelFile(excel_file)
mapping_df = excel_file.parse(sheet_name="var_mapping")

In [8]:
mapping_df.set_index("scm_var_name", inplace=True)

In [9]:
ds_temp = xr.open_dataset(temp_file)
ds_static = xr.open_dataset(static_file)
ds_forcing = xr.open_dataset(forcing_file)

In [10]:
time_dims = ds_forcing["time"]
t0 = time_dims[0].values

In [11]:
coord_list = []
for coord in ds_temp.coords:
    print(coord)
    coord_list.append(coord)

t0
time
lev


In [14]:
out = xr.Dataset()

In [15]:
lev_coords = xr.DataArray(ds_forcing.levels.values,dims=['lev'],
                          coords={'lev':ds_forcing.levels.values})
lev_coords.attrs = ds_temp.lev.attrs

In [16]:
t0_coords = xr.DataArray(ds_forcing.time.values[[0]],dims=['t0'],
                         coords={'t0':ds_forcing.time.values[[0]]})
t0_coords.attrs = ds_temp.t0.attrs
time_coords = xr.DataArray(ds_forcing.time.values,dims=['time'],
                         coords={'time':ds_forcing.time.values})
time_coords.attrs = ds_temp.time.attrs


In [18]:
out.coords['lev'] = lev_coords
out.coords['t0'] = t0_coords
out.coords['time'] = time_coords

In [19]:
coords_dic = {'lev':lev_coords,
              't0':t0_coords,
              'time':time_coords}

In [21]:
data_dic = {'forcing':ds_forcing,
            'static':ds_static}

In [22]:
coord_name_dic_forcing = {'lev':'levels',
                         'lat':'latitude',
                         'lon':'longitude',
                         'time':'time',
                         't0':'t0',
                         'nsoil':'nsoil'}
coord_name_dic_reverse = {'levels':'lev',
                         'latitude':'lat',
                         'longitude':'lon',
                         'time':'time',
                         't0':'t0',
                         'nsoil':'nsoil'}

In [None]:
nudging_var_list = ['ua','va','ta','qv']
nudging_invar_dic = {'ta':'T_nudge',
                     'qv':'qt_nudge',
                     'ua':'u_nudge',
                     'va':'v_nudge'}
nudging_res_lev = 50000
nudging_timestep = 2 * 3600 

In [None]:
for var in ds_temp.variables:
    if var not in coord_list:
        data_source = mapping_df.loc[var, "data_source"]
        vname_source = mapping_df.loc[var, "data_source_var_name"]
        data_type = mapping_df.loc[var, "data_type"]
        print(var,data_type)
        da_temp_var = ds_temp[var]
        if isinstance(data_source, str):
            print(data_source,var, vname_source, data_type)
            ds_source = data_dic[data_source]
            da_replace = ds_source[vname_source]
            da_replace_n = format_da_to_target(da_replace,da_temp_var,data_type)
            out[var] = da_replace_n
        elif 't0' in da_temp_var.dims:
            if 'lev' in da_temp_var.dims:
                da_temp_var = xr.DataArray(np.full([1,len(lev_coords)],0),
                                           dims=['t0','lev'],
                                           coords={'t0':coords_dic['t0'],
                                                   'lev':coords_dic['lev']})
                out[var] = da_temp_var
        else:
            da_temp_var_no_change = ds_temp[var]
            out[var] = da_temp_var_no_change

soil_depth scalar
lon scalar
forcing lon longitude scalar
lat scalar
forcing lat latitude scalar
slmsk scalar
vegtyp scalar
static vegtyp vegetation_category scalar
soiltyp scalar
static soiltyp soil_category scalar
slopetyp scalar
static slopetyp slope_category scalar
tsfco scalar
vegfrac scalar
shdmin scalar
shdmax scalar
canopy scalar
hice scalar
fice scalar
tisfc scalar
snowd scalar
snoalb scalar
tg3 scalar
uustar scalar
alvsf scalar
alnsf scalar
alvwf scalar
alnwf scalar
facsf scalar
facwf scalar
weasd scalar
sncovr scalar
tsfcl scalar
zorl scalar
zorll scalar
zorli scalar
zorlw scalar
area scalar
thetal initial
forcing thetal thetal initial
qt initial
forcing qt qt_nudge initial
ua initial
forcing ua u_nudge initial
va initial
forcing va v_nudge initial
pa initial
forcing pa p initial
zh initial
forcing zh geopotential_height initial
ps initial
forcing ps p_surf initial
ql initial
qi initial
tke initial
o3 initial
stc initial
forcing stc soil_temperature initial
smc initial
forci

In [25]:
out.attrs = ds_temp.attrs

In [26]:
out.attrs['start_date'] = f"{str(time_coords.values[0])[0:10]} {str(time_coords.values[0])[11:19]}"
out.attrs['end_date'] = f"{str(time_coords.values[-1])[0:10]} {str(time_coords.values[-1])[11:19]}"

In [27]:
out_lev = out.sortby('lev', ascending=False)

In [28]:
# era5_noahmp_exp_SCM_driver.test10.nc
out_lev = out_lev.rename(name_dict={"thetal":"theta",
                          "tnthetal_adv":"tntheta_adv"})
out_lev.attrs['adv_theta'] = np.int32(1)
out_lev.attrs['adv_thetal'] = np.int32(0)

In [29]:
# era5_noahmp_exp_SCM_driver.test15.nc
da_temp_var = ds_temp['pa_forc']
data_type = 'forcing'
for nvar in nudging_var_list:
    invar = nudging_invar_dic[nvar]
    da = ds_forcing[invar]
    da_nudging = format_da_to_target(da,da_temp_var,data_type)
    da_nudging.attrs = da.attrs
    out_lev[f"{nvar}_nud"] = da_nudging
    out_lev.attrs[f"pa_nudging_{nvar}"] = nudging_res_lev
    out_lev.attrs[f"nudging_{nvar}"] = nudging_timestep

In [59]:
out_lev.to_netcdf("/glade/derecho/scratch/yifanc/wpo2023/SCM/processed/US-Whs/era5_noahmp_exp_SCM_driver.nc")