This file is used to do pretraining for a dataset with temporal variability of image data from the global ocean. 
Authors: Zuchuan Li
Date: 09/27/2024

# 1. Meta data

In [1]:
import pickle
import os
import numpy as np
from importlib import reload 
import utils
utils = reload(utils)

cat_cols = ['SST', 'CHL', 'PAR', 
            'U', 'V', 
            'MLD_CLM', 'SAL_CLM', 'SST_CLM',
            ]
date_loc_cols = ['YR', 'DY', 'X', 'Y']

dt_path = "/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27"

# Meta data
f_name = dt_path + '/pretraining_metadata_2024-10-09.pickle'
with open(f_name, 'rb') as fid:
    dt_meta = pickle.load(fid)

# ------------------- #
# Load data for given year and day
# ------------------- #
def load_data_yr_dy(yr, dy):
    name = "/{}.{}.bin_encode.pickle".format(yr, str(dy).rjust(3,"0"))
    assert os.path.exists(dt_path + name)
    with open(dt_path + name, 'rb') as f:
        dt = pickle.load(f)
        dt['YR'] = np.ones(dt['CHL'].shape[-1]) * yr
        dt['DY'] = np.ones(dt['CHL'].shape[-1]) * dy
        dt['Y'], dt['X'] = np.where(dt['MASK'])
        del dt['MASK']
        return dt


# ------------------- #
# Combine data
# ------------------- #
def combine_cat_cols(in_dt):
    return np.concatenate([in_dt[ii] for ii in cat_cols], axis=2)

def combine_date_loc_cols(in_dt):
    return np.concatenate([in_dt[ii][:,None] for ii in date_loc_cols], axis=1)


# ------------------- #
# Combine data
# ------------------- #
def load_dataset_yrs(in_yrs, in_dys, loc_date=False):
    rs = []
    locs = []
    for ii in in_yrs:
        for jj in in_dys:
            in_dt = load_data_yr_dy(ii, jj)
            rs.append(combine_cat_cols(in_dt))
            if loc_date:
                locs.append(combine_date_loc_cols(in_dt))
    rs = np.concatenate(rs, axis=-1)
    if loc_date:
        locs = np.concatenate(locs, axis=0)
    return rs, locs


# 2. Prepare training data

In [2]:
with open(dt_path + '/test.data.pickle', 'rb') as fid:
    dt = pickle.load(fid)


# 3. Training

In [3]:
from dataclasses import dataclass
import datetime
import time

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from functools import partial

import oxygen_pretraining
oxygen_pretraining = reload(oxygen_pretraining)
from oxygen_pretraining import MaskedAutoencoderViT

# ------------------------ #
# Train the model
# ------------------------ #
def main(args, dataset_train):
    # fix the seed for reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    
    # load data
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, 
        sampler=sampler_train,
        batch_size=args.batch_size,
        drop_last=True,
    )
    
    # the model
    model = args.model.to(args.device)
    
    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    print(optimizer)
    
    start_time = time.time()
    for epoch in range(0, args.epochs):
        loss = simple_train_one_epoch(model, data_loader_train, 
                                      optimizer, args)
        print(epoch, loss)
        
    print('Training time {}'.format(time.time() - start_time))

# ------------------- #
# Training for one epoch
# ------------------- #
def simple_train_one_epoch(model, data_loader, optimizer, arg):
    model.train()
    loss_avg = 0
    for ep, xx in enumerate(data_loader):
        xx = torch.tensor(xx, dtype=torch.int, 
                          device=arg.device, requires_grad=False)
        loss, _, _ = model(xx, arg.mask_ratio)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_avg += loss.detach().cpu().numpy()
    return loss_avg
        

