In [1]:
import sys
import pathlib
import xarray as xr
import dask
import loky
import ndpyramid
import os
import fsspec
import traceback
from rich.console import Console
import tqdm

parent_dir = pathlib.Path.cwd().parent
sys.path.append(str(parent_dir))
sys.path.append(str(parent_dir.parent))
import functools
from dor_config import DORConfig

@property
def compressed_data_dir(self) -> pathlib.Path:
    """Get the compressed data directory and ensure it exists."""
    out_dir = (
        pathlib.Path(self.parent_data_dir)
        / "research-grade-compressed"
       
    )
    out_dir.mkdir(parents=True, exist_ok=True)
    return out_dir

DORConfig.compressed_data_dir = compressed_data_dir


from dor_cli import setup_directories


from vis_pyramid import create_template_store2, get_nc_glob_pattern, load_ssh_data, reduction, integrate_column_mol, reshape_into_month_year, setup_memory

In [2]:
console = Console()

dirs = setup_directories()
memory = setup_memory(dirs["joblib_cache_dir"])
memory

Memory(location=/pscratch/sd/a/abanihi/dor/joblib)

In [3]:
config = DORConfig(parent_data_dir='/global/cfs/projectdirs/m4746/Datasets/Ocean-CDR-Atlas-v0/OAE-Efficiency-Map/', 
                   store_1_path='/pscratch/sd/a/abanihi/oae/store1b.zarr', 
                   store_2_path='s3://carbonplan-oae-efficiency/v3/store2.zarr/',
                  )
config

DORConfig(scratch_env='/pscratch/sd/a/abanihi', parent_data_dir='/global/cfs/projectdirs/m4746/Datasets/Ocean-CDR-Atlas-v0/OAE-Efficiency-Map/', data_archive_dir='/global/cfs/projectdirs/m4746/Projects/Ocean-CDR-Atlas-v0/data/archive', store_1_path='/pscratch/sd/a/abanihi/oae/store1b.zarr', store_2_path='s3://carbonplan-oae-efficiency/v3/store2.zarr/', cumulative_fg_co2_percent_store_path='/pscratch/sd/a/abanihi/test/cumulative_FG_CO2_percent.zarr')

In [4]:
config.compressed_data_dir

PosixPath('/global/cfs/projectdirs/m4746/Datasets/Ocean-CDR-Atlas-v0/OAE-Efficiency-Map/research-grade-compressed')

In [5]:
#create_template_store2(output_store=config.store_2_path, variables=["DIC", "ALK", "FG", "PH", "pCO2SURF"], levels=2)

In [6]:
def reduction(ds, ssh):
    with xr.set_options(keep_attrs=True):
        # ["DIC", "ALK", "FG", "PH", "pCO2SURF"]
        alk_delta_surf = ds.ALK.isel(z_t=0) - ds.ALK_ALT_CO2.isel(z_t=0)
        PH_delta_surf = ds.PH - ds.PH_ALT_CO2
        pco2_delta_surf = ds.pCO2SURF - ds.pCO2SURF_ALT_CO2
        fg_co2_delta_surf = ds.FG_CO2 - ds.FG_ALT_CO2

        dic_column_integrated = integrate_column_mol(ds.DIC, ds["dz"], ssh)
        dic_delta_column_integrated = dic_column_integrated - integrate_column_mol(
            ds.DIC_ALT_CO2, ds["dz"], ssh
        ) 
        dso = dict(
                ALK_SURF=ds.ALK.isel(z_t=0),
                ALK_DELTA_SURF = alk_delta_surf,
                FG_CO2_SURF=ds.FG_CO2,
                FG_CO2_DELTA_SURF=fg_co2_delta_surf,
                DIC_COLUMN_INTEGRATED=dic_column_integrated,
                DIC_DELTA_COLUMN_INTEGRATED=dic_delta_column_integrated,
                PH_SURF=ds.PH,
                PH_DELTA_SURF=PH_delta_surf,
                pCO2_DELTA_SURF=pco2_delta_surf,
                pCO2_SURF=ds.pCO2SURF,
            
        )
        dset = xr.Dataset(dso)
        coords_to_drop = set(dset.coords).difference(set(['polygon_id', 'injection_date', 'elapsed_time', 'ULONG', 'ULAT']))
        return dset.drop_vars(coords_to_drop)

