In [204]:
import xarray as xr
import numpy as np
import zarr
import pandas as pd

import os, sys, time, glob, re

seed = 107
root_dir = '/data/harish'

In [230]:
def RTMA_data_splitting(zarr_path,dates_range,in_times,out_times,opt_test=False):
    '''
    Since the RTMA data have gaps in it, we have to the standard reference time-series to extract the rolling window samples.
    '''
    # Define input/output window sizes
    in_times = 3   # Example: 24 input hours (1 day)
    out_times = 1   # Example: 6 output hours (6-hour prediction)
    # create a pandas timetime index for the entire training and validation period
    reference_dates = pd.date_range(start=dates_range[0], end=dates_range[1], freq='h')

    # Define the input and output time windows
    in_steps = pd.Timedelta(hours=in_times)
    out_steps = pd.Timedelta(hours=out_times)

    # create input and output samples by sliding the input window over the entire training and validation period
    in_samples = []
    out_samples = []
    for i in range(len(reference_dates) - in_times - out_times +1):
        in_samples.append(reference_dates[i:i+in_times])
        out_samples.append(reference_dates[i+in_times:i+in_times+out_times])
    in_samples = np.array(in_samples)
    out_samples = np.array(out_samples)
    #print(in_samples.shape, out_samples.shape)

    # Load the RTMA data time-series
    ds = xr.open_zarr(zarr_path)
    time_coord = ds.sel(time=slice(*dates_range)).coords['time']

    original_times = pd.to_datetime(time_coord.values)
    reference_dates = pd.to_datetime(reference_dates)
    # Find missing times by comparing the reference and original time series
    missing_times = reference_dates.difference(original_times)
    #print(f'Missing times: {missing_times}, Total missing times: {len(missing_times)}')
    
    # Filter out in_samples and out_samples that overlap with missing times
    filtered_in_samples = []
    filtered_out_samples = []
    for in_sample, out_sample in zip(in_samples, out_samples):
        # Check if any time in the input or output window is in the missing times
        if any(time in missing_times for time in in_sample) or any(time in missing_times for time in out_sample):
            continue  # Skip this sample if it contains a missing time
        filtered_in_samples.append(in_sample)
        filtered_out_samples.append(out_sample)

    # Convert filtered samples to numpy arrays
    filtered_in_samples = np.array(filtered_in_samples)
    filtered_out_samples = np.array(filtered_out_samples)
    #print(filtered_in_samples.shape, filtered_out_samples.shape)
    
    if not opt_test:
        years = pd.DatetimeIndex(filtered_in_samples[:, 0]).year
        months = pd.DatetimeIndex(filtered_in_samples[:, 0]).month
        validation_samples = np.zeros(len(filtered_in_samples), dtype=bool)
        for year in np.unique(years):
            for month in range(1, 13):
                month_indices = np.where((years == year) & (months == month))[0]
                if len(month_indices) == 0:
                    continue
                # Select a random sample from the month
                if len(month_indices) >= int(6*24):
                    start_index = np.random.choice(len(month_indices) - int(6*24) - 1)
                    validation_indices = month_indices[start_index:start_index + int(6*24)]
                    validation_samples[validation_indices] = True
        
        X_train_times = xr.DataArray(filtered_in_samples[~validation_samples], dims=['sample', 'time_window'])
        Y_train_times = xr.DataArray(filtered_out_samples[~validation_samples],dims=['sample', 'time_window'])
        X_val_times = xr.DataArray(filtered_in_samples[validation_samples],dims=['sample', 'time_window'])
        Y_val_times = xr.DataArray(filtered_out_samples[validation_samples],dims=['sample', 'time_window'])
        #print(X_train_times.shape, Y_train_times.shape, X_val_times.shape, Y_val_times.shape)
        
        return X_train_times, Y_train_times, X_val_times, Y_val_times
    
    else:
        X_test_times = xr.DataArray(filtered_in_samples, dims=['sample', 'time_window'])
        Y_test_times = xr.DataArray(filtered_out_samples, dims=['sample', 'time_window'])
        #print(X_test_times.shape, Y_test_times.shape)
        
        return X_test_times, Y_test_times

In [231]:
zarr_path = f'{root_dir}/rtma_i10fg_NYS_subset.zarr'
train_val_dates_range = ('2018-01-01T00', '2022-12-31T23')
# Define input/output window sizes
in_times = 3   # Example: 24 input hours (1 day)
out_times = 1   # Example: 6 output hours (6-hour prediction)

X_train_times, Y_train_times, X_val_times, Y_val_times = RTMA_data_splitting(zarr_path,train_val_dates_range,in_times,out_times,opt_test=False)

(43821, 3) (43821, 1)
Missing times: DatetimeIndex(['2018-11-05 18:00:00', '2018-11-05 19:00:00',
               '2018-11-05 20:00:00', '2018-11-05 21:00:00',
               '2018-11-05 22:00:00', '2018-11-05 23:00:00',
               '2019-09-18 15:00:00', '2019-12-08 00:00:00',
               '2020-01-25 00:00:00', '2020-03-07 18:00:00',
               ...
               '2021-03-26 14:00:00', '2021-03-26 15:00:00',
               '2021-03-26 16:00:00', '2021-03-26 17:00:00',
               '2021-03-26 18:00:00', '2021-03-26 19:00:00',
               '2021-03-26 20:00:00', '2021-03-26 21:00:00',
               '2021-03-26 22:00:00', '2021-03-26 23:00:00'],
              dtype='datetime64[ns]', length=131, freq=None), Total missing times: 131
(43645, 3) (43645, 1)
(35005, 3) (35005, 1) (8640, 3) (8640, 1)


In [233]:
zarr_path = f'{root_dir}/rtma_i10fg_NYS_subset.zarr'
test_dates_range = ('2023-01-01T00', '2023-12-31T23')
# Define input/output window sizes
in_times = 3   # Example: 24 input hours (1 day)
out_times = 1   # Example: 6 output hours (6-hour prediction)

X_test_times, Y_test_times = RTMA_data_splitting(zarr_path,test_dates_range,in_times,out_times,opt_test=True)

(8757, 3) (8757, 1)
Missing times: DatetimeIndex([], dtype='datetime64[ns]', freq='h'), Total missing times: 0
(8757, 3) (8757, 1)
(8757, 3) (8757, 1)


In [240]:
ds = xr.open_zarr(zarr_path)
data = ds.i10fg#.transpose(..., 'time')
X_train = data.sel(time=X_train_times).transpose('sample', 'y', 'x','time_window')
Y_train = data.sel(time=Y_train_times).transpose('sample', 'y', 'x','time_window')
X_val = data.sel(time=X_val_times).transpose('sample', 'y', 'x','time_window')
Y_val = data.sel(time=Y_val_times).transpose('sample', 'y', 'x','time_window')
X_test = data.sel(time=X_test_times).transpose('sample', 'y', 'x','time_window')
Y_test = data.sel(time=Y_test_times).transpose('sample', 'y', 'x','time_window')

print('X_train',X_train.shape, 'Y_train',Y_train.shape, 
      'X_val',X_val.shape, 'Y_val',Y_val.shape, 
      'X_test',X_test.shape, 'Y_test',Y_test.shape)

X_train (35005, 256, 384, 3) Y_train (35005, 256, 384, 1) X_val (8640, 256, 384, 3) Y_val (8640, 256, 384, 1) X_test (8757, 256, 384, 3) Y_test (8757, 256, 384, 1)
