Prepare data for pretraining.

Authors: Zuchuan Li
Date: 10/09/2024


# 1. Get meta data 

In [1]:
import pickle
import os
import numpy as np
from importlib import reload 
import utils
utils = reload(utils)
from utils import binning_img
from oxygen_mae_constants import cat_cols

dt_path = "/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27"

# ---------------- #
# Load data function
# ---------------- #
def load_data(yr, col):
    dts = []
    for dy in range(1, 366, 8):
        name = "/{}.{}.pickle".format(yr, str(dy).rjust(3,"0"))
        if os.path.exists(dt_path + name):
            print(yr, dy)
            with open(dt_path + name, 'rb') as h:
                tmp = pickle.load(h)[col]
                dts = tmp if len(dts) == 0 else np.concatenate((dts, tmp), axis=-1)
    return dts

def load_datas(yrs, col):
    dts = []
    for yr in yrs:
        tmp = load_data(yr, col)
        dts = tmp if len(dts) == 0 else np.concatenate((dts, tmp), axis=-1)
    return dts    

In [12]:
cut_points = {}
dt_levels = {}
mean = []
std = []
num = 0

for col in cat_cols:
    print(col)
    dt = load_datas(range(2003, 2008), col)
    print(dt.shape)
    
    print("Binning...")
    dt, bins, dt_levels[col] = binning_img(dt, bins_num=100)
    bins[0] = -np.inf
    bins[-1] = np.inf
    cut_points[col] = bins
    
    print('Encode data...')
    dt = dt + num
    num += dt_levels[col]
    
    mean.append(np.mean(dt))
    std.append(np.std(dt))
    
# save meta data
f_name = dt_path + '/pretraining_metadata_2024-10-09.pickle'
dt_meta = {
    'CUT_POINT': cut_points,
    'LEVEL_NUM': dt_levels,
    'MEAN': np.repeat(np.array(mean), (3,3,3,4,4,1,1,1)),
    'STD': np.repeat(np.array(std), (3,3,3,4,4,1,1,1)),
    'cat_cols': cat_cols, 
}
with open(f_name, 'wb') as fid:
    pickle.dump(dt_meta, fid)

SST
2003 1
2003 9
2003 17
2003 25
2003 33
2003 41
2003 49
2003 57
2003 65
2003 73
2003 81
2003 89
2003 97
2003 105
2003 113
2003 121
2003 129
2003 137
2003 145
2003 153
2003 161
2003 169
2003 177
2003 185
2003 193
2003 201
2003 209
2003 217
2003 225
2003 233
2003 241
2003 249
2003 257
2003 265
2003 273
2003 281
2003 289
2003 297
2003 305
2003 313
2003 321
2003 329
2003 337
2003 345
2003 353
2003 361
2004 1
2004 9
2004 17
2004 25
2004 33
2004 41
2004 49
2004 57
2004 65
2004 73
2004 81
2004 89
2004 97
2004 105
2004 113
2004 121
2004 129
2004 137
2004 145
2004 153
2004 161
2004 169
2004 177
2004 185
2004 193
2004 201
2004 209
2004 217
2004 225
2004 233
2004 241
2004 249
2004 257
2004 265
2004 273
2004 281
2004 289
2004 297
2004 305
2004 313
2004 321
2004 329
2004 337
2004 345
2004 353
2004 361
2005 1
2005 9
2005 17
2005 25
2005 33
2005 41
2005 49
2005 57
2005 65
2005 73
2005 81
2005 89
2005 97
2005 105
2005 113
2005 121
2005 129
2005 137
2005 145
2005 153
2005 161
2005 169
2005 177
2005 1

# 2. Binning

In [23]:
import oxygen_mae_constants
oxygen_mae_constants = reload(oxygen_mae_constants)
from oxygen_mae_constants import cat_cols, get_encode_shift

# ----------------------- #
# Bin and encode all data
# ----------------------- #
for yr in range(2003, 2022):
    for dy in range(1, 366, 8):
        print(yr, dy)
        name = "/{}.{}.pickle".format(yr, str(dy).rjust(3,"0"))
        if not os.path.exists(dt_path + name):
            continue
        
        print('Load data...')
        with open(dt_path + name, 'rb') as h:
            dt = pickle.load(h)
        
        print('Bin and encode ...')
        for col in cat_cols:
            dt[col], _, _ = binning_img(dt[col], bins_num=100, 
                                        bins_level=dt_meta['CUT_POINT'][col])
            dt[col] += get_encode_shift(col)
            
        print('Save data...')
        name = "/{}.{}.{}.pickle".format(yr, str(dy).rjust(3,"0"), 'bin_encode')
        with open(dt_path + name, 'wb') as h:
            pickle.dump(dt, h)
            

