# Imports

In [None]:
from tensorflow.keras.models import load_model
import xarray as xr
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend as K

# Train/val/test split

## Data Preprocessing

The data excpected to be used in this file is data with dimensions lat, lon, lev, and time. "lat" and "lon", which represents latitude and longitude respectively, are expected to be of resolution 0.5. Furthermore, the expected covariates of the dataset are ['T', 'AIRD', 'U', 'V', 'W', 'KM', "RI', 'QV', 'QI', 'QL']. The expected "lev" values range from 1, 2, ..., 72; each representing an atmospheric level as documented by https://gmao.gsfc.nasa.gov/global_mesoscale/7km-G5NR/docs/.

In [None]:
# Global Variables
BATCH_SIZE = 2480

US_LOCS = {"lat1": 25, 
           "lat2": 50,
           "lon1": -150,
           "lon2": -50}
COORDS = {"US": US_LOCS}

MEANS = [243.9, 0.6, 6.3, 0.013, 0.0002, 5.04, 21.8, 0.002, 9.75e-7, 7.87e-6]
STDS = [30.3, 0.42, 16.1, 7.9, 0.05, 20.6, 20.8, 0.0036, 7.09e-6, 2.7e-5]
SURF_VARS = ['AIRD', 'KM', 'RI', 'QV']

In [None]:
def standardize(ds, s, m):
    """
    Standardize the dataset using provided means and standard deviations.
    """
    assert len(list(ds.data_vars)) == len(m)

    # data_vars are ['T', 'AIRD', 'U', 'V', 'W', 'KM', "RI', 'QV', 'QI', 'QL']
    for i, var in  enumerate(ds.data_vars):  
        ds[var] = (ds[var] - m[i])/s[i]

    return ds

In [None]:
# Load and process the data
file_path = "" # INSERT FILEPATH WITH GLOBAL DATA

global_data = xr.open_mfdataset(file_path)
global_data = global_data.where(global_data['lev'] != 0, drop=True)

# (Optional) Filter coords for quicker processing
global_data = global_data.sel(lat=slice(COORDS["US"]["lat1"], COORDS["US"]["lat2"]), 
                              lon=slice(COORDS["US"]["lon1"], COORDS["US"]["lon2"]))

In [None]:
# Split into Covariates and Labels
times = [''] # Load desired tiem stamps in the form of a list of strings, ex: ['2006-01-08T10:30:00.000000000']

data_in = global_data.sel(time = times)
data_in = data_in[['T', 'AIRD', 'U', 'V', 'W', 'KM', 'RI', 'QV', 'QI', 'QL']]
data_in = xr.map_blocks(standardize, data_in, kwargs = {"m":MEANS, "s": STDS}, template = data_in)
data_in = data_in # this is a DataSet

data_out = global_data.sel(time = times)
data_out = data_out['Wstd'] # this is a DataArray

In [None]:
# Prepare X and y in appropriate array shapes
Xall = data_in
yall = data_out

levs = Xall.coords['lev'].values
for var in SURF_VARS:
    Xs = Xall[var].sel(lev = [71]) # 1 level above surface
    Xsfc = Xs
    
    for lev in range(len(levs)-1):
        Xsfc = xr.concat([Xsfc, Xs], dim='lev')
        
    Xsfc = Xsfc.assign_coords(lev=levs)
    Xall[f"{var}_sfc"] = Xsfc

Xall =  Xall.unify_chunks()
Xall = Xall.to_array()
Xall = Xall.stack( s = ('time', 'lat', 'lon', 'lev')) 

## Saving the Split

In [None]:
# create train-test split
indices = np.arange(len(Xall.s))
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
train_indices, val_indices = train_test_split(train_indices, test_size=0.25, random_state=42)

In [None]:
def save_split(filter_indices, filter_type, Xall, yall):
    print("Setting up Xall")
    Xall = Xall.rename({"variable":"ft"})                       
    Xall = Xall.squeeze()
    Xall = Xall.transpose()
    Xall = Xall.isel(s=filter_indices)
    Xall = Xall.chunk({"ft":14, "s": BATCH_SIZE})
    
    print("Setting up yall")
    yall = yall.stack(s = ('time', 'lat', 'lon', 'lev' ))
    yall = yall.squeeze()
    yall = yall.isel(s=filter_indices)
    yall = yall.transpose()   
    yall = yall.chunk({"s": BATCH_SIZE})
    
    print("saving Xall")
    Xall = Xall.reset_index('s')
    Xall.to_netcdf(f"X_{filter_type}.nc")
    
    print("saving yall")
    yall = yall.reset_index('s')
    yall.to_netcdff(f"y_{filter_type}.nc")

In [None]:
save_split(train_indices, "train", Xall, yall)
save_split(val_indices, "val", Xall, yall)
save_split(test_indices, "test", Xall, yall)

# Retrieving new training sets

In [None]:
file_path = "" # INSERT LOCATION OF TRAIN DATA FROM PREVIOUS SPLIT
X_train = xr.open_dataset(f'{file_path}/X_train.nc')
y_train = xr.open_dataset(f'{file_path}/y_train.nc')

for seed in range(15):
    np.random.seed(seed)
    dim_size = len(X_train.s)
    random_idx = np.random.choice(dim_size, size=dim_size, replace=True)

    Xall = X_train.isel(s = random_idx)
    Xall.to_netcdf(f"X_train_{seed}.nc")

    yall = y_train.isel(s = random_idx)
    yall.to_netcdf(f"y_train_{seed}.nc")