In [1]:
import numpy as np
import torch
from torch import nn  
import time 
import os  
import xarray as xr
import subprocess
import matplotlib.pyplot as plt
from matplotlib import colors
import cartopy.crs as ccrs

In [2]:
## import self defined functions
from sys import path 
# insert at 1, 0 is the script path (or '' in REPL)
path.insert(1, '/tigress/cw55/local/python_lib')
from cg_funcs import global_mean_xarray
root = '/tigress/cw55/work/2022_radi_nn/NN_AM4'
path.insert(1,  root+'/work')

import AM4std_lw_cs_LiH4Relu_EY.train_script01  as lwcsey 
import AM4std_lw_all_LiH4Relu_EY.train_script01 as lwalley 
import AM4std_sw_cs_LiH4Relu_EY.train_script01  as swcsey 
import AM4std_sw_all_LiH4Relu_EY.train_script01 as swalley 

In [3]:
from get_data_lw_AM4_std import get_data_lw_AM4
from get_data_sw_AM4_std import get_data_sw_AM4
######################################################
# common functions to split the training and test data
from NNRTMC_utils import  split_train_test_sample, \
draw_batches, data_std_normalization, data_std_normalization_sw, print_key_results, return_exp_dir

In [4]:
if __name__ == '__main__': 
    # torch.cuda.set_device(0) # select gpu_id, default 0 means the first GPU
    # device = f'cuda:{torch.cuda.current_device()}'
    device = 'cpu'
    ######################################################
    # set exp name and runs
    Exp_name = [ 
        'AM4std_lw_cs_LiH4Relu_EY', 
        'AM4std_lw_all_LiH4Relu_EY'  ,
        'AM4std_sw_cs_LiH4Relu_EY', 
        'AM4std_sw_all_LiH4Relu_EY'  ,
    ]
    Exp_name_model_dict = { 
        Exp_name[0]  : lwcsey , 
        Exp_name[1]  : lwalley ,
        Exp_name[2]  : swcsey , 
        Exp_name[3]  : swalley ,
    }
    sky_cond = { 
        Exp_name[0]  : 'cs', 
        Exp_name[1]  : 'all',
        Exp_name[2]  : 'cs', 
        Exp_name[3]  : 'all',
    }
    work_dir = root+'/work/' 
    ######################################################
    # load data from AM4 runs
    filelist = [f'/scratch/gpfs/cw55/AM4/work/FIXSST_2000s_stellarcpu_intelmpi_22_768PE/'+
            f'HISTORY/20000101.atmos_8xdaily.tile{_}.nc' for _ in range(1,2)] 
    # input_array_ori, output_array_ori, rsdt_array_ori = \
    # get_data_sw_AM4(filelist, condition=sky_cond, month_sel = [1], day_sel = [1,7]) 
    
    hybrid_p_sigma_para = xr.open_dataset('/tigress/cw55/data/NNRTMC_dataset/AM4_pk_bk_202207.nc')
    A_k = hybrid_p_sigma_para.ak.values[None,:]
    B_k = hybrid_p_sigma_para.bk.values[None,:]


In [5]:
plt.rcParams['font.size'] = '6'

In [6]:
month_sel = None
day_sel = [15,18,21,24,27]
month_sel = [1]
day_sel = [1]

In [7]:

# a specific implement for Li5ReluBN
# (Li BN ReLU ) *4 Li
def regroup_linear_BN_para(ori_NN_parameters, nor_para):
    new_NN_parameters = []
    para_ind = 0
    for i in range(4):
        # accout for input_nor
        weight = ori_NN_parameters[para_ind]
        bias   = ori_NN_parameters[para_ind+1]
        if para_ind == 0:
            bias   = bias - weight@(nor_para['input_scale']*nor_para['input_offset'])
            weight = nor_para['input_scale']*weight
        # Batch_nor
        tmp_BN = ori_NN_parameters[para_ind+2]/np.sqrt(ori_NN_parameters[para_ind+5] + 1e-5)
        new_NN_parameters.append(tmp_BN[:,None]*weight)
        new_NN_parameters.append((bias-ori_NN_parameters[para_ind+4])*tmp_BN+ori_NN_parameters[para_ind+3])
        para_ind = para_ind+7
    
    weight = ori_NN_parameters[para_ind]
    bias   = ori_NN_parameters[para_ind+1]
    weight = weight/nor_para['output_scale'][:,None]
    bias   = bias/nor_para['output_scale']+nor_para['output_offset']
    new_NN_parameters.append(weight)
    new_NN_parameters.append(bias)
    return new_NN_parameters

