In [None]:
import yaml
import os
import xarray as xr
import fsspec
import pandas as pd
from pathlib import Path

def convert_time_index_to_datetime(ds, dim):
    return ds.assign_coords({dim: ds.indexes[dim].to_datetimeindex()})

In [None]:
fs = fsspec.filesystem('gs')

In [None]:
def detect_rundirs(bucket):
    diag_ncs =  fs.glob(os.path.join(bucket, '*', 'diags.nc'))
    return [Path(url).parent.name for url in diag_ncs]

In [None]:
BUCKET = os.getenv('INPUT', "gs://vcm-ml-data/experiments-2020-03/prognostic_run_diags")

rundirs = detect_rundirs(BUCKET)

metrics = {}

for rundir in rundirs:
    path = os.path.join(BUCKET, rundir, 'diags.nc')
    with fsspec.open(path, "rb") as f:
        metrics[rundir] = xr.open_dataset(f, engine='h5netcdf').load()

In [None]:
import holoviews as hv
hv.extension('bokeh')

In [None]:
def get_ts(ds):
    return ds.drop([key for key in ds if set(ds[key].dims) != {'time'}])

In [None]:
time_series = {key: convert_time_index_to_datetime(get_ts(ds), 'time') for key, ds in metrics.items()}

In [None]:
def holomap_filter(time_series, varfilter):
    p = hv.Cycle('Colorblind')
    hmap = hv.HoloMap(kdims=['variable', 'run'])
    for run, ds in time_series.items():
        for varname in ds:
            if varfilter in varname:
                try:
                    v = ds[varname].load()
                except:
                    pass
                else:
                    if run.endswith('baseline'):
                        style= 'dashed'
                    else:
                        style = 'solid'
                    long_name = ds[varname].long_name
                    hmap[(long_name, run)] = hv.Curve(v, label=varfilter).options(line_dash=style, color=p)
    return hmap

# RMS Errors

In [None]:
%%opts Curve [width=700, height=500] {+framewise}
hmap = holomap_filter(time_series, 'rms')
hmap.overlay('run')

# Global Averages

In [None]:
%%opts Curve [width=700, height=500] {+framewise}
hmap = holomap_filter(time_series, 'global_avg')
hmap.overlay('run')

# Scalar Metrics

In [None]:
import json

def flatten(metrics):
    for run in metrics:
        for name in metrics[run]:
            baseline_s = '-baseline'
            rf_s = '-rf'
            if run.endswith(baseline_s):
                baseline= "Baseline"
                one_step = run[:-len(baseline_s)]
            elif run.endswith(rf_s):
                one_step = run[:-len(rf_s)]
                baseline= "RF"
            yield one_step, baseline, name, metrics[run][name]['value']

            
metrics = {}

for rundir in rundirs:
    path = os.path.join(BUCKET, rundir, 'metrics.json')
    with fsspec.open(path, "rb") as f:
        metrics[rundir] = json.load(f)
            
            
df = pd.DataFrame(flatten(metrics), columns=['one_step', 'baseline', 'metric', 'value'])


# collect data into a holoviews object
hmap = hv.HoloMap(kdims=["metric"])
bias = hv.HoloMap(kdims=["metric"])

for metric in df.metric.unique():
    s = df[df.metric==metric]
    bars = hv.Bars((s.one_step, s.baseline, s.value), kdims=["one_step", "type"])
    
    if metric.startswith('rmse'):
        hmap[metric] = bars
    elif metric.startswith('drift'):
        bias[metric] = bars

In [None]:
%%opts Bars [width=600] {+framewise}
hmap

In [None]:
%%opts Bars [width=600] {+framewise}
bias