In [1]:
import numpy as np
import torch
from torch import nn  
import time 
import os  
import xarray as xr
import subprocess
import matplotlib.pyplot as plt
from matplotlib import colors
import cartopy.crs as ccrs

In [2]:
## import self defined functions
from sys import path 
# insert at 1, 0 is the script path (or '' in REPL)
path.insert(1, '/tigress/cw55/local/python_lib')
from cg_funcs import global_mean_xarray
root = '/tigress/cw55/work/2022_radi_nn/NN_AM4'
path.insert(1,  root+'/work')
# import work.lw_csaf_Li5Relu_EN.train_script01 as lwcsafen
import work.lw_csaf_Li5Relu_EY.train_script01 as lwcsafey 
# import work.lw_af_Li5Relu_EN.train_script01 as lwafen
import work.lw_af_Li5Relu_EY.train_script01 as lwafey 

In [3]:
from get_AM4_data_lw import get_AM4_data_lw
######################################################
# common functions to split the training and test data
from NNRTMC_lw_utils import  split_train_test_sample, \
draw_batches, data_std_normalization, print_key_results, return_exp_dir
    
from diag_utils import batch_index_sta_end, pred_NN_batch,\
create_6tiles_lw,regrid_6tile2latlon

In [4]:
if __name__ == '__main__': 
    torch.cuda.set_device(0) # select gpu_id, default 0 means the first GPU
    device = f'cuda:{torch.cuda.current_device()}'
    ######################################################
    # set exp name and runs
    Exp_name = [
        # 'lw_csaf_Li5Relu_EN',
        'lw_csaf_Li5Relu_EY',
        # 'lw_af_Li5Relu_EN'  ,
        'lw_af_Li5Relu_EY'  ,
    ]
    Exp_name_model_dict = {
        # 'lw_csaf_Li5Relu_EN': lwcsafen,
        'lw_csaf_Li5Relu_EY': lwcsafey,
        # 'lw_af_Li5Relu_EN'  : lwafen,
        'lw_af_Li5Relu_EY'  : lwafey,
    }
    sky_cond = {
        # 'lw_csaf_Li5Relu_EN': 'csaf',
        'lw_csaf_Li5Relu_EY': 'csaf',
        # 'lw_af_Li5Relu_EN'  : 'af',
        'lw_af_Li5Relu_EY'  : 'af',
    }
    work_dir = root+'/work/' 
    # file list AM4 runs 
    out_filelist = [f'/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.fluxes.tile{_}.nc' for _ in range(1,7)]
    inp_filelist = [f'/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.new_offline_input.tile{_}.nc' for _ in range(1,7)]
    out_filelist = [f'/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.fluxes.tile{_}.nc' for _ in range(1,2)]
    inp_filelist = [f'/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.new_offline_input.tile{_}.nc' for _ in range(1,2)]

    hybrid_p_sigma_para = xr.open_dataset('/tigress/cw55/data/NNRTMC_dataset/AM4_pk_bk_202207.nc')
    A_k = hybrid_p_sigma_para.ak.values[None,:]
    B_k = hybrid_p_sigma_para.bk.values[None,:] 

In [5]:
plt.rcParams['font.size'] = '6'

In [6]:
month_sel = None
day_sel = [15,18,21,24,27]
month_sel = [1]
day_sel = [15]

In [7]:
%%time
predi = {}
error = {}
eng_err = {}
NN_model = {}
ds_regrid = {}

for mo in Exp_name:
    ######################################################
    # load restart file
    run_num, exp_dir = return_exp_dir(work_dir, mo, create_dir=False)
    PATH_last =  exp_dir+f'/restart.{run_num-1:02d}.pth'
    restart_data = torch.load(PATH_last)  # load exist results and restart training
    print(f'load: {PATH_last}')
    # read training dataset, nor_para, model parameteres
    nor_para = restart_data['nor_para']
    model_state_dict = restart_data['model_state_dict']
    # read data
    input_array_ori, output_array_ori, ds_coords = \
    get_AM4_data_lw(out_filelist, inp_filelist, condition=sky_cond[mo], 
                    month_sel = month_sel, day_sel = day_sel, return_coords=True) 
    # initialize model 
    NN_model[mo] = Exp_name_model_dict[mo].NNRTMC_NN(device, nor_para, A_k, B_k, input_array_ori.shape[1],model_state_dict)  
 
    # normalize data via saved nor_para in restart file
    nor_para, input_array, output_array   = data_std_normalization(input_array_ori, output_array_ori, nor_para)
    
    # try NN on test dataset  
    predi[mo], eng_err[mo] = pred_NN_batch(input_array, output_array, NN_model[mo], nor_para, device)
    error[mo] = predi[mo] - output_array_ori
    predi[mo][:,3:] = predi[mo][:,3:]*86400 # HR K/s >> K/day
    error[mo][:,3:] = error[mo][:,3:]*86400 # HR K/s >> K/day
    break

