In [None]:
%matplotlib inline
import logging
from os.path import join

import konrad
import matplotlib.pyplot as plt
import typhon
import xarray as xr
from scipy import interpolate

logging.getLogger().setLevel(logging.ERROR)

plt.style.use(typhon.plots.styles('typhon'))


def results2xarray(fluxes, heating):
    phlev, flxu, flxd = np.split(fluxes, 3, axis=1)
    plev, htngrt = np.split(heating, 2, axis=1)
    
    coords = {
        'time': [0],
        'plev': plev.ravel(),
        'phlev': phlev.ravel(),
    }
    
    dataset = xr.Dataset(coords=coords)
    dataset['lw_flxu'] = xr.DataArray(flxu.T, dims=('time', 'phlev'))
    dataset['lw_flxd'] = -xr.DataArray(flxd.T, dims=('time', 'phlev'))
    dataset['lw_htngrt'] = 24 * 3600 * xr.DataArray(htngrt.T, dims=('time', 'plev'))
    
    konrad.utils.append_description(dataset)
    
    return dataset


def radiation_from_atmfield(atm_fields_compact, **kwargs):
    atmosphere = konrad.atmosphere.Atmosphere.from_atm_fields_compact(atm_fields_compact)        
    surface = konrad.surface.SurfaceHeatCapacity.from_atmosphere(atmosphere)
    atmosphere.attrs['surface'] = surface
    radiation = konrad.radiation.RRTMG(**kwargs)

    return radiation.get_heatingrates(atmosphere)


def dataset_diff(one, another, key):
    height_dim = one[key].dims[1]
        
    func = interpolate.interp1d(
        one[height_dim],
        one[key].values.ravel(),
        bounds_error=False,
        fill_value='extrapolate',
    )

    return another[height_dim], another[key].values.ravel() - func(another[height_dim])
        
    
def calc_co2_factor(atm_field, co2_baseline=348e-6):
    co2 = atm_field.get('abs_species-CO2', keep_dims=False)
    return np.mean(co2 / co2_baseline)
    
    
flux_dir = '/work/um0878/users/lkluft/fluxes_arts/'
atmdata = 'rce-rrtmg.xml'

batch_atmfields = xml.load(join(flux_dir, 'atmdata', atmdata))[::-1]
batch_fluxes = xml.load(join(flux_dir, 'results', 'fluxes_' + atmdata))[::-1]
batch_htngrt = xml.load(join(flux_dir, 'results' , 'heatingrates_' + atmdata))[::-1]

# Heatingrate difference

In [None]:
fig, axes = plt.subplots(2, 2, sharey=True, figsize=(12, 10))

axes = axes.ravel()
dummy = axes[0].plot([], label='Downward', color='k', linestyle='dashed')
dummy = axes[0].plot([], label='Upward', color='k', linestyle='dotted')

for i in range(len(batch_atmfields)):
    arts = results2xarray(batch_fluxes[i], batch_htngrt[i])
    rrtmg = radiation_from_atmfield(batch_atmfields[i])
    scale = calc_co2_factor(batch_atmfields[i])
    
    dummy = axes[2].plot([], color=f'C{i}', label=rf'$\sf CO_2 \times {scale:g}$')
    
    # Fluxes
    flux_keys = {
        'lw_flxd': {'linestyle': 'dashed'},
        'lw_flxu': {'linestyle': 'dotted'},
    }
    for key, kwargs in flux_keys.items():
        # Total fluxes
        typhon.plots.profile_p_log(
            rrtmg['phlev'].values, rrtmg[key][-1, :],
            color=f'C{i}',
            ax=axes[0],
            **kwargs,
        )
        axes[0].set_xlabel('Flux [$\sf Wm^{-2}$]', fontsize='small')
        axes[0].set_xlim(left=0, right=500)
        axes[0].set_title('RRTMG')
    
        # Differences
        axes[1].axvline(0, color='black', linewidth=0.8, zorder=-1)
        typhon.plots.profile_p_log(
            *dataset_diff(arts, rrtmg, key),
            color=f'C{i}',
            ax=axes[1],
            **kwargs,
        )
        axes[1].set_xlabel('Flux difference [$\sf Wm^{-2}$]', fontsize='small')
        axes[1].set_title('$\sf RRTMG - ARTS$')
        
    # Total Heatingrates
    typhon.plots.profile_p_log(
        rrtmg['plev'].values, rrtmg['lw_htngrt'][-1, :],
        color=f'C{i}',
        ax=axes[2],
    )
    axes[2].set_xlabel('Heating rate [K/day]', fontsize='small')

    # Heatingrate differences
    axes[3].axvline(0, color='black', linewidth=0.8, zorder=-1)
    typhon.plots.profile_p_log(
        *dataset_diff(arts, rrtmg, 'lw_htngrt'),
        color=f'C{i}',
        ax=axes[3],
    )
    axes[3].set_xlabel('Heating rate difference [K/day]', fontsize='small')
    axes[3].set_xlim(-0.6, 0.6)

axes[0].set_ylim(top=0.01e2)
axes[0].legend(loc='upper left', fontsize='small')
axes[2].legend(loc='lower left', fontsize='small')
fig.savefig('plots/flux_comparison.pdf')

# Save bias-correction to `netCDF`

In [None]:
for i in range(len(batch_atmfields)):
    arts = results2xarray(batch_fluxes[i], batch_htngrt[i])
    rrtmg = radiation_from_atmfield(batch_atmfields[i])
    scale = calc_co2_factor(batch_atmfields[i])
    
    plev, htngrt_bias = dataset_diff(arts, rrtmg, 'lw_htngrt')
    phlev, lw_flxd_bias = dataset_diff(arts, rrtmg, 'lw_flxd')
    phlev, lw_flxu_bias = dataset_diff(arts, rrtmg, 'lw_flxu')
    
    ds = xr.Dataset(coords={'plev': plev, 'phlev': phlev})
    ds['lw_flxd'] = xr.DataArray(lw_flxd_bias.ravel(), dims=['phlev'])
    ds['lw_flxu'] = xr.DataArray(lw_flxu_bias.ravel(), dims=['phlev'])

    ds.to_netcdf(f'results/bias-correction/radiation-bias_{scale:g}.nc')