In [1]:
import numpy as np
import torch
from torch import nn  
import time 
import os  
import xarray as xr
import subprocess
import matplotlib.pyplot as plt
from matplotlib import colors
import cartopy.crs as ccrs

In [2]:
## import self defined functions
from sys import path 
# insert at 1, 0 is the script path (or '' in REPL)
path.insert(1, '/tigress/cw55/local/python_lib')
from cg_funcs import global_mean_xarray
root = '/tigress/cw55/work/2022_radi_nn/NN_AM4'
path.insert(1,  root+'/work')

import lw_csaf_Li5Relu_EY.train_script01 as lwcsafey 
import lw_af_Li5Relu_EY.train_script01 as lwafey 
import sw_csaf_Li5Relu_EY.train_script01 as swcsafey 
import sw_af_Li5Relu_EY.train_script01 as swafey 

In [3]:
from get_AM4_data_lw import get_AM4_data_lw
from get_AM4_data_sw import get_AM4_data_sw
######################################################
# common functions to split the training and test data
from NNRTMC_lw_utils import  split_train_test_sample, \
draw_batches, data_std_normalization, print_key_results, return_exp_dir
    
from diag_utils import batch_index_sta_end, pred_NN_batch,\
create_6tiles_lw,regrid_6tile2latlon

In [4]:
if __name__ == '__main__': 
    torch.cuda.set_device(0) # select gpu_id, default 0 means the first GPU
    device = f'cuda:{torch.cuda.current_device()}'
    ######################################################
    # set exp name and runs
    Exp_name = [ 
        'lw_csaf_Li5Relu_EY', 
        'lw_af_Li5Relu_EY'  ,
        'sw_csaf_Li5Relu_EY', 
        'sw_af_Li5Relu_EY'  ,
    ]
    Exp_name_model_dict = { 
        'lw_csaf_Li5Relu_EY': lwcsafey, 
        'lw_af_Li5Relu_EY'  : lwafey,
        'sw_csaf_Li5Relu_EY': swcsafey, 
        'sw_af_Li5Relu_EY'  : swafey,
    }
    sky_cond = { 
        'lw_csaf_Li5Relu_EY': 'csaf', 
        'lw_af_Li5Relu_EY'  : 'af',
        'sw_csaf_Li5Relu_EY': 'csaf', 
        'sw_af_Li5Relu_EY'  : 'af',
    }
    work_dir = root+'/work/' 
    # file list AM4 runs 
    out_filelist = [f'/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.fluxes.tile{_}.nc' for _ in range(1,7)]
    inp_filelist = [f'/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.new_offline_input.tile{_}.nc' for _ in range(1,7)]
    # use one file
    out_filelist = [f'/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.fluxes.tile{_}.nc' for _ in range(1,2)]
    inp_filelist = [f'/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.new_offline_input.tile{_}.nc' for _ in range(1,2)]

    hybrid_p_sigma_para = xr.open_dataset('/tigress/cw55/data/NNRTMC_dataset/AM4_pk_bk_202207.nc')
    A_k = hybrid_p_sigma_para.ak.values[None,:]
    B_k = hybrid_p_sigma_para.bk.values[None,:] 

In [5]:
plt.rcParams['font.size'] = '6'

In [6]:
month_sel = None
day_sel = [15,18,21,24,27]
month_sel = [1]
day_sel = [1]

In [7]:

# a specific implement for Li5ReluBN
# (Li BN ReLU ) *4 Li
def regroup_linear_BN_para(ori_NN_parameters, nor_para):
    new_NN_parameters = []
    para_ind = 0
    for i in range(4):
        # accout for input_nor
        weight = ori_NN_parameters[para_ind]
        bias   = ori_NN_parameters[para_ind+1]
        if para_ind == 0:
            bias   = bias - weight@(nor_para['input_scale']*nor_para['input_offset'])
            weight = nor_para['input_scale']*weight
        # Batch_nor
        tmp_BN = ori_NN_parameters[para_ind+2]/np.sqrt(ori_NN_parameters[para_ind+5] + 1e-5)
        new_NN_parameters.append(tmp_BN[:,None]*weight)
        new_NN_parameters.append((bias-ori_NN_parameters[para_ind+4])*tmp_BN+ori_NN_parameters[para_ind+3])
        para_ind = para_ind+7
    
    weight = ori_NN_parameters[para_ind]
    bias   = ori_NN_parameters[para_ind+1]
    weight = weight/nor_para['output_scale'][:,None]
    bias   = bias/nor_para['output_scale']+nor_para['output_offset']
    new_NN_parameters.append(weight)
    new_NN_parameters.append(bias)
    return new_NN_parameters

def save_fnn_parameters(a_list_of_parameters):
    # save FNN parameters
    # num_layer = len(a_list_of_parameters)/2 
    if len(a_list_of_parameters)%2 == 0:
        num_layer = len(a_list_of_parameters)/2 
    else:
        raise Exception('num_layer must be integer')
    ds_nn_save = xr.Dataset()
    ds_nn_save['LN'] = int(len(a_list_of_parameters)/2)
    for i in range(ds_nn_save['LN'].values): 
        ds_nn_save[f'W{i+1}'] = ((f'x{i+1}',f'y{i+1}'),a_list_of_parameters[i*2])
        ds_nn_save[f'B{i+1}'] = ((f'x{i+1}'),a_list_of_parameters[i*2+1])
        ds_nn_save[f'size{i+1}0'] = a_list_of_parameters[i*2].shape[0]
        ds_nn_save[f'size{i+1}1'] = a_list_of_parameters[i*2].shape[1]
    return ds_nn_save

In [8]:
def Rad_NN_activation_function(x):
    # ReLU:
    if x>0: return x
    else: return  np.float32(0)
