In [1]:
import xarray as xr
from ndpyramid.utils import set_zarr_encoding
import numpy as np


In [2]:
url = "s3://carbonplan-oae-efficiency/v2/store1b_rechunked.zarr/"
current = xr.open_dataset(url, engine='zarr', chunks={})
current

Unnamed: 0,Array,Chunk
Bytes,0.94 MiB,0.94 MiB
Shape,"(384, 320)","(384, 320)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 0.94 MiB 0.94 MiB Shape (384, 320) (384, 320) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",320  384,

Unnamed: 0,Array,Chunk
Bytes,0.94 MiB,0.94 MiB
Shape,"(384, 320)","(384, 320)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.90 MiB,485.16 kiB
Shape,"(180, 690, 4)","(180, 690, 1)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.90 MiB 485.16 kiB Shape (180, 690, 4) (180, 690, 1) Dask graph 4 chunks in 2 graph layers Data type float32 numpy.ndarray",4  690  180,

Unnamed: 0,Array,Chunk
Bytes,1.90 MiB,485.16 kiB
Shape,"(180, 690, 4)","(180, 690, 1)"
Dask graph,4 chunks in 2 graph layers,4 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [3]:
path = '/global/homes/a/abanihi/OAE_efficiency_corrected_.nc'
raw_ds = xr.open_dataset(path)
raw_ds

In [4]:
def break_into_seasons(ds, *, polygon:int=0, region:int=0):
    jan = ds.isel(polygon=polygon, region=region, season=0)
    apr = ds.isel(polygon=polygon, region=region, season=1)
    jul = ds.isel(polygon=polygon, region=region, season=2)
    octo = ds.isel(polygon=polygon, region=region, season=3)

    results = {'january': jan.where(jan.OAE_efficiency.notnull(), drop=True), 
               'april': apr.where(apr.OAE_efficiency.notnull(), drop=True), 
               'july': jul.where(jul.OAE_efficiency.notnull(), drop=True),
               'october': octo.where(octo.OAE_efficiency.notnull(), drop=True)}
    return results

In [5]:
datasets = break_into_seasons(raw_ds)
datasets

