In [1]:
import xarray
import xarray as xr
import pandas as pd
from typing import Any, Callable, Union
from climpred.preprocessing.shared import set_integer_time_axis
import os
import subprocess
import glob
import numpy as np
%matplotlib inline
! module load gcp

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def get_path_hind(
    username: str = 'xyz.pqr',
    member: int = 1,
    init: int = 1,
    varname: str = 'xyz',
    expname: str = 'NWA12_COBALT_decr7_jra3q_PRED_',
    ipid: str = 'gfdl.ncrc6-intel23-prod',
    stream: str = 'pp',
    domain: str = 'xyz',
    ts: str = 'ts',
    freq: str = 'xyz',
    length: str = '1yr',
    ending: str = 'nc') -> Union[str, None]:
    """
    Get the path of a file for MOM6 standard output file names and directory.

    Returns:
        str or None: path of requested file(s) or None if directory or files are missing.
    """
    # Get experiment_id
    dir_base_experiment = f'/archive/{username}/fre/NWA/2024_11/'
    if not os.path.exists(dir_base_experiment):
        print(f"Base directory does not exist: {dir_base_experiment}")
        return None

    dirs = os.listdir(dir_base_experiment)
    experiment_id = [
        x for x in dirs if expname in x and (f'i{init}' in x and x.endswith('_e' + str(member)))
    ]
    if len(experiment_id) != 1:
        print(f"Experiment ID not found or ambiguous for init={init}, member={member}.")
        return None

    experiment_id = experiment_id[0]
    dir_outdata = f'{dir_base_experiment}/{experiment_id}/{ipid}/{stream}/{domain}/{ts}/{freq}/{length}'
    src_path = f'{dir_outdata}/{varname}.{ending}'
    files = glob.glob(src_path)
    print(src_path)
    if not files:
        print(f"No files found in source path: {src_path}")
        return None

    #tmp_dir = os.environ['TMPDIR']
    tmp_dir = os.environ.get('TMPDIR', '/vftmp/Vimal.Koul/')
    dst_dir = os.path.join(tmp_dir, f'ocean_annual_{varname[1:]}_{init}01_e{member}')
    subprocess.run(['rm -fr ' + dst_dir], shell=True, check=True)   
    os.makedirs(dst_dir, exist_ok=True)
    print(dst_dir)
    try:
        subprocess.run(['dmget', *files], check=True)
        subprocess.run(['gcp', '--sync', *files, dst_dir], check=True)
        dst_path = f'{dst_dir}/{varname}.{ending}'
        return dst_path
    except Exception as e:
        print(f"Error processing files: {e}")
        return None

def preprocess_1var(ds, v='xyz'):
    """Only leave one variable `v` in dataset"""
    return ds[v].to_dataset(name=v).squeeze()

def create_preprocessor(varname):
    """Create a preprocessing function for a specific variable"""
    def preprocessor(ds):
        return preprocess_1var(ds, v=varname)
    return preprocessor

def load_hindcast(
    inits=range(1965, 2024, 1),
    members=range(1, 5, 1),
    fixed_time_length=10,
    preprocess: Callable = None,
    lead_offset: int = 1,
    parallel: bool = True,
    engine: str = None,
    varname: str = 'xyz',
    domain: str = 'xyz',
    freq: str = 'xyz'
) -> Union[xarray.DataArray, xarray.Dataset]:
    """
    Concat multi-member, multi-initialization hindcast experiment.
    Into one :py:class:`xarray.Dataset`.

    Returns:
        dataset with dims: ``member``, ``init``, ``lead``.
    """
    init_list = []
    yh, xh = None, None

    for init in inits:
        print(f"Processing init {init} ...")
        member_list = []

        for member in members:
            p = get_path_hind(username='Vimal.Koul', member=member, init=init, varname=f'*{varname}', domain=domain, freq=freq)
            if p is None:
                print(f"Appending NaNs for init={init}, member={member}")
                if yh is None or xh is None:
                    raise ValueError("Cannot append NaNs without knowing yh and xh dimensions. Ensure at least one valid file is processed first.")
                member_list.append(
                    xarray.Dataset({
                        varname: (("time", "yh", "xh"), np.full((fixed_time_length, yh, xh), float('nan')))
                    })
                )
                continue

            try:
                # Open all leads for specified member and init
                member_ds = xarray.open_mfdataset(
                    p,
                    combine="nested",
                    concat_dim="time",
                    preprocess=preprocessor,
                    parallel=parallel,
                    engine=engine,
                    coords="minimal",
                    data_vars="minimal",
                    compat="override",
                ).squeeze()

                # Extract yh and xh dimensions from the first valid dataset
                if yh is None or xh is None:
                    yh, xh = member_ds.sizes["yh"], member_ds.sizes["xh"]

                # Set new integer time
                member_ds = set_integer_time_axis(member_ds)
                print(member_ds.sizes["time"])
                if member_ds.sizes["time"] < fixed_time_length:
                    current_length = member_ds.sizes["time"]
                    new_time = np.arange(1, fixed_time_length + 1)
                    member_ds = member_ds.reindex(
                        time=new_time, 
                        fill_value=np.nan
                    )
                member_list.append(member_ds)
            except Exception as e:
                print(f"Error loading dataset for init={init}, member={member}: {e}")
                if yh is None or xh is None:
                    raise ValueError("Cannot append NaNs without knowing yh and xh dimensions. Ensure at least one valid file is processed first.")
                member_list.append(
                    xarray.Dataset({
                        varname: (("time", "yh", "xh"), np.full((fixed_time_length, yh, xh), float('nan')))
                    })
                )

        # Concatenate along the member dimension
        member_ds = xarray.concat(member_list, "member")
        init_list.append(member_ds)

    # Concatenate along the init dimension and finalize
    ds = xarray.concat(init_list, "init").rename({"time": "lead"})
    ds["member"] = members
    ds["init"] = inits
    return ds