In [4]:
@dataclass
class Parameters:
    batch_size = 64
    epochs = 100
    accum_iter = 1

    # Model parameters
    mask_ratio =0.75

    # Optimizer parameters
    weight_decay = 0.05
    lr = 1e-3
    blr = 1e-3
    min_lr = 1e-7
    warmup_epochs = 5

    # Dataset parameters
    data_path = '/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/'
    output_dir = '/data1/zuchuan/data/Oxygen_Ocean/mae_output'
    log_dir = '/data1/zuchuan/data/Oxygen_Ocean/mae_output'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    seed = 0
    
    def __init__(self, dt_meta=None):
        # Configurate model
        emb_num = 808
        if dt_meta is not None:
            emb_num = 0
            for ii in ['SST', 'CHL', 'PAR', 'U', 'V', 'MLD_CLM', 'SAL_CLM', 'SST_CLM']:
                emb_num += dt_meta['LEVEL_NUM'][ii]
                
        # only SST and CHL
        # NOTE: they are placed at the beginning columns
        mask_chn_idx = torch.zeros((20,), dtype=torch.bool)
        mask_chn_idx[:6] = True
        
        self.model = MaskedAutoencoderViT(img_size=3, in_chans=20, patch_size=1, 
                                          embed_dim=512, depth=6, num_heads=8, 
                                          decoder_embed_dim=256, decoder_depth=2, decoder_num_heads=8, 
                                          mlp_ratio=4, num_embeddings=emb_num,
                                          mask_chn_idx=mask_chn_idx, 
                                          mean=dt_meta['MEAN'], std=dt_meta['STD'],
                                          norm_layer=partial(nn.LayerNorm, eps=1e-6))
        
args = Parameters(dt_meta)


In [6]:
main(args, dt)


AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)


  xx = torch.tensor(xx, dtype=torch.int,


0 261.1647734137101
1 26.030115758942298
2 23.061196158581513
3 17.58628619799727
4 12.822234465323035
5 11.947737373937269
6 11.786150262828194
7 11.448473888011113
8 11.781023103006437
9 11.251271553497588
10 11.08070163442602
11 10.98065305602715
12 10.905316096278263
13 10.8471296662917
14 10.618238786394796
15 10.62891915941448
16 10.625722298382389
17 10.80060153540156
18 10.440718603863726
19 10.448699724329279
20 10.445506787850496
21 10.427569509248366
22 10.6425760075915
23 10.396765563949879
24 10.197957396665466
25 10.281220195479909
26 10.26829925968168
27 10.428822387204963
28 10.24322594958143
29 10.397941488707213
30 10.330613991519074
31 10.165562910522766
32 10.281298672833218
33 10.119751423282812
34 10.098191053078718
35 10.19321313412734
36 10.073639542279635
37 10.238856883620732
38 10.27610493315246
39 10.30834558209197
40 10.099747227785649
41 10.233325354062345
42 10.248079160422373
43 10.107344543488866
44 10.24648565603638
45 10.199357105726497
46 10.13432257

In [8]:
import torchvision.datasets as datasets
dataset_train = datasets.ImageFolder('/data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/')

FileNotFoundError: Couldn't find any class folder in /data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27/.

In [11]:
def load_dt(in_name):
    import pickle
    with open(in_name, 'rb') as ff:
        ii = pickle.load(ff)
    return ii
tmp = os.path.join('/data1/zuchuan/data/Oxygen_Ocean', 'pretraining_data_2024-09-27')
datasets.DatasetFolder(tmp, load_dt, extensions=('bin_encode.pickle',), )

FileNotFoundError: Couldn't find any class folder in /data1/zuchuan/data/Oxygen_Ocean/pretraining_data_2024-09-27.

In [30]:
utils = reload(utils)
from utils import SatDataFolder, sat_loader_pickle
import os

tmp = SatDataFolder('/data1/zuchuan/data/Oxygen_Ocean/processed_data/', sat_loader_pickle,
                    extensions=('.pickle',))

In [31]:
for i, (xx,id) in enumerate(tmp):
    print(xx)
    if i == 0:
        break

[[[[100 100 100]
   [100 100 100]
   [100 100 100]]

  [[100 100 100]
   [ 15  15  15]
   [ 14  14  15]]

  [[ 14  14  13]
   [ 14  14  14]
   [ 15  15  15]]

  ...

  [[591 591 591]
   [591 591 591]
   [592 592 592]]

  [[670 670 670]
   [670 670 670]
   [670 670 670]]

  [[726 727 727]
   [726 726 727]
   [727 726 726]]]


 [[[100 100 100]
   [100 100 100]
   [100 100 100]]

  [[ 15  15  15]
   [ 14  14  15]
   [100 100 100]]

  [[ 14  14  14]
   [ 15  15  15]
   [ 15  15  15]]

  ...

  [[591 591 591]
   [592 592 592]
   [592 592 592]]

  [[670 670 670]
   [670 670 670]
   [670 670 670]]

  [[726 726 727]
   [727 726 726]
   [727 727 726]]]


 [[[100 100 100]
   [100 100 100]
   [100 100 100]]

  [[ 15  15  15]
   [ 15  15  15]
   [ 15  15  15]]

  [[ 14  14  14]
   [ 15  15  15]
   [ 15  15  15]]

  ...

  [[591 591 591]
   [592 592 592]
   [592 592 592]]

  [[670 670 670]
   [670 670 670]
   [670 670 670]]

  [[727 727 727]
   [726 727 726]
   [725 725 725]]]


 ...


 [[[100 100 

In [33]:
xx.shape

(495396, 20, 3, 3)

In [34]:
t = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                    [7, 8]]])
t

tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])

