In [None]:
import os

from itertools import product

import xarray as xr
import dask
import numpy as np
import pandas as pd

import util

PROJECT = "NCGD0011"
USER = os.environ["USER"]

In [None]:
basins = ['North_Atlantic_basin', 'North_Pacific_basin', 'South', 'Southern_Ocean']
npolygon = dict(
    North_Atlantic_basin=150, 
    North_Pacific_basin=200, 
    South=300,
    Southern_Ocean=40,
)

In [None]:
mths = ['-' + str(m).zfill(2) for m in range(1,13)]
yrs = np.array([str(yr).zfill(4) for yr in range(347, 363)])

timestamps = np.char.add(np.repeat(yrs, len(mths)),
                         np.tile(mths, len(yrs))
           )

In [None]:
%%time
path = '/glade/campaign/cesm/development/bgcwg/projects/OAE-Global-Efficiency/Mengyang_Global_OAE_Experiments/archive/'

rows = []
offset = 0
for n, b in enumerate(basins):
    
    polygon_ids = [f'{i:03d}' for i in np.arange(offset, offset + npolygon[b])]    
    offset += npolygon[b]
    
    for i, p_id in enumerate(polygon_ids):
        
        for m in ['01', '04', '07', '10']:
            ndx = np.int32(m) - 1
            dates = timestamps[ndx:ndx + 180]
            
            case = f'smyle-fosi.{b}.alk-forcing-{b}.{i:03d}-1999-{m}'
            files = [f'{path}/{case}/ocn/hist/{case}.pop.h.{d}.nc' for d in dates]
            
            rows.append(
                dict(polygon=i, polygon_id=p_id, basin=b, start_date=dates[0], files=files)
            )

index_fields = ['polygon', 'basin', 'start_date']
df = pd.DataFrame(rows).set_index(index_fields)
df

In [None]:
start_dates = list(df.index.unique(level='start_date'))
polygons = [df.xs((b, start_dates[0]), level=('basin', 'start_date')).index[0] for b in basins]

In [None]:
df.loc[0, 'South', '0347-01']

In [None]:
cluster, client = util.get_ClusterClient(memory="2GB", project=PROJECT, walltime="12:00:00")
cluster.scale(256)
client

In [None]:
%%time

@dask.delayed
def get_reference_data(index):
    
    return [
        xr.open_dataset(f)['ALK_ALT_CO2'].isel(time=0, z_t=0) 
        for f in df.loc[index].files
    ]
        

#reference_dsets = {f'{b}-{d}': get_reference_data((0, b, d))
#    for b, d in product(basins, start_dates)
#}

#reference_dsets = dask.compute(reference_dsets)[0]

In [None]:
%%time 

@dask.delayed
def comparison(index, index_ref):
    """return RMSE for field compared to reference"""
    
    files = df.loc[index].files
    da_list = [xr.open_dataset(f)['ALK_ALT_CO2'].isel(time=0, z_t=0) for f in files]    

    files = df.loc[index_ref].files
    da_ref_list = [xr.open_dataset(f)['ALK_ALT_CO2'].isel(time=0, z_t=0) for f in files]

    rmse = []
    for da_test, da_reference in zip(da_list, da_ref_list):
        rmse.append(
            ((da_test - da_reference) ** 2).sum().values.item()
        )
    return np.array(rmse)
        

rmse = []
for b, d in product(basins, start_dates):
    # get the indexes for these polygons — keep the reference as the first index
    polygons = df.xs((b, d), level=('basin', 'start_date')).index
    print((b, d))

    objs_rmse = []
    for p in polygons:
        p_ndx = (p, b, d)
        objs_rmse.append(dict(polygon=p, basin=b, start_date=d, rmse=comparison(p_ndx, (0, b, d))))
    
    rmse.extend(dask.compute(objs_rmse)[0])



In [None]:
df_comp = pd.DataFrame(rmse).set_index(index_fields)
df_comp

In [None]:
df_comp.to_pickle('comparison_data.pkl')