def save_fnn_parameters(a_list_of_parameters):
    # save FNN parameters
    # num_layer = len(a_list_of_parameters)/2 
    if len(a_list_of_parameters)%2 == 0:
        num_layer = len(a_list_of_parameters)/2 
    else:
        raise Exception('num_layer must be integer')
    ds_nn_save = xr.Dataset()
    ds_nn_save['LN'] = int(len(a_list_of_parameters)/2)
    for i in range(ds_nn_save['LN'].values): 
        ds_nn_save[f'W{i+1}'] = ((f'x{i+1}',f'y{i+1}'),a_list_of_parameters[i*2])
        ds_nn_save[f'B{i+1}'] = ((f'x{i+1}'),a_list_of_parameters[i*2+1])
        ds_nn_save[f'size{i+1}0'] = a_list_of_parameters[i*2].shape[0]
        ds_nn_save[f'size{i+1}1'] = a_list_of_parameters[i*2].shape[1]
    return ds_nn_save

In [8]:
def Rad_NN_activation_function(x):
    # ReLU:
    if x>0: return x
    else: return  np.float32(0)
# a specific implement for Li5ReluBN
def Rad_NN_pred(NN_parameters, input_X):
    # first 4 Linear>Relu>BN
    intermediate = np.empty(256)
    para_ind = 0
    for i in range(4):
        # y = x*w+b 
        if i == 0:  
            intermediate = NN_parameters[para_ind] @ input_X.T + NN_parameters[para_ind+1][:,None]
        else:
            intermediate = NN_parameters[para_ind] @ intermediate + NN_parameters[para_ind+1][:,None]
        # y = sigma(y) apply activation function for all nodes
        intermediate = np.vectorize(Rad_NN_activation_function)(intermediate)  
        # move to next set of para (w,b) from NN 
        para_ind = para_ind + 2
    output_Y = NN_parameters[para_ind] @ intermediate + NN_parameters[para_ind+1][:,None] # y = x*w+b
    return output_Y.T

In [9]:
# %%time
predi = {}
error = {}
eng_err = {}
NN_model = {}
ds_regrid = {}
ds_save = {}
for mo in Exp_name:
    print(mo)
    ######################################################
    # load restart file
    run_num, exp_dir = return_exp_dir(work_dir, mo, create_dir=False)
    PATH_last =  exp_dir+f'/restart.{run_num-1:02d}.pth'
    restart_data = torch.load(PATH_last)  # load exist results and restart training
    print(f'load: {PATH_last}')
    # read training dataset, nor_para, model parameteres
    nor_para = restart_data['nor_para']
    model_state_dict = restart_data['model_state_dict']
    # read data
    if 'lw' in mo:
        input_array_ori, output_array_ori, ds_coords = \
        get_data_lw_AM4(filelist, condition=sky_cond[mo], 
                        month_sel = month_sel, day_sel = day_sel, return_coords=True) 
    else: 
        input_array_ori, output_array_ori, rsdt_array_ori, ds_coords = \
        get_data_sw_AM4(filelist, condition=sky_cond[mo], 
                        month_sel = month_sel, day_sel = day_sel, return_coords=True) 
    
    # initialize model 
    NN_model[mo] = Exp_name_model_dict[mo].NNRTMC_NN(device, nor_para, A_k, B_k, input_array_ori.shape[1],model_state_dict)  
    
    # normalize data via saved nor_para in restart file
    if 'lw' in mo:
        nor_para, input_array, output_array   = \
        data_std_normalization(input_array_ori, output_array_ori, nor_para) 
    else: 
        nor_para, input_array, output_array, rsdt_array, day_ind = \
        data_std_normalization_sw(input_array_ori, output_array_ori, rsdt_array_ori, nor_para)    
        
    # process NN dict and save parameters
    a_list_of_parameters = [model_state_dict[_].numpy() for _ in model_state_dict.keys()]
    a_list_of_parameters = regroup_linear_BN_para(a_list_of_parameters, nor_para)
    ds_save[mo] = save_fnn_parameters(a_list_of_parameters)
    ds_save[mo].attrs['info'] = f'FNN parameters for RadNN AM4 standard, 20230330'
    ds_save[mo].attrs['model'] = mo
    ds_save[mo].to_netcdf(f'RadNN_para_AM4std_LiH4ReluW256.{mo}.nc')
    
    
    # check results
    print(input_array.shape,output_array.shape)
    # select a subset of input
    input_X = input_array[:10,:]
    output_Y = output_array[:10,:] 
    if 'lw' in mo:
        input_X_ori = input_array_ori[:10,:]
    else:
        input_X_ori = input_array_ori[day_ind[:10],:]
    # results from pytorch
    NN_pred1 = NN_model[mo].predict(torch.tensor(input_X).to(device)).cpu().numpy()
    NN_pred1 = NN_pred1/nor_para['output_scale']+nor_para['output_offset']
    # results from prototype function 
    NN_pred2 = Rad_NN_pred(a_list_of_parameters, input_X_ori)
    r_err = (abs(NN_pred1 - NN_pred2)/abs(NN_pred1 + NN_pred2)).mean(axis=0)
    print('Maximum Relative Error:')
    print(r_err)
    # break

