In [1]:
import numpy as np
import torch
from torch import nn  
import time 
import os  
import xarray as xr
import subprocess
import matplotlib.pyplot as plt
from matplotlib import colors
import cartopy.crs as ccrs

In [2]:
## import self defined functions
from sys import path 
# insert at 1, 0 is the script path (or '' in REPL)
path.insert(1, '/tigress/cw55/local/python_lib')
from cg_funcs import global_mean_xarray
root = '/tigress/cw55/work/2022_radi_nn' 

In [3]:
from get_data_sw_AM4_std import get_data_sw_AM4
from get_data_lw_AM4_std import get_data_lw_AM4
######################################################
# common functions to split the training and test data
from NNRTMC_utils import NNRTMC_NN_sw, NNRTMC_NN_lw, split_train_test_sample, \
draw_batches, data_std_normalization_sw, data_std_normalization_lw,  return_exp_dir
torch.set_float32_matmul_precision('high') 

In [4]:
plt.rcParams['font.size'] = '6'

In [5]:
# torch.cuda.set_device(0) # select gpu_id, default 0 means the first GPU
# device = f'cuda:{torch.cuda.current_device()}'
device = 'cpu'
######################################################
# set exp name and runs
Exp_name = [
    'ens_AM4std_lw_cs_LiH4W256Relu_EY'    ,
    'ens_AM4std_lw_all_LiH4W256Relu_EY'   ,
    'ens_AM4std_sw_cs_LiH4W256Relu_EY'    ,
    'ens_AM4std_sw_all_LiH4W256Relu_EY'   ,
] 
sky_cond = {
    'ens_AM4std_lw_cs_LiH4W256Relu_EY'    : 'cs',
    'ens_AM4std_lw_all_LiH4W256Relu_EY'   : 'all',
    'ens_AM4std_sw_cs_LiH4W256Relu_EY'    : 'cs',
    'ens_AM4std_sw_all_LiH4W256Relu_EY'   : 'all',
}
#select ei based on std_[l/s]w_analysis
ei = {
    'ens_AM4std_lw_cs_LiH4W256Relu_EY'    : 0,
    'ens_AM4std_lw_all_LiH4W256Relu_EY'   : 0,
    'ens_AM4std_sw_cs_LiH4W256Relu_EY'    : 1,
    'ens_AM4std_sw_all_LiH4W256Relu_EY'   : 0,
}

work_dir = root+'/work/' 
# file list AM4 runs 
filelist = [f'/scratch/gpfs/cw55/AM4/work/CTL2000_train_y2000_stellarcpu_intelmpi_22_768PE/'+
            f'HISTORY/20000101.atmos_8xdaily.tile{_}.nc' for _ in range(1,2)] 
hybrid_p_sigma_para = xr.open_dataset('/tigress/cw55/data/NNRTMC_dataset/AM4_pk_bk_202207.nc')
A_k = hybrid_p_sigma_para.ak.values[None,:]
B_k = hybrid_p_sigma_para.bk.values[None,:] 

In [6]:
month_sel = None
day_sel = [15,18,21,24,27]
month_sel = [1]
day_sel = [1]

In [7]:

# a specific implement for Li5ReluBN
# (Li BN ReLU ) *4 Li
def regroup_linear_BN_para(ori_NN_parameters, nor_para):
    new_NN_parameters = []
    para_ind = 0
    for i in range(4):
        # accout for input_nor
        weight = ori_NN_parameters[para_ind]
        bias   = ori_NN_parameters[para_ind+1]
        if para_ind == 0:
            bias   = bias - weight@(nor_para['input_scale']*nor_para['input_offset'])
            weight = nor_para['input_scale']*weight
        # Batch_nor
        tmp_BN = ori_NN_parameters[para_ind+2]/np.sqrt(ori_NN_parameters[para_ind+5] + 1e-5)
        new_NN_parameters.append(tmp_BN[:,None]*weight)
        new_NN_parameters.append((bias-ori_NN_parameters[para_ind+4])*tmp_BN+ori_NN_parameters[para_ind+3])
        para_ind = para_ind+7
    
    weight = ori_NN_parameters[para_ind]
    bias   = ori_NN_parameters[para_ind+1]
    weight = weight/nor_para['output_scale'][:,None]
    bias   = bias/nor_para['output_scale']+nor_para['output_offset']
    new_NN_parameters.append(weight)
    new_NN_parameters.append(bias)
    return new_NN_parameters