def concatenate_into_bands(ds: xr.Dataset) -> xr.Dataset:
    """Concatenate the dataset into bands."""
    bands_ds = xr.Dataset(coords=ds.coords)

    bands_ds["ALK"] = xr.concat(
        [ds["ALK_DELTA_SURF"], ds["ALK_SURF"]],
        dim=xr.DataArray(name="band", data=["delta", "experimental"], dims="band"),
    )
    bands_ds["DIC"] = xr.concat(
        [ds["DIC_DELTA_COLUMN_INTEGRATED"], ds["DIC_COLUMN_INTEGRATED"]],
        dim=xr.DataArray(name="band", data=["delta", "experimental"], dims="band"),
    )
    bands_ds["PH"] = xr.concat(
        [ds["PH_DELTA_SURF"], ds["PH_SURF"]],
        dim=xr.DataArray(name="band", data=["delta", "experimental"], dims="band"),
    )
    bands_ds["FG"] = xr.concat(
        [ds["FG_CO2_DELTA_SURF"], ds["FG_CO2_SURF"]],
        dim=xr.DataArray(name="band", data=["delta", "experimental"], dims="band"),
    )
    bands_ds["pCO2SURF"] = xr.concat(
        [ds["pCO2_DELTA_SURF"], ds["pCO2_SURF"]],
        dim=xr.DataArray(name="band", data=["delta", "experimental"], dims="band"),
    )

    return bands_ds


In [7]:
@memory.cache
def process_and_create_pyramid(
    polygon_id: str,
    injection_month: str,
    data_dir: str,
    store_path: str,
    weights_store: str,
    levels: int = 2,
) -> None:
    """Process data and create visualization pyramid."""
    try:
        path = get_nc_glob_pattern(data_dir, polygon_id, injection_month)
        console.print(f"Loading data from {path}", style="blue")

        with dask.config.set(
            pool=loky.ProcessPoolExecutor(max_workers=os.cpu_count() // 2, timeout=120)
        ):
            ds = xr.open_mfdataset(
                path,
                coords="minimal",
                combine="by_coords",
                data_vars="minimal",
                compat="override",
                decode_times=True,
                parallel=True,
                decode_timedelta=True,
            )
            ds = dask.optimize(ds)[0]

            console.print("Processing dataset through reduction pipeline", style="blue")
            ssh = load_ssh_data(injection_month, ssh_path=f"{config.compressed_data_dir}/control/g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.SMYLE.005.pop.h.SSH.030601-036812.nc")

            

            bands_ds = (
                ds.pipe(reduction, ssh)
                .pipe(concatenate_into_bands)
                .pipe(reshape_into_month_year)
            )

            console.print("Building visualization pyramid", style="blue")
            other_chunks = dict(
                month=1, year=-1, band=1, polygon_id=1, injection_date=1, x=128, y=128
            )

            if fsspec.get_mapper(weights_store).fs.exists(weights_store):
                console.print(
                    f"Using weights from {weights_store} for regridding", style="blue"
                )
                weights = xr.open_datatree(weights_store, engine="zarr", chunks={})

            else:
                console.print(
                    "No weights store provided or does not exist. "
                    "Weights will be generated on-the-fly.",
                    style="yellow",
                )
                weights = ndpyramid.regrid.generate_weights_pyramid(bands_ds, levels=2)
                weights.to_zarr(
                    weights_store, consolidated=True, zarr_format=2, mode="w"
                )

            pyramid = ndpyramid.pyramid_regrid(
                bands_ds,
                levels=levels,
                projection="web-mercator",
                parallel_weights=False,
                other_chunks=other_chunks,
                weights_pyramid=weights,
            )

            pyramid = dask.optimize(pyramid)[0]

            console.print(f"Saving pyramid to {store_path}", style="blue")
            pyramid.to_zarr(store_path, region="auto", mode="r+")

            return pyramid

    except Exception as exc:
        console.print(
            f"[bold red]Error processing polygon_id={polygon_id}, "
            f"injection_month={injection_month}: {traceback.format_exc()}[/bold red]"
        )
        raise exc

In [8]:
polygon_ids = range(0, 690)
padded_polygon_ids = [f"{polygon_id:03d}" for polygon_id in polygon_ids]
padded_injection_months = ['01', '04', '07', '10']
tasks = []
for polygon_id in padded_polygon_ids:
    for injection_month in padded_injection_months:
        tasks.append((polygon_id, injection_month))
tasks[:10]

[('000', '01'),
 ('000', '04'),
 ('000', '07'),
 ('000', '10'),
 ('001', '01'),
 ('001', '04'),
 ('001', '07'),
 ('001', '10'),
 ('002', '01'),
 ('002', '04')]

In [9]:
%%time

for polygon_id, injection_month in tqdm.tqdm(tasks):
    try:
        process_and_create_pyramid(polygon_id=polygon_id, injection_month=injection_month, 
                               data_dir=f'{config.compressed_data_dir}/experiments', 
                               store_path=config.store_2_path, 
                               weights_store=f"{os.environ['SCRATCH']}/oae/weights.zarr")
        #console.print(f"Finished processing polygon_id={polygon_id}, injection_month={injection_month}",style="green")
    except Exception:
            console.print(
                f"[bold red]Error processing {polygon_id}/{injection_month}: "
                f"{traceback.format_exc()}[/bold red]"
            )

100%|██████████| 2760/2760 [1:53:11<00:00,  2.46s/it]

CPU times: user 1h 46min 13s, sys: 4min 32s, total: 1h 50min 45s
Wall time: 1h 53min 11s



