In [1]:
import matplotlib.pyplot as plt
from helpers.cm26 import DatasetCM26, operator_Kochkov
from helpers.computational_tools import select_NA, select_Pacific, select_Cem, select_Equator, StateFunctions, compare
from helpers.ann_tools import ANN, import_ANN, minibatch, export_ANN
import torch
import torch.optim as optim
import numpy as np
from time import time
import json
%load_ext autoreload
%autoreload 3

In [2]:
ds = DatasetCM26()

In [3]:
# ann_Txy = ANN([27, 20, 1])
# ann_Txx_Tyy = ANN([27, 20, 2])
log_dict = {}
ann_Txy = import_ANN('trained_models/ANN_CM26_Kochkov/Txy_epoch_2000.nc')
ann_Txx_Tyy = import_ANN('trained_models/ANN_CM26_Kochkov/Txx_Tyy_epoch_2000.nc')

In [4]:
num_epochs=100

all_parameters = list(ann_Txy.parameters()) + list(ann_Txx_Tyy.parameters())
optimizer = optim.Adam(all_parameters, lr=1e-3)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 
        milestones=[int(num_epochs/2), int(num_epochs*3/4), int(num_epochs*7/8)], gamma=0.1)  

t_s = time()

for epoch in range(num_epochs):
    t_e = time()
    
    training_dataset = ds.sample_epoch(time=np.random.randint(6950), operator=operator_Kochkov) # Approximately 19 years for training
    testing_dataset = ds.sample_epoch(time=np.random.randint(6950,8035), operator=operator_Kochkov) # Approximately 3 years for testing
    print(f'---------- Epoch {epoch} ------------\n')
    
    for factor in [2,4,6,10,20]:
        ############# Training step ################
        
        batch = training_dataset[factor]
        SGSx = torch.tensor(batch.data.SGSx.values).type(torch.float32)
        SGSy = torch.tensor(batch.data.SGSy.values).type(torch.float32)
        
        SGS_norm = 1. / torch.sqrt((SGSx**2 + SGSy**2).mean())
        SGSx = SGS_norm * SGSx
        SGSy = SGS_norm * SGSy
        
        # order of loops is such that the last iteration
        # is standard model without augmentation
        for rotation in [90, 0]:
            for reflect_x in [True, False]:
                for reflect_y in [True, False]:
                    
                    optimizer.zero_grad()
                    prediction = batch.state.Apply_ANN(ann_Txy, ann_Txx_Tyy,
                        rotation=rotation, reflect_x=reflect_x, reflect_y=reflect_y)
                    ANNx = prediction['ZB20u'] * SGS_norm
                    ANNy = prediction['ZB20v'] * SGS_norm
                    MSE_train = ((ANNx-SGSx)**2 + (ANNy-SGSy)**2).mean()
        
                    MSE_train.backward()
                    optimizer.step()
        
        ############ Testing step ##################
        with torch.no_grad():
            batch = testing_dataset[factor]
            prediction = batch.state.Apply_ANN(ann_Txy, ann_Txx_Tyy)
            ANNx = prediction['ZB20u']
            ANNy = prediction['ZB20v']
            SGSx = torch.tensor(batch.data.SGSx.values).type(torch.float32)
            SGSy = torch.tensor(batch.data.SGSy.values).type(torch.float32)

            SGS_norm = 1. / torch.sqrt((SGSx**2 + SGSy**2).mean())

            MSE_test = ((ANNx*SGS_norm-SGSx*SGS_norm)**2 + (ANNy*SGS_norm-SGSy*SGS_norm)**2).mean()
        
        print(f'Factor: {factor}. '+'MSE train/test: [%.6f, %.6f]' % (float(MSE_train.data), float(MSE_test.data)))
        ########### Saving history of losses ############
        log_dict[f'epoch-{epoch+1}-factor-{factor}'] = dict(MSE_train=float(MSE_train.data),MSE_test=float(MSE_test.data))
        ######## Freeing memory ############
        del training_dataset[factor].data
        del training_dataset[factor].param
        del testing_dataset[factor].data
        del testing_dataset[factor].param
    t = time()
    print('Epoch time/Remaining time in seconds: [%d/%d]' % (t-t_e, (t-t_s)*(num_epochs/(epoch+1)-1)))
    scheduler.step()
    
    if (epoch+1) % 50 == 0:
        export_ANN(ann_Txy, input_norms=torch.ones(27), output_norms=torch.ones(1), 
           filename=f'trained_models/ANN_CM26_Kochkov_augmented/Txy_epoch_{epoch+1}.nc')
        export_ANN(ann_Txx_Tyy, input_norms=torch.ones(27), output_norms=torch.ones(2), 
           filename=f'trained_models/ANN_CM26_Kochkov_augmented/Txx_Tyy_epoch_{epoch+1}.nc')
        with open(f'trained_models/ANN_CM26_Kochkov_augmented/log_dict_epoch_{epoch+1}', 'w') as file:
            json.dump(log_dict, file)

---------- Epoch 0 ------------

Factor: 2. MSE train/test: [0.584559, 0.597894]
Factor: 4. MSE train/test: [0.572371, 0.607747]
Factor: 6. MSE train/test: [0.647255, 0.647357]
Factor: 10. MSE train/test: [0.814737, 0.798697]
Factor: 20. MSE train/test: [0.973381, 0.982239]
Epoch time/Remaining time in seconds: [27/2676]
---------- Epoch 1 ------------

Factor: 2. MSE train/test: [0.661409, 0.636738]
Factor: 4. MSE train/test: [0.635682, 0.593563]
Factor: 6. MSE train/test: [0.710839, 0.657809]
Factor: 10. MSE train/test: [0.815673, 0.824961]
Factor: 20. MSE train/test: [0.960262, 0.973040]
Epoch time/Remaining time in seconds: [26/2599]
---------- Epoch 2 ------------

Factor: 2. MSE train/test: [0.707960, 0.672557]
Factor: 4. MSE train/test: [0.660206, 0.635050]
Factor: 6. MSE train/test: [0.676754, 0.708666]
Factor: 10. MSE train/test: [0.809021, 0.796882]
Factor: 20. MSE train/test: [0.974526, 0.972888]
Epoch time/Remaining time in seconds: [27/2590]
---------- Epoch 3 ------------