{'january': <xarray.Dataset> Size: 3kB
 Dimensions:         (time: 180)
 Coordinates:
     region          <U14 56B 'Pacific'
     season          <U7 28B 'January'
     polygon         int64 8B 0
   * time            (time) object 1kB 0347-01-16 12:00:00 ... 0361-12-16 12:0...
 Data variables:
     OAE_efficiency  (time) float64 1kB 0.03391 0.08991 0.1478 ... 0.8193 0.8194,
 'april': <xarray.Dataset> Size: 3kB
 Dimensions:         (time: 180)
 Coordinates:
     region          <U14 56B 'Pacific'
     season          <U7 28B 'April'
     polygon         int64 8B 0
   * time            (time) object 1kB 0347-04-16 00:00:00 ... 0362-03-16 12:0...
 Data variables:
     OAE_efficiency  (time) float64 1kB 0.0245 0.05309 0.0858 ... 0.8213 0.8215,
 'july': <xarray.Dataset> Size: 3kB
 Dimensions:         (time: 180)
 Coordinates:
     region          <U14 56B 'Pacific'
     season          <U7 28B 'July'
     polygon         int64 8B 0
   * time            (time) object 1kB 0347-07-16 12:00:00

In [6]:
# fig, ax = plt.subplots(figsize=(12, 6))
# # Define colors and offsets for each season
# colors = ['blue', 'green', 'red', 'orange']
# offsets = [0.8, 0.6, 0.4, 0.2]  # For spacing the seasons vertically

# # Plot each season as a line with points
# for i, (season, times) in enumerate(datasets.items()):
#     ax.scatter(times, [offsets[i]] * len(times), 
#              color=colors[i], alpha=0.7, label=season)
    
#     # Connect points with lines
#     ax.plot(times, [offsets[i]] * len(times), 
#            color=colors[i], alpha=0.3, linewidth=2)

# # Add vertical lines at points of overlap
# all_times = np.concatenate([times for times in datasets.values()])
# unique_times = np.unique(all_times)

# # Count occurrences of each time point
# time_counts = {}
# for t in all_times:
#     time_counts[t] = time_counts.get(t, 0) + 1

# # Highlight overlaps
# for t, count in time_counts.items():
#     if count > 1:  # If time appears in more than one season
#         ax.axvline(x=t, color='purple', alpha=0.3 * count/4, 
#                   linestyle='--', linewidth=count)

# # Format the plot
# ax.set_yticks(offsets)
# ax.set_yticklabels(['January', 'April', 'July', 'October'])
# ax.set_title('Time Overlaps Between Seasons')

# # Format x-axis to show dates properly
# ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
# plt.xticks(rotation=45)

# ax.grid(True, axis='x', alpha=0.3)
# ax.legend(loc='upper right')

# plt.tight_layout()

In [7]:
def combine_seasons_with_relative_time(season_dict):
    """
    Combine season datasets using a relative time approach.
    
    Parameters:
    -----------
    season_dict : dict
        Dictionary with keys as season names and values as xarray Datasets
        Each dataset should have 180 time points
        
    Returns:
    --------
    xarray.Dataset
        Combined dataset with a single time dimension of 180 points
    """
    with xr.set_options(keep_attrs=True):
        # First, let's create a common time coordinate based on "months since injection"
        relative_months = np.arange(180).astype('int32')
        
        # Create a new dataset for each season with standardized coordinates
        standardized_datasets = []
        
        for season_name, ds in season_dict.items():
            # Create a new dataset with the original data but new coordinates
            new_ds = xr.Dataset(
                data_vars={
                    'OAE_efficiency': ('elapsed_time', ds.OAE_efficiency.data)
                },
                coords={
                    'elapsed_time': relative_months,
                    #'original_time': ('elapsed_months', ds.time.values)  # Keep original time as a coordinate
                }
            )
            
            # Extract the first time point as the injection date
            new_ds = new_ds.assign_coords(injection_date=ds.time.values[0].month)
            
            standardized_datasets.append(new_ds)
        
        # Combine using multi-index approach
        combined = xr.concat(standardized_datasets, dim='injection_date').expand_dims({'polygon': [ds.polygon.values], 'region': [ds.region.values]})
    
    return combined

In [8]:
combine_seasons_with_relative_time(datasets)#.OAE_efficiency.plot(col='injection_date')

In [9]:
dsets = []
for polygon in range(len(raw_ds.polygon)):
    for region in range(len(raw_ds.region)):
        parts = break_into_seasons(raw_ds, polygon=polygon, region=region)
        if len(parts['january'].time) == 180:
            data = combine_seasons_with_relative_time(parts)
            dsets.append(data)
ds = xr.combine_by_coords(dsets)
ds

In [10]:
atlantic = ds.sel(region='Atlantic').isel(polygon=slice(0, 150))
pacific = ds.sel(region='Pacific').isel(polygon=slice(0, 200))
south_atlantic = ds.sel(region='South').isel(polygon=slice(0, 300))
southern = ds.sel(region='Southern_Ocean').isel(polygon=slice(0, 40))

In [11]:
atlantic

In [12]:
dset = xr.concat([atlantic, pacific, south_atlantic, southern], dim='polygon').drop_vars('region')
dset = dset.rename_vars({'polygon': 'polygon_id'}).swap_dims({'polygon': 'polygon_id'}).transpose('elapsed_time', 'polygon_id', 'injection_date')
dset['polygon_id'] = np.arange(0, 690, dtype=np.int32)
dset = set_zarr_encoding(dset, float_dtype='float32', int_dtype='int32').chunk({'injection_date': 1})
dset

Unnamed: 0,Array,Chunk
Bytes,1.90 MiB,485.16 kiB
Shape,"(180, 690, 4)","(180, 690, 1)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.90 MiB 485.16 kiB Shape (180, 690, 4) (180, 690, 1) Dask graph 4 chunks in 1 graph layer Data type float32 numpy.ndarray",4  690  180,

Unnamed: 0,Array,Chunk
Bytes,1.90 MiB,485.16 kiB
Shape,"(180, 690, 4)","(180, 690, 1)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [13]:
%%time
dset.to_zarr("s3://carbonplan-oae-efficiency/v3/store1b.zarr/", consolidated=True, zarr_format=2, mode='w')

CPU times: user 396 ms, sys: 257 ms, total: 653 ms
Wall time: 4.5 s


<xarray.backends.zarr.ZarrStore at 0x7ff18c1f8d30>