# Calculate the feedbacks in single model abrupt-4xCO2 experiment by using the radiatvie kernel 
ref: Soden, et.al., (2008)


In [5]:
import numpy as np
import subprocess
import xarray as xr
import numba
from numba import njit
import time
sstart_time0 = time.time()

import matplotlib.pyplot as plt
%matplotlib inline

import Radiative_Repsonse_with_Raditive_kernel as R3k

ModuleNotFoundError: No module named 'Radiative_Repsonse_with_Raditive_kernel'

In [2]:
def error_info(string):
    """
    print any input info if it is not empty
    and raise exception """
    if string!="" :
        raise Exception("Error : "+string)
        
def global_mean_xarray_nan(ds_XLL):
    """
    A function to compute the global mean value of the data.
    The data has to have the lat and lon in its dimensions.
    
    Parameters
    ----------
    ds_XLL   :  DataArray with lat and lon. ds_XLL.lat will be 
                used for area weight.

    Returns
    ----------
    GM_mean  : global mean value.
    
    """
    lat = ds_XLL.coords['lat']        # readin lat
    weight_area = np.cos(np.deg2rad(lat))
    data_XLL_weighted = (ds_XLL*weight_area).mean(dim=['lat','lon'])
    data_ones_nan = xr.where(np.isfinite(ds_XLL),1.,np.nan)
    weight_sum = (data_ones_nan*weight_area).mean(dim=['lat','lon'])
    GM_mean = data_XLL_weighted/weight_sum
    
    return GM_mean

## Prepare the data 
 - __ta(3D), hus(3D), ts(2D)__: 
> in both perturbation (abrupt-4xCO2) and control (piControl) \
> compute the perturbation for temperature, water vapor \
> $dt = T_{per} - T_{con}  $
 - __rlut rsdt rsut__ : 
> in perturbation \
> for the dR
 - __rsutcs rlutcs__: 
> clear-sky data needed in both perturbation (abrupt-4xCO2) and control (piControl) \
> for the dR^0, as well as the direct forcing in clear sky (and therefore we can infer the direct forcing in all sky ) \
> $D = D^0/1.16$

In [3]:
experiments = ["piControl","abrupt-4xCO2"]
# experiments = ["piControl","1pctCO2"]

var_list = "ta hus ts rlut rsdt rsut rlutcs rsutcs rsus rsds".split()
res_hori = '2x2.5'
## run the regrid scripts before this one!
dvc_info = '2020051522'

In [4]:
# def decompose_RK(model)
# process for each model
model = 'GFDL-CM4'
# model = 'CESM2'
var_cont = {}
var_pert = {}
model_ava = []
# read in all var needed
print(">>> start "+model)
for var in var_list:
    print(">>> reading <var>|{0:>7s}".format(var) )
    try:
        dirpath = "./data/CMIP6_post_regrid/piControl/{3:}/{0:}/{1:}.*0001-0450.ltm.nc.{2:}.2x2.5".format(var,var,dvc_info,model)
        tmp = subprocess.run(["source ~/.bash_env;  ls  "+dirpath], shell=True, capture_output=True)
        error_info(tmp.stderr.decode("utf-8"))
    except:
        print("model {0:} does't have variable {1:} in control experiment".format(model,var))
        raise 
    filelist = tmp.stdout.decode("utf-8")[:-1]
    # read variables
    with xr.open_dataset(filelist) as ds:
#             print(ds)
        var_cont[var] =  ds.isel(model=0).load()

    try:
        dirpath = "./data/CMIP6_post_regrid/abrupt-4xCO2/{3:}/{0:}/{1:}.*.nc.{2:}.2x2.5".format(var,var,dvc_info,model)
        tmp = subprocess.run(["source ~/.bash_env;  ls  "+dirpath], shell=True, capture_output=True)
        error_info(tmp.stderr.decode("utf-8"))
    except:
        print("model {0:} does't have variable {1:} in perturbe experiment".format(model,var))
        raise
    filelist = tmp.stdout.decode("utf-8")[:-1]
    # read variables
    with xr.open_dataset(filelist) as ds:
        var_pert[var] =  ds.isel(model=0).load()

>>> start GFDL-CM4
>>> reading <var>|     ta
model GFDL-CM4 does't have variable ta in control experiment


Exception: Error : ls: cannot access ./data/CMIP6_post_regrid/piControl/GFDL-CM4/ta/ta.*0001-0450.ltm.nc.2020051522.2x2.5: No such file or directory