def save_fnn_parameters(a_list_of_parameters):
    # save FNN parameters
    # num_layer = len(a_list_of_parameters)/2 
    if len(a_list_of_parameters)%2 == 0:
        num_layer = len(a_list_of_parameters)/2 
    else:
        raise Exception('num_layer must be integer')
    ds_nn_save = xr.Dataset()
    ds_nn_save['LN'] = int(len(a_list_of_parameters)/2)
    for i in range(ds_nn_save['LN'].values): 
        ds_nn_save[f'W{i+1}'] = ((f'x{i+1}',f'y{i+1}'),a_list_of_parameters[i*2])
        ds_nn_save[f'B{i+1}'] = ((f'x{i+1}'),a_list_of_parameters[i*2+1])
        ds_nn_save[f'size{i+1}0'] = a_list_of_parameters[i*2].shape[0]
        ds_nn_save[f'size{i+1}1'] = a_list_of_parameters[i*2].shape[1]
    return ds_nn_save

In [8]:
def Rad_NN_activation_function(x):
    # ReLU:
    if x>0: return x
    else: return  np.float32(0)
# a specific implement for Li5ReluBN
def Rad_NN_pred(NN_parameters, input_X):
    # first 4 Linear>Relu>BN
    intermediate = np.empty(256)
    para_ind = 0
    for i in range(4):
        # y = x*w+b 
        if i == 0:  
            intermediate = NN_parameters[para_ind] @ input_X.T + NN_parameters[para_ind+1][:,None]
        else:
            intermediate = NN_parameters[para_ind] @ intermediate + NN_parameters[para_ind+1][:,None]
        # y = sigma(y) apply activation function for all nodes
        intermediate = np.vectorize(Rad_NN_activation_function)(intermediate)  
        # move to next set of para (w,b) from NN 
        para_ind = para_ind + 2
    output_Y = NN_parameters[para_ind] @ intermediate + NN_parameters[para_ind+1][:,None] # y = x*w+b
    return output_Y.T

In [9]:
# %%time
predi = {}
error = {}
eng_err = {}
NN_model = {}
ds_regrid = {}
ds_save = {}
for mo in Exp_name:
    print(mo)
    ######################################################
    # load restart file
    run_num, exp_dir = return_exp_dir(work_dir, mo, create_dir=False)
    PATH_last =  exp_dir+f'/model{ei[mo]}_restart.{run_num-1:02d}.pth'
    restart_data = torch.load(PATH_last)  # load exist results and restart training
    print(f'load: {PATH_last}')
    # read training dataset, nor_para, model parameteres
    nor_para = restart_data['nor_para']
    model_state_dict = restart_data['model_state_dict']
    nn_struc_name = list(restart_data['model_state_dict'].keys())
    input_size  = restart_data['model_state_dict'][nn_struc_name[0]].shape[1]
    output_size = restart_data['model_state_dict'][nn_struc_name[-1]].shape[0]
    hidden_layer_width = restart_data['model_state_dict'][nn_struc_name[0]].shape[0]
    print(input_size, output_size, hidden_layer_width,) 
    # read data
    if 'lw' in mo:
        input_array_ori, output_array_ori, ds_coords = \
        get_data_lw_AM4(filelist, condition=sky_cond[mo], 
                        month_sel = month_sel, day_sel = day_sel, return_coords=True) 
    else: 
        input_array_ori, output_array_ori, rsdt_array_ori, ds_coords = \
        get_data_sw_AM4(filelist, condition=sky_cond[mo], 
                        month_sel = month_sel, day_sel = day_sel, return_coords=True) 
    
    # initialize model 
    if 'lw' in mo:
        NN_model[mo] = NNRTMC_NN_lw(device, nor_para, A_k, B_k,input_size, output_size, hidden_layer_width ,model_state_dict)  
    else:
        NN_model[mo] = NNRTMC_NN_sw(device, nor_para, A_k, B_k,input_size, output_size, hidden_layer_width ,model_state_dict)  
        
    # normalize data via saved nor_para in restart file
    if 'lw' in mo:
        nor_para, input_array, output_array   = \
        data_std_normalization_lw(input_array_ori, output_array_ori, nor_para) 
    else: 
        nor_para, input_array, output_array, rsdt_array, day_ind = \
        data_std_normalization_sw(input_array_ori, output_array_ori, rsdt_array_ori, nor_para)    
        
    # process NN dict and save parameters
    a_list_of_parameters = [model_state_dict[_].numpy() for _ in model_state_dict.keys()]
    a_list_of_parameters = regroup_linear_BN_para(a_list_of_parameters, nor_para)
    ds_save[mo] = save_fnn_parameters(a_list_of_parameters)
    ds_save[mo].attrs['info'] = f'FNN parameters for RadNN AM4 standard, 202305220'
    ds_save[mo].attrs['model'] = f'{mo} #{ei}'
    ds_save[mo].to_netcdf(f'./saved_para_files/2023May_O36_LiH4ReluW256/RadNN_para_AM4std_LiH4ReluW256.{mo}.nc')
    
    # check results
    print(input_array.shape,output_array.shape)
    # select a subset of input
    input_X = input_array[:10,:]
    output_Y = output_array[:10,:] 
    if 'lw' in mo:
        input_X_ori = input_array_ori[:10,:]
    else:
        input_X_ori = input_array_ori[day_ind[:10],:]
    # results from pytorch
    NN_pred1 = NN_model[mo].predict(torch.tensor(input_X).to(device))
    NN_pred1 = NN_pred1/nor_para['output_scale']+nor_para['output_offset']
    # results from prototype function 
    NN_pred2 = Rad_NN_pred(a_list_of_parameters, input_X_ori)
    r_err = (abs(NN_pred1 - NN_pred2)/abs(NN_pred1 + NN_pred2)).mean(axis=0)
    print(f'Maximum Relative Error | MAX : {np.max(r_err):8.2e} | Mean: {np.mean(r_err):8.2e}') 
    # break

