In [None]:
import xarray as xr
import pathlib
import sys
import datetime
from data_config import (
    get_scratch_dir,
    get_dask_log_dir,
    get_dask_local_dir,
)

from process_files import (
    memory,
    get_case_metadata,
)
import numpy as np

import pandas as pd
import tqdm
import dask.array as da
import numcodecs

In [None]:
scratch = get_scratch_dir()
dask_log_directory = get_dask_log_dir()
dask_local_directory = get_dask_local_dir()

In [None]:
parent_dir = pathlib.Path.cwd().parent
sys.path.append(str(parent_dir))

In [None]:
import atlas

In [None]:
@memory.cache
def get_done_cases_df(today=datetime.datetime.today().date()):
    calc = atlas.global_irf_map(cdr_forcing="DOR", vintage="001")

    data = calc.df
    # done = data.loc[data.archive]
    done = data

    done_cases = done.index.to_list()
    done_cases.remove("smyle.cdr-atlas-v0.control.001")
    done_cases = sorted(done_cases)

    df = calc.df.loc[done_cases]
    return df, done_cases

In [None]:
%%time

df, done_cases = get_done_cases_df()
df

In [None]:
base_directory = pathlib.Path(
    "/global/cfs/projectdirs/m4746/Projects/Ocean-CDR-Atlas-v0/data/analysis"
)
base_directory

In [None]:
def add_additional_coords(ds: xr.Dataset, case: str, case_metadata: pd.Series):
    polygon_master = int(case_metadata.polygon_master)
    if polygon_master < 0 or polygon_master > 689:
        raise ValueError(
            f"Polygon id must be in range [0, 690). Found polygon_id={polygon_master}"
        )

    # add as an integer coordinate
    polygon_id_coord = xr.DataArray(
        name="polygon_id",
        dims="polygon_id",
        data=[polygon_master],
        attrs={"long_name": "polygon ID"},
    ).astype("int32")

    # injenction date
    injection_date_coord = xr.DataArray(
        data=[int(case_metadata.start_date.split("-")[-1])],
        dims=["injection_date"],
        attrs={"long_name": "injection date", "units": "month of 1999"},
    ).astype("int32")

    renamed = ds.drop_vars("time").rename_dims(time="elapsed_time")

    return renamed.assign_coords(
        polygon_id=polygon_id_coord,
        injection_date=injection_date_coord,
    )


def expand_ensemble_dims(ds: xr.Dataset) -> xr.Dataset:
    """Add new dimensions across the ensemble"""

    copied = ds.copy()

    # all data variables should be ensemble variables
    for name in list(ds.data_vars):
        copied[name] = copied[name].expand_dims(["polygon_id", "injection_date"])

    # absolute time is a function of injection_date because of the different starting times
    # copied["time"] = copied["time"].expand_dims(["injection_date"])
    # copied["time_bound"] = copied["time_bound"].expand_dims(["injection_date"])

    return copied


def compute_dor_efficiency(ds: xr.Dataset) -> xr.Dataset:
    ds["DOR_efficiency"] = (-ds.DIC_ADD_TOTAL / ds.DIC_FLUX).astype("float32")
    return ds


def set_compression_encoding(ds: xr.Dataset) -> xr.Dataset:
    compressor = numcodecs.Zlib(level=1)

    for name, var in ds.variables.items():
        # avoid using NaN as a fill value, and avoid overflow errors in encoding
        if np.issubdtype(var.dtype, np.integer):
            ds[name].encoding = {"compressor": compressor, "_FillValue": 2_147_483_647}
        elif var.dtype == np.dtype("float32"):
            ds[name].encoding = {
                "compressor": compressor,
                "_FillValue": 9.969209968386869e36,
            }
        else:
            ds[name].encoding = {"compressor": compressor}

    return ds


def set_elapsed_time(ds: xr.Dataset):
    elapsed_time_integer_months = xr.DataArray(
        np.arange(180), dims=["elapsed_time"], attrs={"units": "months"}
    )
    ds["elapsed_time"] = elapsed_time_integer_months.astype("int32")
    return ds


def process_case(case: str, df: pd.DataFrame) -> xr.Dataset:
    case_metadata = get_case_metadata(case, df=df)
    path = base_directory / f"{case}.analysis.zarr"
    if not path.exists():
        return None
    ds = (
        xr.open_dataset(path, engine="zarr", chunks={}, decode_timedelta=True)
        .pipe(add_additional_coords, case, case_metadata)
        .pipe(expand_ensemble_dims)
        .pipe(compute_dor_efficiency)
        .pipe(set_elapsed_time)
    )
    return ds[["DOR_efficiency", "elapsed_time"]].drop_vars(["time_delta"])