2004 1
Load data...
Bin and encode ...
Save data...
2004 9
Load data...
Bin and encode ...
Save data...
2004 17
Load data...
Bin and encode ...
Save data...
2004 25
Load data...
Bin and encode ...
Save data...
2004 33
Load data...
Bin and encode ...
Save data...
2004 41
Load data...
Bin and encode ...
Save data...
2004 49
Load data...
Bin and encode ...
Save data...
2004 57
Load data...
Bin and encode ...
Save data...
2004 65
Load data...
Bin and encode ...
Save data...
2004 73
Load data...
Bin and encode ...
Save data...
2004 81
Load data...
Bin and encode ...
Save data...
2004 89
Load data...
Bin and encode ...
Save data...
2004 97
Load data...
Bin and encode ...
Save data...
2004 105
Load data...
Bin and encode ...
Save data...
2004 113
Load data...
Bin and encode ...
Save data...
2004 121
Load data...
Bin and encode ...
Save data...
2004 129
Load data...
Bin and encode ...
Save data...
2004 137
Load data...
Bin and encode ...
Save data...
2004 145
Load data...
Bin and encode ...
Sa

In [24]:
dt_meta

{'CUT_POINT': {'SST': array([       -inf, -1.11500001, -0.65249997, -0.285     ,  0.04      ,
          0.36281249,  0.70999998,  1.07333326,  1.44999993,  1.83770263,
          2.2349999 ,  2.63999987,  3.05199981,  3.4749999 ,  3.90999985,
          4.3499999 ,  4.79374981,  5.23499966,  5.67749977,  6.125     ,
          6.57499981,  7.01999998,  7.4666667 ,  7.9199996 ,  8.37312508,
          8.81833267,  9.26277733,  9.71500015, 10.17916584, 10.64999962,
         11.125     , 11.61499977, 12.11499977, 12.61499977, 13.11499977,
         13.60999966, 14.10750008, 14.60499954, 15.10499954, 15.59999943,
         16.09000015, 16.56500053, 17.02499962, 17.46500015, 17.89374924,
         18.30500031, 18.70999908, 19.10499954, 19.48999977, 19.86999893,
         20.23999977, 20.60000038, 20.95499992, 21.30500031, 21.64500046,
         21.97500038, 22.29500008, 22.60499954, 22.90499878, 23.19624901,
         23.47999954, 23.75499916, 24.01999855, 24.27499962, 24.51999855,
         24.754999

# 3. Build datafolder

In [4]:
import os
import shutil

src_path = '/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/'
dst_path = '/data1/zuchuan/data/Oxygen_Ocean/processed_data/'
for yr in range(2004, 2022):
    os.mkdir(dst_path + str(yr) + '/')
    for dy in range(1, 366, 8):
        name = "{}.{}.bin_encode.pickle".format(yr, str(dy).rjust(3,'0'))
        src_name = src_path + name
        if os.path.exists(src_name):
            print(src_name)
            dst_name = dst_path + str(yr) + '/' + name
            shutil.copy(src_name, dst_name)
            if os.path.exists(dst_name):
                os.remove(src_name)

/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.001.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.009.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.017.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.025.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.033.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.041.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.049.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.057.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.065.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.073.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/2004.081.bin_encode.pickle
/data1/zuchuan/data/Oxygen_Ocean

In [2]:
path = '/data0/zuchuan/processed_data/2005/'
with open(path + '2005.009.bin_encode.pickle', 'rb') as h:
    dt = pickle.load(h)

In [4]:
dt['MASK']

array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])

In [7]:
import numpy as np
name = '/data0/zuchuan/processed_data/2005/2005.001.bin_encode.pickle'
tmp = name.split('/')[-1].split('.')[:2]
np.array(tmp).astype(int)

array([2005,    1])

In [15]:
import numpy as np
idx = 6
np.where((idx+3) >= np.cumsum([3, 3, 3, 4, 4, 1, 1, 1]))[0][-1]

2

In [2]:
dt_meta.keys()

dict_keys(['CUT_POINT', 'LEVEL_NUM', 'MEAN', 'STD'])