load: /tigress/cw55/work/2022_radi_nn/NN_AM4/work/lw_csaf_Li5Relu_EY/restart.04.pth
Data files:
['/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.fluxes.tile1.nc'] ['/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.new_offline_input.tile1.nc']
Data selection:
    Month: [1] 
    Day: [15] 
Reading data... 0 Done.
Total data size: 73728
CPU times: user 1.63 s, sys: 556 ms, total: 2.19 s
Wall time: 1.66 s


# process NN dict and save parameters

In [8]:
model_state_dict.keys()

dict_keys(['Res_stack.0.weight', 'Res_stack.0.bias', 'Res_stack.1.weight', 'Res_stack.1.bias', 'Res_stack.1.running_mean', 'Res_stack.1.running_var', 'Res_stack.1.num_batches_tracked', 'Res_stack.3.weight', 'Res_stack.3.bias', 'Res_stack.4.weight', 'Res_stack.4.bias', 'Res_stack.4.running_mean', 'Res_stack.4.running_var', 'Res_stack.4.num_batches_tracked', 'Res_stack.6.weight', 'Res_stack.6.bias', 'Res_stack.7.weight', 'Res_stack.7.bias', 'Res_stack.7.running_mean', 'Res_stack.7.running_var', 'Res_stack.7.num_batches_tracked', 'Res_stack.9.weight', 'Res_stack.9.bias', 'Res_stack.10.weight', 'Res_stack.10.bias', 'Res_stack.10.running_mean', 'Res_stack.10.running_var', 'Res_stack.10.num_batches_tracked', 'Res_stack.12.weight', 'Res_stack.12.bias'])

In [9]:

# a specific implement for Li5ReluBN
# (Li BN ReLU ) *4 Li
def regroup_linear_BN_para(ori_NN_parameters):
    new_NN_parameters = []
    para_ind = 0
    for i in range(4):
        tmp = ori_NN_parameters[para_ind+2]/np.sqrt(ori_NN_parameters[para_ind+5] + 1e-5)
        new_NN_parameters.append(tmp[:,None]*ori_NN_parameters[para_ind])
        new_NN_parameters.append((ori_NN_parameters[para_ind+1]-ori_NN_parameters[para_ind+4])*tmp+ori_NN_parameters[para_ind+3])
        para_ind = para_ind+7
    new_NN_parameters.append(ori_NN_parameters[para_ind])
    new_NN_parameters.append(ori_NN_parameters[para_ind+1])
    return new_NN_parameters


In [10]:

a_list_of_parameters = [model_state_dict[_].numpy() for _ in model_state_dict.keys()]
a_list_of_parameters = regroup_linear_BN_para(a_list_of_parameters)

# Prototype function for implement 

In [11]:
def Rad_NN_activation_function(x):
    # ReLU:
    if x>0: return x
    else: return  np.float32(0)
# a specific implement for Li5ReluBN
def Rad_NN_pred(NN_parameters, input_X):
    # first 4 Linear>Relu>BN
    intermediate = np.empty(256)
    para_ind = 0
    for i in range(4):
        # y = x*w+b 
        if i == 0:  
            intermediate = NN_parameters[para_ind] @ input_X.T + NN_parameters[para_ind+1][:,None]
        else:
            intermediate = NN_parameters[para_ind] @ intermediate + NN_parameters[para_ind+1][:,None]
        # y = sigma(y) apply activation function for all nodes
        intermediate = np.vectorize(Rad_NN_activation_function)(intermediate)  
        # move to next set of para (w,b) from NN 
        para_ind = para_ind + 2
    output_Y = NN_parameters[para_ind] @ intermediate + NN_parameters[para_ind+1][:,None] # y = x*w+b
    return output_Y.T

In [12]:
print(input_array.shape,output_array.shape)
input_X = input_array[:300,:]
output_Y = output_array[:300,:] 

(73728, 102) (73728, 36)


In [13]:
# results from pytorch
NN_pred1 = NN_model[mo].predict(torch.tensor(input_X).to(device)).cpu().numpy()
# results from prototype function
NN_pred2 = Rad_NN_pred(a_list_of_parameters, input_X)

In [15]:
abs(NN_pred1 - NN_pred2).mean(axis=0)/abs(NN_pred1 + NN_pred2).mean(axis=0)

array([3.8300288e-07, 6.3589977e-07, 5.5322863e-07, 2.2787182e-07,
       3.0507857e-07, 4.9270773e-07, 4.1362333e-07, 3.4659183e-07,
       3.5326934e-07, 5.4356229e-07, 7.1080552e-07, 1.4992122e-06,
       7.1352912e-07, 5.2024774e-07, 8.4973902e-07, 5.2025030e-07,
       4.2171934e-07, 4.7106175e-07, 5.8741665e-07, 5.4361800e-07,
       5.3361902e-07, 6.8633005e-07, 5.8288930e-07, 4.3927122e-07,
       4.8868583e-07, 5.3402482e-07, 4.9524471e-07, 4.1708316e-07,
       5.1602080e-07, 3.7395199e-07, 5.6597230e-07, 6.4413018e-07,
       8.6686998e-07, 7.1540325e-07, 7.7731954e-07, 8.2777575e-07],
      dtype=float32)