def process_case_without_data(
    case: str, df: pd.DataFrame, ds: xr.Dataset
) -> xr.Dataset:
    case_metadata = get_case_metadata(case, df=df)
    original_attrs = ds.polygon_id.attrs
    ds = ds.assign_coords(polygon_id=[case_metadata.polygon_master])
    ds["polygon_id"] = ds["polygon_id"].astype("int32")
    ds["polygon_id"].attrs = original_attrs

    return ds

In [None]:
def create_empty_target_store():
    store1b_chunks_encoding_per_variable = {
        "DOR_efficiency": {
            "chunks": {"polygon_id": 1, "injection_date": 1, "elapsed_time": 180}
        },  # polygon_id: 1 injection_date: 1 elapsed_time: 180
        "polygon_id": {"chunks": {"polygon_id": 690}},  # polygon_id: 1
        "injection_date": {"chunks": {"injection_date": 1}},  # injection_date: 1
        "elapsed_time": {"chunks": {"elapsed_time": 180}},  # elapsed_time: 180
    }
    sizes_all_dims = {
        "elapsed_time": 180,
        "polygon_id": 690,
        "injection_date": 4,
    }

    placeholder = xr.Dataset()
    placeholder["elapsed_time"] = xr.DataArray(
        np.arange(180), dims=["elapsed_time"], attrs={"units": "months"}
    )
    placeholder["polygon_id"] = xr.DataArray(
        np.arange(690),
        dims=["polygon_id"],
        attrs={"long_name": "Polygon ID"},
    )
    placeholder["injection_date"] = xr.DataArray(
        np.array([1, 4, 7, 10]),
        dims=["injection_date"],
        attrs={"long_name": "injection date", "units": "month of 1999"},
    )

    var_chunks = store1b_chunks_encoding_per_variable["DOR_efficiency"]["chunks"]
    var_dims = list(var_chunks.keys())
    var_sizes = {d: s for d, s in sizes_all_dims.items() if d in var_dims}
    var_shape = tuple(var_sizes.values())
    ordered_var_dims = list(var_sizes.keys())

    placeholder["DOR_efficiency"] = xr.DataArray(
        da.empty(
            shape=var_shape,
            chunks=var_chunks,
            dtype="float32",
        ),
        dims=ordered_var_dims,
    )
    placeholder = (
        placeholder.pipe(set_compression_encoding)
        .chunk(polygon_id=-1, injection_date=1, elapsed_time=-1)
        .transpose("elapsed_time", "polygon_id", "injection_date")
    )

    return placeholder

In [None]:
store_path = "s3://carbonplan-dor-efficiency/store1b.zarr"

placeholder = create_empty_target_store()
placeholder

In [None]:
placeholder.to_zarr(
    store_path, consolidated=True, zarr_format=2, mode="w", compute=False
)

In [None]:
%%time

for key, group in df.groupby("start_date"):
    cases_without_data = []
    dsets = []
    for case in tqdm.tqdm(group.index):
        single_ds = process_case(case=case, df=group)
        if single_ds:
            dsets.append(single_ds)
        else:
            cases_without_data.append(case)

    for case in cases_without_data:
        single_ds = xr.zeros_like(dsets[0])
        single_ds = process_case_without_data(case=case, df=group, ds=single_ds)
        dsets.append(single_ds)
    dataset = (
        xr.combine_by_coords(dsets, combine_attrs="drop_conflicts")
        .transpose("elapsed_time", "polygon_id", ...)
        .chunk(polygon_id=-1, elapsed_time=-1)
    )
    dataset.to_zarr(store_path, region="auto")

    print(f"Number of cases without data for group={key}: {len(cases_without_data)}")

In [None]:
ds = xr.open_dataset(store_path, engine="zarr", chunks={})
# ds.sel(polygon_id=0).DOR_efficiency.plot(col_wrap=4, col="injection_date")
ds.isel(polygon_id=slice(0, 690, 23), injection_date=0).DOR_efficiency.plot(
    col_wrap=5, col="polygon_id"
)

In [None]:
ds

In [None]:
dset = xr.open_dataset(
    "https://carbonplan-oae-efficiency.s3.us-west-2.amazonaws.com/v2/store1b_rechunked.zarr/",
    engine="zarr",
    chunks={},
)
dset

In [None]:
xr.show_versions()