In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

from vcm.catalog import catalog
import intake
import fsspec
import xarray as xr
import numpy as np

from vcm.fv3.metadata import standardize_fv3_diagnostics

SECONDS_PER_DAY = 86400


In [None]:
grid = catalog["grid/c48"].read()
area = grid["area"]

land_sea_mask = catalog["landseamask/c48"].read()["land_sea_mask"]
area_land_mask = area.where(land_sea_mask==1)

In [None]:
verif_physics =  standardize_fv3_diagnostics(catalog['40day_c48_gfsphysics_15min_may2020'].to_dask())
verif_physics["total_precip_to_surface"] = verif_physics["PRATEsfc"] * SECONDS_PER_DAY

verif_physics["total_precip_to_surface_spatial_mean_physics_land"]= (
    (verif_physics["total_precip_to_surface"] * area_land_mask).sum(skipna=True, dim=["tile", "x", "y"]) 
    / area_land_mask.sum(skipna=True)
)

In [None]:
experiments = [
    "training-control-ml-tendencies-only-RF",
    "training-control-ml-tendencies-rad-RF",
    "training-prescribed-ml-tendencies-rad-RF",
    "training-prescribed-ml-tendencies-rad-NN",
    "baseline",
]
paths = [
    f"gs://vcm-ml-public/argo/prog-report-nudge-to-3km-sensitivity/{experiment}"
    for experiment in experiments
]
diags = {}
for experiment, path in zip(experiments, paths):
    print(path)
    with fsspec.open(f"{path}/diags.nc", "rb") as f:
        diags[experiment] = xr.open_dataset(f).load()

In [None]:
fig = plt.figure()

fig=plt.figure(figsize=(8,7))

downsample_times = [
    diags["training-prescribed-ml-tendencies-rad-RF"].time.values[0]
] + [t for t in diags["training-prescribed-ml-tendencies-rad-RF"].time.values if t.hour==0]
subset_times = list(set(downsample_times).intersection(list(verif_physics.time.values)))

diags["baseline"] \
    ["total_precip_to_surface_spatial_mean_physics_land"].sel(time=subset_times).sortby("time") \
    .resample(time='1D').mean('time') \
    .plot(
        label="baseline",
        linewidth=1.5,
    )

diags["training-control-ml-tendencies-only-RF"] \
    ["total_precip_to_surface_spatial_mean_physics_land"].sortby("time") \
    .resample(time='1D').mean('time') \
    .plot(
        label="training-control-ml-tendencies-only-RF",
        linewidth=1.5,
    )
diags["training-control-ml-tendencies-rad-RF"] \
    ["total_precip_to_surface_spatial_mean_physics_land"].sel(time=subset_times).sortby("time") \
    .resample(time='1D').mean('time') \
    .plot(
        label="training-control-ml-tendencies-rad-RF",
        linewidth=1.5,
    )
diags["training-prescribed-ml-tendencies-rad-RF"] \
    ["total_precip_to_surface_spatial_mean_physics_land"].sel(time=subset_times).sortby("time") \
    .resample(time='1D').mean('time') \
    .plot(
        label="training-prescribed-ml-tendencies-rad-RF",
        linestyle="-",
        linewidth=1.5,
    )
diags["training-prescribed-ml-tendencies-rad-NN"] \
["total_precip_to_surface_spatial_mean_physics_land"].sel(time=subset_times).sortby("time") \
    .resample(time='1D').mean('time') \
    .plot(
        label="training-prescribed-ml-tendencies-rad-NN",
        linestyle="-",
        linewidth=1.5,
    )


verif_physics["total_precip_to_surface_spatial_mean_physics_land"] \
    .sel(time=subset_times) \
    .sortby("time") \
    .plot(
        linestyle="--",
        label="fine resolution",
        linewidth=3,
        color="black",
    )

plt.legend()
plt.xlabel("date", fontsize=14)
plt.ylabel("mm / day", fontsize=14)

plt.xlim(min(subset_times), max(subset_times))
plt.ylim(0.6, 4.2)

plt.grid(True, axis="y", alpha=0.4)
plt.xticks(fontsize=14)

plt.yticks(np.linspace(0.6, 4.2, 10), fontsize=14)
plt.title("Surface precipitation, land average", fontsize=16)
plt.legend(loc="lower right", fontsize=14)

plt.savefig("figures/sensitivity_precip_time_series.png")