AM4std_lw_cs_LiH4Relu_EY
load: /tigress/cw55/work/2022_radi_nn/NN_AM4/work/AM4std_lw_cs_LiH4Relu_EY/restart.04.pth
Data files:
['/scratch/gpfs/cw55/AM4/work/FIXSST_2000s_stellarcpu_intelmpi_22_768PE/HISTORY/20000101.atmos_8xdaily.tile1.nc']
Data selection:
    Month: [1] 
    Day: [1] 
Reading data... 0 Done.
Total data size: 73728
(73728, 101) (73728, 36)
Maximum Relative Error:
[3.1104037e-07 2.7030077e-07 2.8069408e-07 6.5790636e-07 2.4520901e-07
 3.9918012e-07 1.4922189e-06 7.8801520e-07 1.0084433e-06 2.1688361e-06
 2.0623937e-05 3.4149743e-05 3.2982123e-06 1.5311558e-06 6.1635035e-07
 1.9528552e-06 2.6253156e-06 2.2480540e-06 2.0374714e-06 1.5900863e-06
 2.3006630e-06 1.3225094e-06 7.6378484e-07 7.6612594e-07 1.8549705e-06
 2.6536461e-06 5.4304983e-06 2.4725100e-06 1.4331661e-06 1.4699725e-06
 1.9452498e-06 3.1930172e-06 3.9239430e-06 3.8944372e-06 1.4462348e-05
 7.5794429e-05]
AM4std_lw_all_LiH4Relu_EY
load: /tigress/cw55/work/2022_radi_nn/NN_AM4/work/AM4std_lw_all_LiH4Relu_EY/re

  rsdt_r = np.where(np.isclose(rsdt,0,rtol=1e-05, atol=1e-1,), 0, 1/rsdt)


Done.
Total data size: 73728
Night time will be removed! (rsdt==0)
Total data size (daylight): (45722,)
(45722, 106) (45722, 36)
Maximum Relative Error:
[2.0233676e-07 1.4293546e-07 6.5711953e-07 1.8241627e-07 2.7779507e-07
 2.6873963e-07 1.7350837e-07 1.3345806e-07 1.5248105e-07 1.9116735e-07
 1.5810807e-07 1.4517727e-07 1.0548096e-07 1.9186091e-07 4.5169719e-07
 3.8325143e-07 1.1848590e-06 4.5437409e-07 4.9352064e-07 4.1384834e-07
 1.2141197e-06 4.7816104e-07 5.3893081e-07 4.4950986e-07 4.2453743e-07
 2.7594189e-07 5.2784736e-07 1.5631545e-07 3.1245898e-07 3.1626143e-07
 3.3281651e-07 3.1324083e-07 3.5890633e-07 3.2364431e-07 2.7173272e-07
 2.1107485e-07]
AM4std_sw_all_LiH4Relu_EY
load: /tigress/cw55/work/2022_radi_nn/NN_AM4/work/AM4std_sw_all_LiH4Relu_EY/restart.04.pth
Data files:
['/scratch/gpfs/cw55/AM4/work/FIXSST_2000s_stellarcpu_intelmpi_22_768PE/HISTORY/20000101.atmos_8xdaily.tile1.nc']
Data selection:
    Month: [1] 
    Day: [1] 
Reading data... 0 Done.
Total data size: 7372