This file is used for fine-tuning air-sea oxygen flux based on pretraining.
Authors: Zuchuan Li
Date: 10/15/2024

# 1. Load data

In [1]:
import pickle

from utils import binning_img
import numpy as np
import pandas as pd
from scipy.stats import uniform
from oxygen_mae_constants import cat_cols, dt_meta, get_encode_shift

# Load predictors
o2_path = '/data1/zuchuan/data/Oxygen_Ocean/'
with open(o2_path + 'o2_spacetime_2024-10-14_all.pickle', 'rb') as h:
    dt = pickle.load(h)
    dt_x = dt['dt']
    dt_idx = dt['idx']
    
# Load air-sea oxygen flux
dt_o2 = pd.read_pickle(o2_path + 'o2_Flux_2024-08-08.pkl')

# @^@ match the index
dt_y = dt_o2['O2_FLUX'].iloc[dt_idx]


# 2. Preprocessing

In [2]:
# --------------------- #
# 2.1. Normalize output
# --------------------- #
dt_y_mean = np.mean(dt_y)
dt_y_std = np.std(dt_y)
dt_y = (dt_y - dt_y_mean) / dt_y_std
dt_y = dt_y.to_numpy()


# --------------------- #
# 2.2. Encode variables
# --------------------- #
dt_x_bin_enc = {}
for col in cat_cols:
    print(col)
    dt_x_bin_enc[col], _, _ = binning_img(dt_x[col], bins_level=dt_meta['CUT_POINT'][col])
    dt_x_bin_enc[col] += get_encode_shift(col)
    
    
# --------------------- #
# 2.3. Organize data for training
# --------------------- #
dt_xy = np.concatenate([dt_x_bin_enc[col] for col in cat_cols], axis=2)
dt_x = np.transpose(dt_xy, (3, 2, 0, 1))
print(dt_x.shape, dt_y.shape)

# 80% for training, 20% for validation
random_vals = uniform.rvs(size = dt_x.shape[0], loc=0, scale=1, random_state=123)
idx_train = random_vals <= 0.8
idx_val = ~idx_train
with open('/data0/zuchuan/mae_output/o2_finetune_data_train.pickle', 'wb') as h:
    pickle.dump({'x': dt_x[idx_train], 'y': dt_y[idx_train]}, h)
    
with open('/data0/zuchuan/mae_output/o2_finetune_data_val.pickle', 'wb') as h:
    pickle.dump({'x': dt_x[idx_val], 'y': dt_y[idx_val]}, h)
    

SST
CHL
PAR
U
V
MLD_CLM
SAL_CLM
SST_CLM
(174376, 20, 3, 3) (174376,)


In [4]:
print(idx_train.sum(), idx_val.sum())

139543 34833