In [None]:
## import kernel file and compute: dR_wv, dR_T, dR_Ts, dR_alb + clear sky
rk_source = 'GFDL'
# kernel file is from Brian Soden
dirpath = "./data/kernels_TOA_"+rk_source+"_CMIP6-standard.nc"
tmp = subprocess.run(["source ~/.bash_env;  ls  "+dirpath], shell=True, capture_output=True)
file_rk = tmp.stdout.decode("utf-8")
print(file_rk)
f_RK =  xr.open_dataset(file_rk[:-1],decode_times=False) 
f_RK=f_RK.rename({'time': 'month'})
f_RK.coords['month']= np.arange(1,13,1)
f_RK.coords['plev']= f_RK.coords['plev']*100

## Compute the cloud feedback
$
\begin{align}
dR &= D + dR_wv + dR_a + dR_T + dR_c \\
dR^0 &= D^0 + dR^0_{wv} + dR^0_a + dR^0_T \\
D^0 &= dR^0 - (dR^0_{wv} + dR^0_a + dR^0_T)\\
\Rightarrow dR_c &= dR - D - (dR_{wv} + dR_a + dR_T) \\
with\quad D &= D^0/1.16 
\end{align}
$

## convert the hus for RK and calculate the $dR_{wv}$
$
\begin{align}
dR_{wv} &= K^w  dq \\
        &= \underline{(K^w * \zeta)} \quad \underline{(dq / \zeta)}\\
        &= K^{\omega} \omega  \\
\zeta &= \frac{q}{q_s} \frac{dq_s}{dT} \\
\Rightarrow \omega  &=  dq / \zeta \\
                    &= \frac{q_s}{q} \frac{dT}{dq_s} dq \\
                    &= \frac{dT}{dln(q_s)}\frac{dq}{q} 
\end{align}
$

Since the water vapor RK is store with factor $\zeta$ and the unit $W/m^{-2} K^{-1}$, the convertion for specific humidity is needed for the $dR_{wv}$ calculation.

## time benchmark  (test with GFDL-CM4)
- compile and parallel with numba < 10s
- pure xarray ~ 100s

In [None]:
%%time
# run with dummy data to compile the jit functions and speed up the computation
# see https://numba.pydata.org/numba-doc/latest/index.html for detail
dummy_TPLL = np.random.rand(36,2,3,4).astype('float32')
dummy_TPLL12 = np.random.rand(12,2,3,4).astype('float32')
dummy_TLL = np.random.rand(36,3,4).astype('float32')
dummy_TLL12 = np.random.rand(12,3,4).astype('float32')
dummy_plev = np.random.rand(2).astype('float32')
cfrk.diff_pert_mon_cont_12mon_TPLL_fast(dummy_TPLL, dummy_TPLL12)
cfrk.diff_pert_mon_cont_12mon_TLL_fast(dummy_TLL, dummy_TLL12)
cfrk.alb_diff_pert_mon_cont_12mon_TLL_fast(dummy_TLL, dummy_TLL, dummy_TLL12, dummy_TLL12)
cfrk.omega_wv_fast(dummy_TPLL, dummy_TPLL12, dummy_TPLL12)
cfrk.RK_compute_TLL_fast (dummy_TLL, dummy_TLL12)
cfrk.RK_compute_TPLL_plev_fast(dummy_TPLL, dummy_TPLL12, dummy_plev)
print("@njit Functions finished compile!")

In [None]:
%%time
# numba version (fast, will be 10x faster if run it after compiled)
## convert the hus for RK and calculate the dR_wv
ta_anom  = cfrk.diff_pert_mon_cont_12mon_TPLL_fast(var_pert['ta'].ta.values, var_cont['ta'].ta.values)
ts_anom  = cfrk.diff_pert_mon_cont_12mon_TLL_fast (var_pert['ts'].ts.values, var_cont['ts'].ts.values)
alb_anom = cfrk.alb_diff_pert_mon_cont_12mon_TLL_fast(var_pert['rsus'].rsus.values, var_pert['rsds'].rsds.values,
                                                      var_cont['rsus'].rsus.values, var_cont['rsds'].rsds.values)
omega_wv = cfrk.omega_wv_fast (var_pert['hus'].hus.values, var_cont['hus'].hus.values, var_cont['ta'].ta.values)
dR_sw   = cfrk.diff_pert_mon_cont_12mon_TLL_fast((var_pert['rsdt'].rsdt.values-var_pert['rsut'].rsut.values),
                                                 (var_cont['rsdt'].rsdt.values-var_cont['rsut'].rsut.values) )
dR_lw   = cfrk.diff_pert_mon_cont_12mon_TLL_fast((-var_pert['rlut'].rlut.values),
                                                 (-var_cont['rlut'].rlut.values) )
