This file is used for fine-tuning air-sea oxygen flux based on pretraining.
Authors: Zuchuan Li
Date: 10/15/2024

# 1. Load data

In [7]:
import pickle
import torch

from utils import binning_img, sin_days, sin_loc_idx
import numpy as np
import pandas as pd
from scipy.stats import uniform
from oxygen_mae_constants import cat_cols, dt_meta, get_encode_shift

# Load predictors
o2_path = '/data1/zuchuan/data/Oxygen_Ocean/'
with open(o2_path + 'o2_spacetime_2024-10-14_all.pickle', 'rb') as h:
    dt = pickle.load(h)
    dt_x = dt['dt']
    dt_idx = dt['idx']
    
# Load air-sea oxygen flux
dt_o2 = pd.read_pickle(o2_path + 'o2_Flux_2024-08-08.pkl')

# @^@ match the index
dt_y = torch.tensor(dt_o2['O2_FLUX'].iloc[dt_idx].to_numpy(), dtype=torch.float)

# location
lat = dt_o2['LAT'].iloc[dt_idx]
lon = dt_o2['LON'].iloc[dt_idx]
row_idx = ((90 - lat) / 180 * 2160).astype(int).to_numpy()
col_idx = ((lon + 180) / 360 * 4320).astype(int).to_numpy()
locs = sin_loc_idx(row_idx, col_idx, row=2160, col=4320).to(torch.float)

# days
dys = dt_o2['DATE'].iloc[dt_idx].dt.dayofyear.to_numpy()
yrs = dt_o2['DATE'].iloc[dt_idx].dt.year.to_numpy()
dys = sin_days(torch.tensor(dys)[None, :])
dys = torch.transpose(dys, 1, 0).to(torch.float)


In [8]:
dt_o2['DATE'].dt.year.value_counts() / dt_o2.shape[0]

DATE
2016    0.085012
2017    0.083235
2014    0.080751
2019    0.075542
2018    0.073792
2021    0.073239
2015    0.071999
2020    0.069724
2013    0.054101
2012    0.049264
2011    0.042712
2010    0.040419
2009    0.039696
2008    0.035566
2022    0.033362
2007    0.026606
2006    0.017789
2003    0.008242
2005    0.008017
2004    0.007386
2002    0.006092
2001    0.005599
2000    0.005056
1998    0.003690
1999    0.003109
Name: count, dtype: float64

# 2. Preprocessing

In [10]:
# --------------------- #
# 2.1. Normalize output
# --------------------- #
dt_y_bin, y_bins, y_levels = binning_img(dt_y, bins_num=300)
dt_y_bin_norm = (dt_y_bin - dt_y_bin.mean()) / np.std(dt_y_bin)
dt_y_bin = torch.tensor(dt_y_bin, dtype=torch.int)
dt_y_bin_norm = torch.tensor(dt_y_bin_norm, dtype=torch.float)

dt_y_mean = torch.mean(dt_y)
dt_y_std = torch.std(dt_y)
dt_y = (dt_y - dt_y_mean) / dt_y_std


# --------------------- #
# 2.2. Encode variables
# --------------------- #
dt_x_bin_enc = {}
for col in cat_cols:
    print(col)
    dt_x_bin_enc[col], _, _ = binning_img(dt_x[col], bins_level=dt_meta['CUT_POINT'][col])
    dt_x_bin_enc[col] += get_encode_shift(col)
    
    
# --------------------- #
# 2.3. Organize data for training
# --------------------- #
dt_xy = np.concatenate([dt_x_bin_enc[col] for col in cat_cols], axis=2)
dt_x_enc = np.transpose(dt_xy, (3, 2, 0, 1))
dt_x_enc = torch.tensor(dt_x_enc, dtype=torch.float)

# raw data without binning
dt_x_raw = np.concatenate([dt_x[col] for col in cat_cols], axis=2)
dt_x_raw = np.transpose(dt_x_raw, (3, 2, 0, 1))


