In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

import cftime
import intake
import fsspec
import numpy as np
import pandas as pd
import xarray as xr
from dask.diagnostics import ProgressBar

import fv3viz as viz
from vcm.catalog import catalog
from vcm import local_time

def weighted_average(array, weights, axis=None):
    return np.nansum(array * weights, axis=axis) / np.nansum(weights, axis=axis)

In [None]:
with fsspec.open("gs://vcm-ml-public/argo/prog-report-nudge-to-3km-compare-dq1-dq2-only/neural_networks_dQ1_dQ2_only/diags.nc", "rb") as f:
    diags_nn = xr.open_dataset(f).load()
    
with fsspec.open("gs://vcm-ml-public/argo/prog-report-nudge-to-3km-compare-dq1-dq2-only/random_forests_dQ1_dQ2_only/diags.nc", "rb") as f:
    diags_rf = xr.open_dataset(f).load()
    
with fsspec.open("gs://vcm-ml-public/argo/prog-report-nudge-to-3km-nn-rf-comparison/baseline/diags.nc", "rb") as f:
    diags_baseline = xr.open_dataset(f).load()
    
    

In [None]:
nudged_run_diags_zarr = "gs://vcm-ml-experiments/2021-04-13-n2f-c3072/3-hrly-ave-rad-precip-setting-30-min-rad-timestep-shifted-start-tke-edmf/diags.zarr"
nudged_run_diags = intake.open_zarr(nudged_run_diags_zarr).to_dask()
nudged_run_diags = nudged_run_diags.isel(time=slice(475, None)).resample(time='1H').nearest() # follow what the prognostic run report does to compute diurnal cycles

In [None]:
grid = catalog['grid/c48'].to_dask()
mask = catalog['landseamask/c48'].to_dask().land_sea_mask

In [None]:
SECONDS_PER_DAY=86400

def total_precipitation_rate(physics_precipitation, column_integrated_moistening):
    total_precipitation_rate_unrectified = SECONDS_PER_DAY*(physics_precipitation - column_integrated_moistening)
    total_precipitation_rate = total_precipitation_rate_unrectified.where(total_precipitation_rate_unrectified > 0, 0.0)
    total_precipitation_rate.attrs = {'long_name': 'total precip rate to surface max(PRATE - <dQ2> - <nQ2>, 0)', 'units': 'mm/day'}
    total_precipitation_rate_unrectified.attrs = {'long_name': 'total precip rate to surface PRATE - <dQ2> - <nQ2>', 'units': 'mm/day'}
    return total_precipitation_rate.rename('total_precipitation_rate'), total_precipitation_rate_unrectified.rename('total_precipitation_rate_unrectified')

In [None]:
def diurnal_cycles(ds):
    local_time_ = local_time(ds, time="time", lon_var="lon")
    local_time_.attrs = {"long_name": "local time", "units": "hour"}
    ds["local_time"] = np.floor(local_time_)  # equivalent to hourly binning
    with xr.set_options(keep_attrs=True):
        diurnal_cycles = ds.drop("lon").groupby("local_time").mean()
    return diurnal_cycles

In [None]:
with ProgressBar():
    target_total_precipitation_rate, target_total_precipitation_rate_unrectified = total_precipitation_rate(nudged_run_diags.physics_precip, nudged_run_diags.net_moistening_due_to_nudging)
    target_total_precipitation_rate = target_total_precipitation_rate.load()
    target_total_precipitation_rate_unrectified = target_total_precipitation_rate_unrectified.load()

In [None]:
target_total_precipitation_rate_land = target_total_precipitation_rate.where(mask == 1.0)
target_total_precipitation_rate_unrectified_land = target_total_precipitation_rate_unrectified.where(mask == 1.0)

In [None]:
with ProgressBar():
    diurnal_cycles = diurnal_cycles(xr.merge([
        target_total_precipitation_rate_land,
        target_total_precipitation_rate_unrectified_land,
        nudged_run_diags.physics_precip,
        nudged_run_diags.net_moistening_due_to_nudging,
        grid.lon
    ]))

In [None]:
default_colormap = plt.rcParams['axes.prop_cycle'].by_key()['color']

colors = {
    "base-no-ML": default_colormap[3],
    "$TqR$-RF": default_colormap[0],
    "$TqR$-NN": default_colormap[5],
}

In [None]:
var = 'total_precip_to_surface_diurnal_land'

labels = ["base-no-ML", "$TqR$-NN", "$TqR$-RF"]
datasets = [diags_baseline, diags_nn, diags_rf]

verif_land_precip_diurnal = diags_baseline[var] - diags_baseline['diurn_bias_total-precipitation_diurnal_land']

fig = plt.figure(figsize=(8,7))

verif_land_precip_diurnal.plot(label="fine grid", linestyle="--", color="black", linewidth=3)
diurnal_cycles.total_precipitation_rate_unrectified.plot(label=r"physics precipitation - $\langle \Delta Q_q \rangle$", linestyle=":", color="black", linewidth=3)
for label, ds in zip(labels, datasets):
    ds[var].plot(label=label, linewidth=2, color=colors[label])

plt.legend(fontsize=12)
plt.grid(True, axis="both", alpha=0.4)
plt.xlim(0, 23)
plt.ylim(1.3, 4.4)
plt.xlabel("local time [hr]", fontsize=14)
plt.ylabel("[mm / day]", fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.title("Diurnal cycle of precipitation over land", fontsize=18)
fig.savefig("figures/Figure_12_PrecLandDiurnalCycle.pdf", bbox_inches = "tight")