In [38]:
torch.flatten(t, start_dim=1, end_dim=2)

tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

In [43]:
import os
os.mkdir('/data0/Zuchuan')

PermissionError: [Errno 13] Permission denied: '/data0/Zuchuan'

In [17]:
from mae_main.util.pos_embed import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed_from_grid
from utils import _encode_spacetime2, encode_pos2d_chn1d

get_2d_sincos_pos_embed(16, 3, cls_token=True)

array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00],
       [ 8.41471016e-01,  9.98334214e-02,  9.99983307e-03,
         9.99999931e-04,  5.40302277e-01,  9.95004177e-01,
         9.99949992e-01,  9.99999523e-01,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00],
       [ 9.09297407e-01,  1.98669329e

In [15]:
get_1d_sincos_pos_embed_from_grid(4, np.arange(4))

array([[ 0.        ,  0.        ,  1.        ,  1.        ],
       [ 0.84147098,  0.00999983,  0.54030231,  0.99995   ],
       [ 0.90929743,  0.01999867, -0.41614684,  0.99980001],
       [ 0.14112001,  0.0299955 , -0.9899925 ,  0.99955003]])

In [16]:
_encode_spacetime2(4,4)

array([[ 0.        ,  1.        ,  0.        ,  1.        ],
       [ 0.84147098,  0.54030231,  0.00999983,  0.99995   ],
       [ 0.90929743, -0.41614684,  0.01999867,  0.99980001],
       [ 0.14112001, -0.9899925 ,  0.0299955 ,  0.99955003]])

In [32]:
encode_pos2d_chn1d(8, 2, 2, cls_token=False)

tensor([[[0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 1.0000],
         [0.0000, 1.0000, 0.0000, 1.0000, 0.8415, 0.0100, 0.5403, 0.9999],
         [0.8415, 0.5403, 0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 1.0000],
         [0.8415, 0.5403, 0.0000, 1.0000, 0.8415, 0.0100, 0.5403, 0.9999],
         [0.0000, 1.0000, 0.8415, 0.5403, 0.0000, 0.0000, 1.0000, 1.0000],
         [0.0000, 1.0000, 0.8415, 0.5403, 0.8415, 0.0100, 0.5403, 0.9999],
         [0.8415, 0.5403, 0.8415, 0.5403, 0.0000, 0.0000, 1.0000, 1.0000],
         [0.8415, 0.5403, 0.8415, 0.5403, 0.8415, 0.0100, 0.5403, 0.9999]]])

In [29]:
from importlib import reload 
utils = reload(utils)
from utils import encode_pos2d_chn1d

In [30]:
tmp = encode_pos2d_chn1d(1024, 3, 20, cls_token=True)

In [31]:
tmp[0]

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.0000, 0.0000, 0.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9093, 0.9581, 0.9870,  ..., 1.0000, 1.0000, 1.0000],
        [0.9093, 0.9581, 0.9870,  ..., 1.0000, 1.0000, 1.0000],
        [0.9093, 0.9581, 0.9870,  ..., 1.0000, 1.0000, 1.0000]])

In [58]:
for i in range(tmp.shape[1]):
    for j in range(i+1, tmp.shape[1]):
        count = (np.abs(tmp[0,i,:] - tmp[0,j,:]) > 0.1).sum()
        if count < 20:
            print(i,j)