### 2023-06-05 prescribed cloud radiative fluxes

Evaluate radiative fluxes for the different coarsened-fine cloud cases.

In [1]:
import cftime
import xarray as xr
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import matplotlib as mpl
from cycler import cycler
import contextlib
import fv3viz
import cftime
import dataclasses
import os
from dask.distributed import Client
from vcm.catalog import catalog as CATALOG
from vcm.fv3.metadata import standardize_fv3_diagnostics

In [2]:
RADIATION_RUNS = {
    'max_random_overlap_ccnorm_false': 'gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-max-random/fv3gfs_run',
    'random_overlap_ccnorm_false': 'gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-random/fv3gfs_run',
    'decorr_overlap_ccnorm_false': 'gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-decorr/fv3gfs_run',
    'max_random_overlap_ccnorm_true': 'gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-cc-max-random/fv3gfs_run',
    'random_overlap_ccnorm_true': 'gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-cc-random/fv3gfs_run',
    'decorr_overlap_ccnorm_true': 'gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-cc-decorr/fv3gfs_run',
}
C48_REFERENCE_KEY = '10day_c48_PIRE_ccnorm_gfsphysics_15min_may2023'
VERIFICATION_NAME = 'coarsened_fine_radiation'

In [3]:
@dataclasses.dataclass
class RadiationDiag:
    fortran_diag_name: str
    python_diag_name: str
    scale: float
    
FLUXES = [
    RadiationDiag('DSWRFsfc', 'total_sky_downward_shortwave_flux_at_surface', 100.0),
    RadiationDiag('DLWRFsfc', 'total_sky_downward_longwave_flux_at_surface', 30.0),
    RadiationDiag('USWRFtoa', 'total_sky_upward_shortwave_flux_at_top_of_atmosphere', 100.0),
    RadiationDiag('ULWRFtoa', 'total_sky_upward_longwave_flux_at_top_of_atmosphere', 30.0),
]
TIME_SLICE = slice(
    cftime.DatetimeJulian(2020, 8, 7, 0, 30, 0),
    cftime.DatetimeJulian(2020, 8, 9, 23, 30, 0),
    4
) # hourly over days 8-10 validation
GRID = CATALOG['grid/c48'].read()
COLOR_CYCLE = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [4]:
def _rename_fortran_fluxes(ds):
    ds = standardize_fv3_diagnostics(ds)
    ds_rename = xr.Dataset()
    for diag in FLUXES:
        if diag.fortran_diag_name in ds:
            ds_rename[diag.python_diag_name] = ds[diag.fortran_diag_name]
    return ds_rename

def _rename_python_fluxes(ds):
    ds_rename = xr.Dataset()
    for diag in FLUXES:
        name_with_suffix = diag.python_diag_name + "_python"
        if name_with_suffix in ds:
            ds_rename[diag.python_diag_name] = ds[name_with_suffix]
    return ds_rename

def _get_total_downward_flux(
    total_sky_downward_shortwave_flux_at_surface,
    total_sky_downward_longwave_flux_at_surface
):
    total_sky_downward_flux_at_surface = (
        total_sky_downward_shortwave_flux_at_surface + 
        total_sky_downward_longwave_flux_at_surface
    ).assign_attrs({
        'long_name': 'total sky downward flux at surface',
        'units': 'W/m^2'
    })
    return total_sky_downward_flux_at_surface

def _get_total_upward_flux(
    total_sky_upward_shortwave_flux_at_top_of_atmosphere,
    total_sky_upward_longwave_flux_at_top_of_atmosphere
):
    total_sky_upward_flux_at_top_of_atmosphere = (
        total_sky_upward_shortwave_flux_at_top_of_atmosphere + 
        total_sky_upward_longwave_flux_at_top_of_atmosphere
    ).assign_attrs({
        'long_name': 'total sky upward flux at top of atmosphere',
        'units': 'W/m^2'
    })
    return total_sky_upward_flux_at_top_of_atmosphere
        
def get_zarr_fluxes(rundir, zarrname, rename=None):
    print(os.path.join(rundir, zarrname + '.zarr'))
    ds = xr.open_zarr(os.path.join(rundir, zarrname + '.zarr'), consolidated=True)
    ds = _rename_python_fluxes(ds)
    ds = ds.assign({
        'total_sky_downward_flux_at_surface': _get_total_downward_flux(
            ds.total_sky_downward_shortwave_flux_at_surface,
            ds.total_sky_downward_longwave_flux_at_surface
        ),
        'total_sky_upward_flux_at_top_of_atmosphere': _get_total_upward_flux(
            ds.total_sky_upward_shortwave_flux_at_top_of_atmosphere,
            ds.total_sky_upward_longwave_flux_at_top_of_atmosphere
        )
    })
    return ds

def get_reference_fine_fluxes():
    ds = CATALOG[C48_REFERENCE_KEY].to_dask()
    ds = _rename_fortran_fluxes(ds)
    ds = ds.assign({
        'total_sky_downward_flux_at_surface': _get_total_downward_flux(
            ds.total_sky_downward_shortwave_flux_at_surface,
            ds.total_sky_downward_longwave_flux_at_surface
        ),
        'total_sky_upward_flux_at_top_of_atmosphere': _get_total_upward_flux(
            ds.total_sky_upward_shortwave_flux_at_top_of_atmosphere,
            ds.total_sky_upward_longwave_flux_at_top_of_atmosphere
        )
    })
    return ds