In [11]:
%%time

nmembers   = 5
init_strt  = 1965
init_end   = 2024
VARNAME    = 'ssh'
domain     = 'ocean_annual' #'ocean_cobalt_btm' 'ocean_cobalt_omip_sfc'
freq       = 'annual'

if freq=='annual':
    fixed_time_length=10
elif freq=='monthly':
    fixed_time_length=120

preprocessor = create_preprocessor(VARNAME)

ds_nwa12_orig = load_hindcast(inits=range(init_strt, init_end+1, 1), \
                          members=range(1,nmembers+1), \
                          fixed_time_length=fixed_time_length, \
                          preprocess=preprocessor, varname=VARNAME, \
                          domain=domain, freq=freq)
ds_nwa12_orig

Processing init 1965 ...
/archive/Vimal.Koul/fre/NWA/2024_11//NWA12_COBALT_decr7_jra3q_PRED_i19650101_e1/gfdl.ncrc6-intel23-prod/pp/ocean_annual/ts/annual/1yr/*ssh.nc
/vftmp/Vimal.Koul/ocean_annual_ssh_196501_e1
10
/archive/Vimal.Koul/fre/NWA/2024_11//NWA12_COBALT_decr7_jra3q_PRED_i19650101_e2/gfdl.ncrc6-intel23-prod/pp/ocean_annual/ts/annual/1yr/*ssh.nc
/vftmp/Vimal.Koul/ocean_annual_ssh_196501_e2
10
/archive/Vimal.Koul/fre/NWA/2024_11//NWA12_COBALT_decr7_jra3q_PRED_i19650101_e3/gfdl.ncrc6-intel23-prod/pp/ocean_annual/ts/annual/1yr/*ssh.nc
/vftmp/Vimal.Koul/ocean_annual_ssh_196501_e3
10
/archive/Vimal.Koul/fre/NWA/2024_11//NWA12_COBALT_decr7_jra3q_PRED_i19650101_e4/gfdl.ncrc6-intel23-prod/pp/ocean_annual/ts/annual/1yr/*ssh.nc
/vftmp/Vimal.Koul/ocean_annual_ssh_196501_e4
3
/archive/Vimal.Koul/fre/NWA/2024_11//NWA12_COBALT_decr7_jra3q_PRED_i19650101_e5/gfdl.ncrc6-intel23-prod/pp/ocean_annual/ts/annual/1yr/*ssh.nc
/vftmp/Vimal.Koul/ocean_annual_ssh_196501_e5
1
Processing init 1966 ...
/a

Unnamed: 0,Array,Chunk
Bytes,14.64 GiB,5.00 MiB
Shape,"(60, 5, 10, 845, 775)","(1, 1, 1, 845, 775)"
Dask graph,3000 chunks in 7828 graph layers,3000 chunks in 7828 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 14.64 GiB 5.00 MiB Shape (60, 5, 10, 845, 775) (1, 1, 1, 845, 775) Dask graph 3000 chunks in 7828 graph layers Data type float64 numpy.ndarray",5  60  775  845  10,

Unnamed: 0,Array,Chunk
Bytes,14.64 GiB,5.00 MiB
Shape,"(60, 5, 10, 845, 775)","(1, 1, 1, 845, 775)"
Dask graph,3000 chunks in 7828 graph layers,3000 chunks in 7828 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [12]:
%%time

ds_nwa12_orig["lead"] = np.arange(1, 1 + ds_nwa12_orig.lead.size)

# climpred looks for this to know lead time resolution.
ds_nwa12_orig["lead"].attrs["units"] = "years"

# Extract the years from the init coordinate
years = ds_nwa12_orig.init.values

# Convert the years to pandas.Timestamp objects
dates = [pd.Timestamp(year=year, month=1, day=1) for year in years]
#dates = [year for year in years]

# Create a DatetimeIndex from the pandas.Timestamp objects
dt_index = pd.DatetimeIndex(dates)

# Assign the DatetimeIndex to the init coordinate of the dataset
ds_nwa12_orig = ds_nwa12_orig.assign_coords(init=dt_index)
ds_nwa12_orig

CPU times: user 4.16 ms, sys: 349 μs, total: 4.51 ms
Wall time: 3.83 ms


Unnamed: 0,Array,Chunk
Bytes,14.64 GiB,5.00 MiB
Shape,"(60, 5, 10, 845, 775)","(1, 1, 1, 845, 775)"
Dask graph,3000 chunks in 7828 graph layers,3000 chunks in 7828 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 14.64 GiB 5.00 MiB Shape (60, 5, 10, 845, 775) (1, 1, 1, 845, 775) Dask graph 3000 chunks in 7828 graph layers Data type float64 numpy.ndarray",5  60  775  845  10,

Unnamed: 0,Array,Chunk
Bytes,14.64 GiB,5.00 MiB
Shape,"(60, 5, 10, 845, 775)","(1, 1, 1, 845, 775)"
Dask graph,3000 chunks in 7828 graph layers,3000 chunks in 7828 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [13]:
ds_nwa12_orig = ds_nwa12_orig.load()

In [14]:
ds_nwa12_orig

In [15]:
file_path = f'/work/vnk/outdata_for_analysis/post_202412/jra3q/nwa12_hindcast_{freq}_{VARNAME}.nc'
if not os.path.exists(file_path):
    ds_nwa12_orig.rename({'init': 'time'}).to_netcdf(path=file_path, invalid_netcdf=False)
else:
    os.remove(file_path)
    ds_nwa12_orig.rename({'init': 'time'}).to_netcdf(path=file_path, invalid_netcdf=False)