# a specific implement for Li5ReluBN
def Rad_NN_pred(NN_parameters, input_X):
    # first 4 Linear>Relu>BN
    intermediate = np.empty(256)
    para_ind = 0
    for i in range(4):
        # y = x*w+b 
        if i == 0:  
            intermediate = NN_parameters[para_ind] @ input_X.T + NN_parameters[para_ind+1][:,None]
        else:
            intermediate = NN_parameters[para_ind] @ intermediate + NN_parameters[para_ind+1][:,None]
        # y = sigma(y) apply activation function for all nodes
        intermediate = np.vectorize(Rad_NN_activation_function)(intermediate)  
        # move to next set of para (w,b) from NN 
        para_ind = para_ind + 2
    output_Y = NN_parameters[para_ind] @ intermediate + NN_parameters[para_ind+1][:,None] # y = x*w+b
    return output_Y.T

In [9]:
%%time
predi = {}
error = {}
eng_err = {}
NN_model = {}
ds_regrid = {}
ds_save = {}
for mo in Exp_name:
    print(mo)
    ######################################################
    # load restart file
    run_num, exp_dir = return_exp_dir(work_dir, mo, create_dir=False)
    PATH_last =  exp_dir+f'/restart.{run_num-1:02d}.pth'
    restart_data = torch.load(PATH_last)  # load exist results and restart training
    print(f'load: {PATH_last}')
    # read training dataset, nor_para, model parameteres
    nor_para = restart_data['nor_para']
    model_state_dict = restart_data['model_state_dict']
    # read data
    if 'lw' in mo:
        input_array_ori, output_array_ori, ds_coords = \
        get_AM4_data_lw(out_filelist, inp_filelist, condition=sky_cond[mo], 
                        month_sel = month_sel, day_sel = day_sel, return_coords=True) 
    else:
        input_array_ori, output_array_ori, ds_coords = \
        get_AM4_data_sw(out_filelist, inp_filelist, condition=sky_cond[mo], 
                        month_sel = month_sel, day_sel = day_sel, return_coords=True) 
        
    # initialize model 
    NN_model[mo] = Exp_name_model_dict[mo].NNRTMC_NN(device, nor_para, A_k, B_k, input_array_ori.shape[1],model_state_dict)  
    
    # normalize data via saved nor_para in restart file
    nor_para, input_array, output_array   = data_std_normalization(input_array_ori, output_array_ori, nor_para)
    
    # try NN on test dataset  
    predi[mo], eng_err[mo] = pred_NN_batch(input_array, output_array, NN_model[mo], nor_para, device)
    error[mo] = predi[mo] - output_array_ori
    predi[mo][:,3:] = predi[mo][:,3:]*86400 # HR K/s >> K/day
    error[mo][:,3:] = error[mo][:,3:]*86400 # HR K/s >> K/day
    # process NN dict and save parameters
    a_list_of_parameters = [model_state_dict[_].numpy() for _ in model_state_dict.keys()]
    a_list_of_parameters = regroup_linear_BN_para(a_list_of_parameters, nor_para)
    ds_save[mo] = save_fnn_parameters(a_list_of_parameters)
    ds_save[mo].attrs['info'] = f'FNN parameters for RadNN AM4, 20230304'
    ds_save[mo].attrs['model'] = mo
    ds_save[mo].to_netcdf(f'RadNN_para_ReLU_L5W256.{mo}.nc')
    
    
    # check results
    print(input_array.shape,output_array.shape)
    # select a subset of input
    input_X = input_array[:300,:]
    output_Y = output_array[:300,:] 
    input_X_ori = input_array_ori[:300,:]
    # results from pytorch
    NN_pred1 = NN_model[mo].predict(torch.tensor(input_X).to(device)).cpu().numpy()
    NN_pred1 = NN_pred1/nor_para['output_scale']+nor_para['output_offset']
    # results from prototype function 
    NN_pred2 = Rad_NN_pred(a_list_of_parameters, input_X_ori)
    r_err = abs(NN_pred1 - NN_pred2).mean(axis=0)/abs(NN_pred1 + NN_pred2).mean(axis=0)
    print(r_err)
    # break

lw_csaf_Li5Relu_EY
load: /tigress/cw55/work/2022_radi_nn/NN_AM4/work/lw_csaf_Li5Relu_EY/restart.04.pth
Data files:
['/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.fluxes.tile1.nc'] ['/scratch/gpfs/cw55/NNRTMC_data/AM4_v2/20000101.new_offline_input.tile1.nc']
Data selection:
    Month: [1] 
    Day: [1] 
Reading data... 0 Done.
Total data size: 73728
(73728, 102) (73728, 36)
[4.6300187e-07 4.4846544e-07 3.1752694e-07 4.0320899e-07 1.2767708e-06
 8.6028223e-07 6.5514820e-07 1.1562497e-06 1.5120446e-06 3.6600886e-06
 7.1314789e-06 6.5557260e-06 8.6845794e-06 1.1165456e-05 2.5080274e-06
 1.2070267e-06 1.9513373e-06 1.0148282e-06 3.7516327e-06 5.7518655e-06
 7.4827262e-06 9.0632466e-06 8.6924929e-06 7.9995316e-06 7.9817864e-06
 9.8964338e-06 2.2988123e-05 2.8059429e-05 1.8143368e-05 1.9735104e-05
 2.1709140e-05 1.6707705e-05 4.7031412e-05 8.1867256e-06 3.7826059e-05
 4.2433840e-05]
lw_af_Li5Relu_EY
load: /tigress/cw55/work/2022_radi_nn/NN_AM4/work/lw_af_Li5Relu_EY/restart.04.pth
Data files