In [5]:
def get_bias(coarse_fluxes, reference_fluxes, provenance):
    flux_bias = xr.Dataset()
    for name in reference_fluxes:
        bias = coarse_fluxes[name] - reference_fluxes[name]
        bias = bias.assign_attrs({
            'long_name': f"{name} bias [{provenance}]",
            'units': coarse_fluxes[name].attrs.get('units')
        })
        flux_bias[name] = bias
    return flux_bias

def weighted_rms(da, weights, dims=['x', 'y', 'tile']):
    numerator = ((da ** 2) * weights).sum(dim=dims)
    denominator = weights.sum(dim=dims)
    rms = np.sqrt(numerator / denominator)
    return rms.assign_attrs(da.attrs)

def weighted_mean(da, weights, dims=['x', 'y', 'tile']):
    return da.weighted(weights).mean(dim=dims)

def time_mean(ds):
    time_mean_ds = xr.Dataset()
    with xr.set_options(keep_attrs=True):
        for var in ds:
            if "time" in ds[var].dims:
                new_name = var + '_time_mean'
                time_mean_ds[new_name] = ds[var].mean('time')
    return time_mean_ds

In [6]:
def get_metrics_df(bias, context, grid=GRID.area):
    vars_2d = [var for var in bias.data_vars if "pressure" not in bias[var].dims]
    bias_2d = bias[vars_2d]
    bias_over_time = weighted_mean(bias_2d, GRID.area).mean('time')
    rms_time_mean = weighted_rms(bias_2d.mean('time'), GRID.area)
    rms_over_time = weighted_rms(bias_2d, GRID.area).mean('time')
    metrics_df = xr.concat(
        [bias_over_time, rms_time_mean, rms_over_time],
        dim=xr.DataArray(['bias', 'RMSE of time-mean pattern', 'time-mean of inst. RMSE'], dims=['metric'], name='metric')
    ).to_dataframe()
    metrics_df.attrs = {'context': context}
    return metrics_df

def get_r_squared(pred, truth, grid=GRID.area):
    vars_2d = [var for var in truth.data_vars if "pressure" not in truth[var].dims]
    bias_2d = (pred - truth)[vars_2d]
    pred_rmse = weighted_rms(bias_2d, GRID.area)
    truth_mean = weighted_mean(truth, GRID.area)
    truth_stdev = weighted_rms(truth - truth_mean, GRID.area)
    return 1 - (pred_rmse ** 2) / (truth_stdev ** 2)

In [7]:
reference_fine_diags = get_reference_fine_fluxes().sel(time=TIME_SLICE)

radiation_datasets = {}
for k, v in RADIATION_RUNS.items():
    radiation_datasets[k] = get_zarr_fluxes(v, 'radiative_fluxes').sel(time=TIME_SLICE)
radiation_datasets[VERIFICATION_NAME] = reference_fine_diags

gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-max-random/fv3gfs_run/radiative_fluxes.zarr
gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-random/fv3gfs_run/radiative_fluxes.zarr
gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-decorr/fv3gfs_run/radiative_fluxes.zarr
gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-cc-max-random/fv3gfs_run/radiative_fluxes.zarr
gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-cc-random/fv3gfs_run/radiative_fluxes.zarr
gs://vcm-ml-experiments/cloud-ml/2023-07-04/cloud-ml-prognostic-run-prescribed-cloud-cc-decorr/fv3gfs_run/radiative_fluxes.zarr


In [8]:
Client()

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 29.39 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:33699,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 29.39 GiB

0,1
Comm: tcp://127.0.0.1:39679,Total threads: 2
Dashboard: http://127.0.0.1:36879/status,Memory: 7.35 GiB
Nanny: tcp://127.0.0.1:39395,
Local directory: /tmp/dask-worker-space/worker-iltylau8,Local directory: /tmp/dask-worker-space/worker-iltylau8

0,1
Comm: tcp://127.0.0.1:42641,Total threads: 2
Dashboard: http://127.0.0.1:40875/status,Memory: 7.35 GiB
Nanny: tcp://127.0.0.1:44117,
Local directory: /tmp/dask-worker-space/worker-xpyvgj61,Local directory: /tmp/dask-worker-space/worker-xpyvgj61

0,1
Comm: tcp://127.0.0.1:38369,Total threads: 2
Dashboard: http://127.0.0.1:45873/status,Memory: 7.35 GiB
Nanny: tcp://127.0.0.1:39345,
Local directory: /tmp/dask-worker-space/worker-thamm6sz,Local directory: /tmp/dask-worker-space/worker-thamm6sz

0,1
Comm: tcp://127.0.0.1:34173,Total threads: 2
Dashboard: http://127.0.0.1:40377/status,Memory: 7.35 GiB
Nanny: tcp://127.0.0.1:40207,
Local directory: /tmp/dask-worker-space/worker-6am_7ngp,Local directory: /tmp/dask-worker-space/worker-6am_7ngp