dRcs_sw = cfrk.diff_pert_mon_cont_12mon_TLL_fast((var_pert['rsdt'].rsdt.values-var_pert['rsutcs'].rsutcs.values),
                                                 (var_cont['rsdt'].rsdt.values-var_cont['rsutcs'].rsutcs.values) )
dRcs_lw = cfrk.diff_pert_mon_cont_12mon_TLL_fast((-var_pert['rlutcs'].rlutcs.values),
                                                 (-var_cont['rlutcs'].rlutcs.values) )
plev_weight = cfrk.RK_plev_weight(var_cont['hus'].hus.plev.values)
dR_wv_lw    = cfrk.RK_compute_TPLL_plev_fast(omega_wv,f_RK.lw_q.values.astype('float32'),    plev_weight)
dR_wv_sw    = cfrk.RK_compute_TPLL_plev_fast(omega_wv,f_RK.sw_q.values.astype('float32'),    plev_weight)
dR_wvcs_lw  = cfrk.RK_compute_TPLL_plev_fast(omega_wv,f_RK.lwclr_q.values.astype('float32'), plev_weight)
dR_wvcs_sw  = cfrk.RK_compute_TPLL_plev_fast(omega_wv,f_RK.swclr_q.values.astype('float32'), plev_weight)
dR_Ta       = cfrk.RK_compute_TPLL_plev_fast(ta_anom, f_RK.lw_ta.values.astype('float32'),   plev_weight)
dR_Tacs     = cfrk.RK_compute_TPLL_plev_fast(ta_anom, f_RK.lwclr_ta.values.astype('float32'),plev_weight)
dR_Ts       = cfrk.RK_compute_TLL_fast (ts_anom,f_RK.lw_ts.values.astype('float32'))
dR_Tscs     = cfrk.RK_compute_TLL_fast (ts_anom,f_RK.lwclr_ts.values.astype('float32'))
dR_alb      = cfrk.RK_compute_TLL_fast (alb_anom,f_RK.sw_alb.values.astype('float32'))
dR_albcs    = cfrk.RK_compute_TLL_fast (alb_anom,f_RK.swclr_alb.values.astype('float32'))

## dR due to cloud change
Dcs_lw   = dRcs_lw - (dR_Tacs - dR_Tscs - dR_wvcs_lw)
Dcs_sw   = dRcs_sw - (- dR_albcs - dR_wvcs_sw)
D_lw     = Dcs_lw / 1.16
D_sw     = Dcs_sw / 1.16
dR_c_lw  = dR_lw - D_lw - (dR_Ta - dR_Ts - dR_wv_lw)
dR_c_sw  = dR_sw - D_sw - (- dR_alb - dR_wv_sw)
np.nanmean(dR_c_lw+dR_c_sw)

In [None]:
%%time
# xarray version (slow but easy to understand and modify)

ta_anom0 = var_pert['ta'].ta.groupby('time.month') - var_cont['ta'].ta
omega_wv0 = cfrk.omega_wv_xarray (var_pert['hus'].hus, var_cont['hus'].hus, var_cont['ta'].ta)
#### dt, dts, omega_wv, dRt_l, dRt_s, dRt_lcs, dRt_scs
ts_anom0 = var_pert['ts'].ts.groupby('time.month') - var_cont['ts'].ts
alb_pert = var_pert['rsus'].rsus/var_pert['rsds'].rsds
alb_pert = alb_pert.where(np.isfinite(alb_pert),0)
alb_cont = var_cont['rsus'].rsus/var_cont['rsds'].rsds
alb_cont = alb_cont.where(np.isfinite(alb_cont),0)
alb_anom_0 = alb_pert.groupby('time.month') - alb_cont
dR0   = (var_pert['rsdt'].rsdt-var_pert['rsut'].rsut-var_pert['rlut'].rlut).groupby('time.month') \
       -(var_cont['rsdt'].rsdt-var_cont['rsut'].rsut-var_cont['rlut'].rlut)
dRcs0 = (var_pert['rsdt'].rsdt-var_pert['rsutcs'].rsutcs-var_pert['rlutcs'].rlutcs).groupby('time.month') \
       -(var_cont['rsdt'].rsdt-var_cont['rsutcs'].rsutcs-var_cont['rlutcs'].rlutcs)
