In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

import os
import cftime
import intake
import fsspec
import numpy as np
import pandas as pd
import seaborn as sb
import xarray as xr
import json
import copy

import fv3viz as viz
from vcm.catalog import catalog
from vcm.fv3.metadata import standardize_fv3_diagnostics

def weighted_average(array, weights, axis=None):
    return np.nansum(array * weights, axis=axis) / np.nansum(weights, axis=axis)


MAPPABLE_VAR_KWARGS = {
    "coord_x_center": "x",
    "coord_y_center": "y",
    "coord_x_outer": "x_interface",
    "coord_y_outer": "y_interface",
    "coord_vars": {
        "lonb": ["y_interface", "x_interface", "tile"],
        "latb": ["y_interface", "x_interface", "tile"],
        "lon": ["y", "x", "tile"],
        "lat": ["y", "x", "tile"],
    },
}

SECONDS_PER_DAY = 86400

   
grid = catalog["grid/c48"].read()
area = grid["area"]

land_sea_mask = catalog["landseamask/c48"].read()["land_sea_mask"]

verif = standardize_fv3_diagnostics(catalog['40day_c48_atmos_8xdaily_additional_vars_may2020'].to_dask()) \
    .rename({"x": "grid_xt", "y": "grid_yt"})


In [None]:
ics = ["20160805.000000", "20160813.000000", "20160821.000000", "20160829.000000"]
start_datetimes = [cftime.DatetimeJulian(2016, 8, day, 0, 0, 0, 0) for day in [11, 19, 27,]] + [cftime.DatetimeJulian(2016, 9, 4, 0, 0, 0, 0)]

baseline_no_ML_paths = [f"gs://vcm-ml-experiments/2021-04-13/baseline-physics-run-201608{start_day}-start-rad-step-1800s" for start_day in ["05", "13", "21", "29"]]
temperature_moisture_RF_paths = [f"gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/control-dq1-dq2-rf/initial_conditions_runs/{ic}" for ic in ics] 
temperature_moisture_winds_RF_paths = [f"gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/control-rf/prognostic_run_tendencies_only_ics/{ic}" for ic in ics]
temperature_moisture_winds_prescribed_sfc_RF_paths = [f"gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/rf/initial_conditions_runs/{ic}" for ic in ics]
temperature_moisture_winds_prescribed_sfc_NN_ensemble_paths = [f"gs://vcm-ml-experiments/2021-05-11-nudge-to-c3072-corrected-winds/nn-ensemble-model/initial_conditions_runs_rectified_nn_rad/{ic}" for ic in ics]

temperature_moisture_prescribed_sfc_RF_paths = [f"gs://vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/rf/initial_conditions_runs/{ic}" for ic in ics]
temperature_moisture_prescribed_sfc_NN_ensemble_paths = [f"gs://vcm-ml-experiments/2021-06-21-nudge-to-c3072-dq1-dq2-only/nn-ensemble-model/initial_conditions_runs/{ic}" for ic in ics]



In [None]:
def get_prog_run_errors(
        diags_path,
        verif,
        var,
        start_time=cftime.DatetimeJulian(2016, 8, 11, 0, 0, 0, 0)
):
    print(diags_path)
    ds = intake.open_zarr(diags_path).to_dask().sortby("time")
    verif_precip = verif.sortby("time")
    prog_times = ds.sel(time=slice(start_time, None)).time.values
    verif_times = verif.sortby("time")
    verif_times = verif_times.sel(time=slice(start_time, None)).time.values
    overlap_times = sorted(list(set(verif_times).intersection(prog_times)))
    
    da_verif = verif.sortby("time")[var].sel(time=overlap_times)
    ds = ds.sortby("time").sel(time=overlap_times)
    ds[f"{var}_bias"] = (ds[var] - da_verif)
    ds[f"{var}_rmse"] = np.sqrt((ds[var] - da_verif)**2)
    return ds[[f"{var}_bias", f"{var}_rmse"]].mean("time").load()