ens_AM4std_lw_cs_LiH4W256Relu_EY
load: /tigress/cw55/work/2022_radi_nn/work/ens_AM4std_lw_cs_LiH4W256Relu_EY/model0_restart.03.pth
101 35 256
Data files:
['/scratch/gpfs/cw55/AM4/work/CTL2000_train_y2000_stellarcpu_intelmpi_22_768PE/HISTORY/20000101.atmos_8xdaily.tile1.nc']
Data selection:
    Month: [1] 
    Day: [1] 
Reading data... 0 
Read data done. Use time:   0s
Total data size: 73728
(73728, 101) (73728, 35)
Maximum Relative Error | MAX : 1.47e-05 | Mean: 3.70e-06
ens_AM4std_lw_all_LiH4W256Relu_EY
load: /tigress/cw55/work/2022_radi_nn/work/ens_AM4std_lw_all_LiH4W256Relu_EY/model0_restart.03.pth
365 35 256
Data files:
['/scratch/gpfs/cw55/AM4/work/CTL2000_train_y2000_stellarcpu_intelmpi_22_768PE/HISTORY/20000101.atmos_8xdaily.tile1.nc']
Data selection:
    Month: [1] 
    Day: [1] 
Reading data... 0 
Read data done. Use time:   0s
Total data size: 73728
(73728, 365) (73728, 35)
Maximum Relative Error | MAX : 5.48e-05 | Mean: 7.17e-06
ens_AM4std_sw_cs_LiH4W256Relu_EY
load: /tigres

  rsdt_r = np.where(np.isclose(rsdt,0,rtol=1e-05, atol=1e-1,), 0, 1/rsdt)



Read data done. Use time:   0s
Total data size: 73728
Night time will be removed! (rsdt==0)
Total data size (daylight): (45722,)
(45722, 106) (45722, 36)
Maximum Relative Error | MAX : 4.44e-07 | Mean: 2.37e-07
ens_AM4std_sw_all_LiH4W256Relu_EY
load: /tigress/cw55/work/2022_radi_nn/work/ens_AM4std_sw_all_LiH4W256Relu_EY/model0_restart.03.pth
370 36 256
Data files:
['/scratch/gpfs/cw55/AM4/work/CTL2000_train_y2000_stellarcpu_intelmpi_22_768PE/HISTORY/20000101.atmos_8xdaily.tile1.nc']
Data selection:
    Month: [1] 
    Day: [1] 
Reading data... 0 
Read data done. Use time:   0s
Total data size: 73728
Night time will be removed! (rsdt==0)
Total data size (daylight): (45722,)
(45722, 370) (45722, 36)
Maximum Relative Error | MAX : 1.84e-06 | Mean: 7.47e-07