#### dR_wv, dR_T, dR_Ts, dR_alb + clear sky
dR_wv0    = cfrk.RK_compute_TPLL(omega_wv0,f_RK.lw_q+f_RK.sw_q)
dR_wvcs0  = cfrk.RK_compute_TPLL(omega_wv0,f_RK.lwclr_q+f_RK.swclr_q)
dR_Ta0    = cfrk.RK_compute_TPLL(ta_anom0,f_RK.lw_ta)
dR_Tacs0  = cfrk.RK_compute_TPLL(ta_anom0,f_RK.lwclr_ta)
dR_Ts0    = cfrk.RK_compute_suf (ts_anom0,f_RK.lw_ts)
dR_Tscs0  = cfrk.RK_compute_suf (ts_anom0,f_RK.lwclr_ts)
dR_alb0   = cfrk.RK_compute_suf (alb_anom_0*100,f_RK.sw_alb)
dR_albcs0 = cfrk.RK_compute_suf (alb_anom_0*100,f_RK.swclr_alb)
D_00   = dRcs0 - (dR_Tacs0 - dR_Tscs0 - dR_albcs0 - dR_wvcs0)
D0     = D_00 / 1.16
dR_c0  = dR0 - D0 - (dR_Ta0 - dR_Ts0 - dR_alb0 - dR_wv0)
print(dR_c0.mean().values)

In [None]:

## write to file
ds_write = xr.Dataset()
time_ds = var_pert['ts'].time
ds_write.coords['time'] = time_ds
ds_write.coords['lat']  = var_pert['ts'].lat
ds_write.coords['lon']  = var_pert['ts'].lon
ds_write.attrs['RK_Info'] = file_rk
ds_write.attrs['data_dir'] = '/tigress/cw55/data/CMIP6_post/CMIP6_post_large_scratch/experi/var/var.dvc_info'
ds_write.attrs['dvc_info'] = dvc_info

ds_write['dR_wv_lw'] = (('time','lat','lon'),dR_wv_lw)
ds_write['dR_wv_sw'] = (('time','lat','lon'),dR_wv_sw)
ds_write['dR_wvcs_lw'] = (('time','lat','lon'),dR_wvcs_lw)
ds_write['dR_wvcs_sw'] = (('time','lat','lon'),dR_wvcs_sw)
ds_write['dR_Ta'] = (('time','lat','lon'),dR_Ta)
ds_write['dR_Tacs'] = (('time','lat','lon'),dR_Tacs)
ds_write['dR_Ts'] = (('time','lat','lon'),dR_Ts)
ds_write['dR_Tscs'] = (('time','lat','lon'),dR_Tscs)
ds_write['dR_alb'] = (('time','lat','lon'),dR_alb)
ds_write['dR_albcs'] = (('time','lat','lon'),dR_albcs)
ds_write['dR_c_lw'] = (('time','lat','lon'),dR_c_lw)
ds_write['dR_c_sw'] = (('time','lat','lon'),dR_c_sw)
ds_write['Dcs_lw'] = (('time','lat','lon'),Dcs_lw)
ds_write['Dcs_sw'] = (('time','lat','lon'),Dcs_sw)
ds_write['dR_sw'] = (('time','lat','lon'),dR_sw)
ds_write['dR_lw'] = (('time','lat','lon'),dR_lw)
ds_write['dRcs_sw'] = (('time','lat','lon'),dRcs_sw)
ds_write['dRcs_sw'] = (('time','lat','lon'),dRcs_sw)
ds_write['dts'] = (('time','lat','lon'),ts_anom)
#file size GB
print(ds_write.nbytes/1e9)

# save the rk results for future use
# out_filename = 'rk.test.toa'
# print(out_filename)
# ds_write.to_netcdf(out_filename)

# Plot section

In [None]:
dR_temp_gm = global_mean_xarray_nan(ds_write['dR_Ta']+ds_write['dR_Ts']).groupby('time.year').mean()
dR_clou_gm = global_mean_xarray_nan(ds_write['dR_c_lw']+ds_write['dR_c_sw']).groupby('time.year').mean()
dR_wv_gm = global_mean_xarray_nan(ds_write['dR_wv_lw']+ds_write['dR_wv_sw']).groupby('time.year').mean()
dR_albe_gm = global_mean_xarray_nan(ds_write['dR_alb']).groupby('time.year').mean()
dts_gm = global_mean_xarray_nan(ds_write['dts']).groupby('time.year').mean()

In [None]:
plt.close()
fig0 = plt.figure(figsize=(4,3),dpi=150)
ax1   = fig0.add_subplot(111)
ax1.plot(dts_gm,dR_temp_gm/5,'.',label='dR_Temp/5')
ax1.plot(dts_gm,dR_clou_gm,'.',label='dR_Cloud')
ax1.plot(dts_gm,dR_albe_gm,'.',label='dR_Albe')
ax1.plot(dts_gm,dR_wv_gm,'.',label='dR_WV')
ax1.set_xlabel('dTs')
plt.legend()
plt.show()

In [None]:
# #### exit programm f####
print("--- %s seconds ---" % (time.time() - sstart_time0))