def concat_ics(ic_paths, var, start_datetimes, ic_coord, verif_physics):
    ic_data = []
    for path, start_datetime in zip(ic_paths, start_datetimes):
        ds = get_prog_run_errors(
                os.path.join(path, "atmos_dt_atmos.zarr"),
                verif_physics,
                start_time=start_datetime,
                var=var
        )
        
        ic_data.append(ds)
    
    return xr.concat(ic_data, dim=pd.Index(ic_coord, name="ic"))

In [None]:
baseline_no_ML = get_prog_run_errors(
    "gs://vcm-ml-experiments/2021-04-13/baseline-physics-run-20160805-start-rad-step-1800s/atmos_dt_atmos.zarr",
    verif,
    var="TMP200", 
    start_time=cftime.DatetimeJulian(2016, 8, 11, 0, 0, 0, 0)
)

In [None]:
args = ["TMP200", start_datetimes, ics, verif]

baseline_no_ML = concat_ics(baseline_no_ML_paths, *args)
temperature_moisture_RF = concat_ics(temperature_moisture_RF_paths, *args)
temperature_moisture_winds_RF = concat_ics(temperature_moisture_winds_RF_paths, *args)
temperature_moisture_winds_prescribed_sfc_RF = concat_ics(temperature_moisture_winds_prescribed_sfc_RF_paths, *args)
temperature_moisture_winds_prescribed_sfc_NN_ensemble = concat_ics(temperature_moisture_winds_prescribed_sfc_NN_ensemble_paths, *args)

temperature_moisture_prescribed_sfc_RF  = concat_ics(temperature_moisture_prescribed_sfc_RF_paths, *args)
temperature_moisture_prescribed_sfc_NN_ensemble = concat_ics(temperature_moisture_prescribed_sfc_NN_ensemble_paths, *args)

In [None]:
var= "TMP200_rmse"

domain_avg_biases = []

labels = [
    "base-no-ML",
    "Tq-RF",
    "Tquv-RF",
    "TquvR-RF",
    "TquvR-NN",
    "TqR-RF",
    "TqR-NN",
]

datasets = [
    baseline_no_ML,
    temperature_moisture_RF, 
    temperature_moisture_winds_RF,
    temperature_moisture_winds_prescribed_sfc_RF, 
    temperature_moisture_winds_prescribed_sfc_NN_ensemble,
    temperature_moisture_prescribed_sfc_RF ,
    temperature_moisture_prescribed_sfc_NN_ensemble
]
for domain in ["global", "land", "ocean"]:
    for label, ds_ in zip(labels, datasets):
        for ic in ics:
            ds = copy.copy(ds_).sel(ic=ic).rename({"grid_xt": "x", "grid_yt": "y"})
            if domain == "global":
                area_masked = area
                da = ds[var]
            elif domain == "land":
                area_masked = area.where(land_sea_mask==1)
                da = ds[var].where(land_sea_mask==1)
            elif domain == "ocean":
                area_masked = area.where(land_sea_mask==0)
                da = ds[var].where(land_sea_mask==0)     
            bias = ((da * area_masked).sum(skipna=True, dim=["tile", "x", "y"]) / area_masked.sum(skipna=True)).values.item()
            rmse = np.sqrt( 
                (da**2 * area_masked).sum(skipna=True, dim=["tile", "x", "y"]) / area_masked.sum(skipna=True)
            ).values.item()        
            domain_avg_biases.append([ic, domain, label, rmse, bias,])

df = pd.DataFrame(domain_avg_biases, columns=["IC", "domain", "dataset", "RMSE", "bias"])

In [None]:
ablation_tmp200_rmse = {
    "variable": "TMP200_rmse",
    "units": "K"
}
for experiment in labels:
    row = df.loc[df["domain"]=="global"].loc[df["IC"]=="20160805.000000"].loc[df["dataset"]==experiment]
    ablation_tmp200_rmse [experiment] = row["bias"].item()


with open("tables/ablation_tmp200_time_mean_rmse.json", "w") as f:
    json.dump([ablation_tmp200_rmse], f, indent=4)