# --------------------- #
# randomly split
# 70% for training, 15% for validation, 15% for testing
# --------------------- #
random_vals = uniform.rvs(size = dt_x_enc.shape[0], loc=0, scale=1, random_state=123)
idx_train = random_vals <= 0.7
idx_val = (~idx_train) & (random_vals <= 0.85)
idx_test = (~idx_train) & (~idx_val)

for name, idx in zip(['train', 'val', 'test'], [idx_train, idx_val, idx_test]):
    fn = '/data0/zuchuan/mae_output/o2_finetune_data_{}_rand.pickle'.format(name)
    with open(fn, 'wb') as h:
        pickle.dump({'x': dt_x_enc[idx],
                     'x_raw': dt_x_raw[idx], 
                     'y': dt_y[idx],
                     'y_bin': dt_y_bin[idx],
                     'y_bin_norm': dt_y_bin_norm[idx],
                     'locs': locs[idx],
                     'dys': dys[idx]}, h)

print('-------------randomly split dataset--------------')        
print('training: {}'.format(idx_train.sum()))
print('validation: {}'.format(idx_val.sum()))
print('testing: {}'.format(idx_test.sum()))


# --------------------- #
# split based on year
# validation: 2017-2018 (~16%)
# testing: 2019-2020 (~14%)
# --------------------- #
idx_val = (2017 <= yrs) & (yrs <= 2018) 
idx_test = (2019 <= yrs) & (yrs <= 2020) 
idx_train = (~idx_val) & (~idx_test)

for name, idx in zip(['train', 'val', 'test'], [idx_train, idx_val, idx_test]):
    fn = '/data0/zuchuan/mae_output/o2_finetune_data_{}_yr.pickle'.format(name)
    with open(fn, 'wb') as h:
        pickle.dump({'x': dt_x_enc[idx],
                     'x_raw': dt_x_raw[idx], 
                     'y': dt_y[idx],
                     'y_bin': dt_y_bin[idx],
                     'y_bin_norm': dt_y_bin_norm[idx],
                     'locs': locs[idx],
                     'dys': dys[idx]}, h)

print('-------------year-based split dataset--------------')        
print('training: {}'.format(idx_train.sum()))
print('validation: {}'.format(idx_val.sum()))
print('testing: {}'.format(idx_test.sum()))


SST
CHL
PAR
U
V
MLD_CLM
SAL_CLM
SST_CLM
-------------randomly split dataset--------------
training: 122007
validation: 26338
testing: 26031
-------------year-based split dataset--------------
training: 123877
validation: 26074
testing: 24425


In [31]:
y_bins

array([-3.96294975e+00, -5.75321034e-01, -4.07091945e-01, -3.27882610e-01,
       -2.79664457e-01, -2.46485755e-01, -2.20838197e-01, -2.01027174e-01,
       -1.84448928e-01, -1.71685226e-01, -1.60624772e-01, -1.50105212e-01,
       -1.41630396e-01, -1.34412967e-01, -1.27911896e-01, -1.21963767e-01,
       -1.16500907e-01, -1.11625977e-01, -1.06836732e-01, -1.02261469e-01,
       -9.81954858e-02, -9.44323838e-02, -9.08339955e-02, -8.76537841e-02,
       -8.46035928e-02, -8.18491466e-02, -7.91602917e-02, -7.65374228e-02,
       -7.42140338e-02, -7.18159415e-02, -6.98071495e-02, -6.77542686e-02,
       -6.58683777e-02, -6.40064888e-02, -6.22897092e-02, -6.05793120e-02,
       -5.90162389e-02, -5.73235834e-02, -5.57463411e-02, -5.41594168e-02,
       -5.27633056e-02, -5.13647087e-02, -5.01201674e-02, -4.88052610e-02,
       -4.75151427e-02, -4.62928908e-02, -4.50735763e-02, -4.39472748e-02,
       -4.28872555e-02, -4.18827301e-02, -4.09720149e-02, -3.99350254e-02,
       -3.89115140e-02, -

In [13]:
dt_x_enc[idx].shape

torch.Size([24425, 20, 3, 3])

In [14]:
torch.isnan(dt_x_enc[idx]).sum()

tensor(0)