# 1 Import Packages

In [1]:
import xarray as xr
import os
import numpy as np

In [2]:
import src.config as config
import src.utils as utils
from tqdm import tqdm
import math

In [3]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

# 2 Metadata

In [4]:
experiment_family = "GrandEnsemble"
experiment = "hist"
realization = "lkm0001"
processing = "mergetime_rho_remap_masked_ym"

filename_dict = dict(experiment_family = experiment_family,
experiment = experiment,
realization = realization)

In [9]:
experiment_name = "cv_samplestandardized_20yr_split"

output_dir = os.path.join(config.data_pro_path, "ml_transform", experiment_name)
os.makedirs(output_dir, exist_ok = True)

In [10]:
realization_id_list = np.arange(1, 100+1)

In [11]:
realization_list = ["lkm{}".format(str(realization_id).zfill(4)) for realization_id in realization_id_list]

In [12]:
def load_realization_list(realization_list, filename_dict, lev_index):
    file_list = []
    data_list = []
    for realization in tqdm(realization_list):
        filename_dict["realization"] = realization
        filename_dict["processing"] = "mergetime_rho_remap_masked_ym"

        path, filename = utils.gen_absolute_path_and_filename(
        filename_dict = filename_dict,
        init_path = config.data_pro_path, 
        init_filestem ="", 
        filetype="nc" 
        )

        file = os.path.join(path, filename)
        assert os.path.exists(file), "file {} does not exist".format(file)
        data = xr.open_dataset(file)["rho"].assign_coords({"realization": realization}).isel(lev=lev_index)
        data_list.append(data)
        
    return xr.concat(data_list, dim="realization")

In [13]:
def load_amoc_realization_list(realization_list, filename_dict):
    file_list = []
    data_list = []
    for realization in tqdm(realization_list):
        filename_dict["realization"] = realization
        filename_dict["processing"] = "amoc_ym"
        path, filename = utils.gen_absolute_path_and_filename(
        filename_dict = filename_dict,
        init_path = config.data_pro_path, 
        init_filestem ="", 
        filetype="nc" 
        )

        file = os.path.join(path, filename)
        assert os.path.exists(file), "file {} does not exist".format(file)
        data = xr.open_dataset(file)["atlantic_moc"].assign_coords({"realization": realization})
        data_list.append(data)
        
    return xr.concat(data_list, dim="realization")

# 3 Processing

In [56]:
lev_index=23

## 3.1 Load Data

In [57]:
data      = load_realization_list(realization_list, filename_dict, lev_index=lev_index)

100%|██████████| 100/100 [02:33<00:00,  1.53s/it]


In [58]:
data_amoc = load_amoc_realization_list(realization_list, filename_dict)


100%|██████████| 100/100 [00:08<00:00, 11.59it/s]


In [59]:
lev = data.lev
lon = data.lon
lat = data.lat

## 3.2 Feature scaling

### 3.2.1 Stack Data

In [60]:
n_window = 20
years = data.time.dt.year.values

year_window_list = np.array_split(years, 8)


In [61]:
year_split_window_list = [years[20*i:20*(i+1)] for i in range(math.floor(len(years)/n_window))]

In [62]:
import itertools

In [63]:
period_split_combinations = list(itertools.combinations(year_window_list, 2))

In [64]:
train_data = split_combination[0]
valid_data = split_combination[1]

In [65]:
for i, split_combination in enumerate(period_split_combinations):
    
    train_year_list = split_combination[0]
    valid_year_list = split_combination[1]

    
    train_data = data.where(data.time.dt.year.isin(train_year_list)).dropna(dim="time", how="all")
    valid_data = data.where(data.time.dt.year.isin(valid_year_list)).dropna(dim="time", how="all")
    
    train_data_stack = train_data.stack(sample= ("realization","time"))
    valid_data_stack = valid_data.stack(sample= ("realization","time"))
    
    train_data_stack_landmasked = train_data_stack.where(train_data_stack!=0)
    valid_data_stack_landmasked = valid_data_stack.where(valid_data_stack!=0)
    
    train_data_stack_samplemean = train_data_stack_landmasked.mean(dim=("sample"))
    valid_data_stack_samplemean = valid_data_stack_landmasked.mean(dim=("sample"))

    train_data_stack_samplestd = train_data_stack_landmasked.std(dim=("sample"))
    valid_data_stack_samplestd = valid_data_stack_landmasked.std(dim=("sample"))
    
    train_data_stack_sampleanom = train_data_stack_landmasked - train_data_stack_samplemean
    valid_data_stack_sampleanom = valid_data_stack_landmasked - train_data_stack_samplemean
    
    train_data_stack_samplestandardized = train_data_stack_sampleanom/train_data_stack_samplestd
    valid_data_stack_samplestandardized = valid_data_stack_sampleanom/train_data_stack_samplestd
    

    train_data_stack_samplestandardized.unstack().to_netcdf(os.path.join(output_dir,"train_x_lev_{}_{}.nc".format(lev_index,i) ))
    valid_data_stack_samplestandardized.unstack().to_netcdf(os.path.join(output_dir,"valid_x_lev_{}_{}.nc".format(lev_index,i) ))
    
    
    amoc_depth = 1020
    amoc_lat = 26.5

    train_data_amoc = data_amoc.where(data_amoc.time.dt.year.isin(train_year_list)).dropna(dim="time", how="all")
    valid_data_amoc = data_amoc.where(data_amoc.time.dt.year.isin(valid_year_list)).dropna(dim="time", how="all")
    
    train_data_amoc_depth_1020_lat_26 = train_data_amoc.sel(depth_2 = amoc_depth, lat= amoc_lat).isel(lon=0)/(1025*10**6)
    valid_data_amoc_depth_1020_lat_26 = valid_data_amoc.sel(depth_2 = amoc_depth, lat= amoc_lat).isel(lon=0)/(1025*10**6)

    #train_data_amoc_depth_1020_lat_26_samplemean = train_data_amoc_depth_1020_lat_26.mean()
    #train_data_amoc_depth_1020_lat_26_samplestd = train_data_amoc_depth_1020_lat_26.std()

    #train_data_amoc_depth_1020_lat_26_samplestandardized = (train_data_amoc_depth_1020_lat_26 - train_data_amoc_depth_1020_lat_26_samplemean)/train_data_amoc_depth_1020_lat_26_samplestd
    #valid_data_amoc_depth_1020_lat_26_samplestandardized = (valid_data_amoc_depth_1020_lat_26 - train_data_amoc_depth_1020_lat_26_samplemean)/train_data_amoc_depth_1020_lat_26_samplestd

    #train_data_amoc_depth_1020_lat_26_samplemean.to_netcdf(os.path.join(output_dir,"train_data_amoc_depth_1020_lat_26_samplemean,nc"))
    #train_data_amoc_depth_1020_lat_26_samplestd.to_netcdf( os.path.join(output_dir, "train_data_amoc_depth_1020_lat_26_samplestd.nc"))

    train_data_amoc_depth_1020_lat_26.to_netcdf(os.path.join(output_dir, "train_y_{}_{}.nc".format(lev_index,i)))
    valid_data_amoc_depth_1020_lat_26.to_netcdf(os.path.join(output_dir, "valid_y_{}_{}.nc".format(lev_index,i)))