In [9]:
radiation_time_mean_datasets = {}
radiation_bias_datasets = {}
radiation_bias_time_mean_datasets = {}

for k in RADIATION_RUNS:
    print(k)
    radiation_datasets[k].load()
    radiation_time_mean_datasets[k] = time_mean(radiation_datasets[k])
    radiation_bias_datasets[k] = get_bias(
        radiation_datasets[k],
        radiation_datasets[VERIFICATION_NAME],
        f"{k} - coarsened fine clouds' radiation"
    ).load()
    radiation_bias_time_mean_datasets[k] = time_mean(radiation_bias_datasets[k])
    
print(VERIFICATION_NAME)
radiation_datasets[VERIFICATION_NAME].load()
radiation_time_mean_datasets[VERIFICATION_NAME] = time_mean(radiation_datasets[VERIFICATION_NAME])

max_random_overlap_ccnorm_false
random_overlap_ccnorm_false
decorr_overlap_ccnorm_false
max_random_overlap_ccnorm_true
random_overlap_ccnorm_true
decorr_overlap_ccnorm_true
coarsened_fine_radiation


In [10]:
radiation_run_metrics = {}
for k in RADIATION_RUNS:
    print(k)
    radiation_run_metrics[k] = get_metrics_df(radiation_bias_datasets[k], k)

max_random_overlap_ccnorm_false
random_overlap_ccnorm_false
decorr_overlap_ccnorm_false
max_random_overlap_ccnorm_true
random_overlap_ccnorm_true
decorr_overlap_ccnorm_true


In [11]:
r_squared_against_coarsened_fine_rad = []
for k in RADIATION_RUNS:
    r_squared_against_coarsened_fine_rad.append(
        get_r_squared(
            radiation_time_mean_datasets[k],
            radiation_time_mean_datasets[VERIFICATION_NAME]
        ).expand_dims({'run': [k]})
    )
r_squared_against_coarsened_fine_rad = xr.concat(r_squared_against_coarsened_fine_rad, dim='run')

In [12]:
vars_ = [
    'total_sky_downward_shortwave_flux_at_surface',
    'total_sky_downward_longwave_flux_at_surface',
    'total_sky_upward_shortwave_flux_at_top_of_atmosphere',
    'total_sky_upward_longwave_flux_at_top_of_atmosphere',
    'total_sky_downward_flux_at_surface',
    'total_sky_upward_flux_at_top_of_atmosphere'
]

threshold_runs = [int(name.split('cond')[1][:2])/100. if 'cond' in name else 0. for name in radiation_run_metrics]
bias_df = pd.DataFrame({
    k: v.loc['bias'].loc[vars_] for k, v in radiation_run_metrics.items()
}).T
bias_df['threshold'] = threshold_runs

In [13]:
print(bias_df.T.iloc[[0, 1, 2, 3]].to_latex(float_format="%.2f"))

\begin{tabular}{lrrrrrr}
\toprule
{} &  max\_random\_overlap\_ccnorm\_false &  random\_overlap\_ccnorm\_false &  decorr\_overlap\_ccnorm\_false &  max\_random\_overlap\_ccnorm\_true &  random\_overlap\_ccnorm\_true &  decorr\_overlap\_ccnorm\_true \\
\midrule
total\_sky\_downward\_shortwave\_flux\_at\_surface       &                            16.30 &                         8.80 &                        13.86 &                            4.69 &                      -16.90 &                       -0.74 \\
total\_sky\_downward\_longwave\_flux\_at\_surface        &                            -2.84 &                         2.54 &                        -1.50 &                           -0.70 &                       10.08 &                        1.45 \\
total\_sky\_upward\_shortwave\_flux\_at\_top\_of\_atmos... &                           -13.82 &                        -7.57 &                       -11.77 &                           -4.05 &                       14.14 &                 

In [14]:
r_squared_df = r_squared_against_coarsened_fine_rad[[var + '_time_mean' for var in vars_]].to_dataframe()

In [15]:
print(r_squared_df.T.iloc[[0, 1, 2, 3]].to_latex(float_format="%.2f"))

\begin{tabular}{lrrrrrr}
\toprule
run &  max\_random\_overlap\_ccnorm\_false &  random\_overlap\_ccnorm\_false &  decorr\_overlap\_ccnorm\_false &  max\_random\_overlap\_ccnorm\_true &  random\_overlap\_ccnorm\_true &  decorr\_overlap\_ccnorm\_true \\
\midrule
total\_sky\_downward\_shortwave\_flux\_at\_surface\_ti... &                             0.94 &                         0.98 &                         0.96 &                            0.99 &                        0.92 &                        1.00 \\
total\_sky\_downward\_longwave\_flux\_at\_surface\_tim... &                             0.99 &                         0.99 &                         1.00 &                            1.00 &                        0.97 &                        1.00 \\
total\_sky\_upward\_shortwave\_flux\_at\_top\_of\_atmos... &                             0.84 &                         0.93 &                         0.88 &                            0.96 &                        0.79 &              