# Import packages

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import os
import re
import math
import scipy.integrate as integrate
from tensorflow import keras

from tqdm import tqdm

# Load data, normalization, and models

In [None]:
proj_name = "expanded"
timesteps = 48*2
!pwd

In [None]:
def checktime():
    print(datetime.datetime.now()) 
#change base path later for functional programming approach

data_path = "testing_data/"
norm_path = "norm_files/"
model_path = "../coupling_folder/h5_models/"

num_models = 330

inpsub = np.loadtxt(norm_path + "inp_sub.txt")
inpdiv = np.loadtxt(norm_path + "inp_div.txt")

heatScale = 1004
moistScale = 2.5e6
outscale = np.concatenate((np.repeat(heatScale, 30), np.repeat(moistScale, 30)))

with open(data_path + 'test_input.npy', 'rb') as f:
    test_input = np.load(f)[:,0:timesteps,:,:]
    
with open(data_path + 'test_target.npy', 'rb') as f:
    test_target = np.load(f)[:,0:timesteps,:,:]
    
nn_input = (test_input-inpsub[:,np.newaxis,np.newaxis,np.newaxis])/inpdiv[:,np.newaxis,np.newaxis,np.newaxis]

spData = xr.open_mfdataset(["/ocean/projects/atm200007p/jlin96/longSPrun_o3/AndKua_aqua_Base_training.cam2.h1.0001-01-19-00000.nc", \
                            "/ocean/projects/atm200007p/jlin96/longSPrun_o3/AndKua_aqua_Base_training.cam2.h1.0001-01-20-00000.nc"], \
                            decode_times = False)

assert test_input.shape[1] == test_target.shape[1]

#Creating mass weights
def createPressureGrid(h1Data):
    hyam = np.array(h1Data["hyam"])
    hybm = np.array(h1Data["hybm"])
    ps = np.array(h1Data["NNPS"])
    lats = np.array(h1Data["lat"])
    lons = np.array(h1Data["lon"])
    levs = 30
    times = np.array(range(len(ps)))
    pressureGrid = np.zeros([len(times), 30, len(lats), len(lons)])
    for t in range(len(times)):
        for lat in range(len(lats)):
            for lon in range(len(lons)):
                pressureGrid[t, :, lat, lon]  = hyam[t]*1e5 + ps[t][lat][lon]*hybm[t]
    return np.diff(pressureGrid, axis = 1)
pressures = np.mean(createPressureGrid(spData), axis = 0)[11:29]
mass_weights = pressures/sum(pressures.flatten())

#Creating area weights
r = 6371
def integrand(t):
    return math.sin(t)

def surfArea(lat1, lat2, lon1, lon2):
    lat1 = lat1 + 90
    lat2 = lat2 + 90
    lat1 = min(lat1,lat2)*math.pi/180
    lat2 = max(lat1, lat2)*math.pi/180
    lons = (max(lon1, lon2) - min(lon1, lon2))*math.pi/180
    a = integrate.quad(integrand, lat1, lat2)
    #max error is 2nd arg for a
    return lons*r*r*a[0]

# Longitudes are equidistant so we can simplify surfArea
def weight_area(lat1, lat2):
    lat1 = lat1 + 90
    lat2 = lat2 + 90
    lat1 = min(lat1,lat2)*math.pi/180
    lat2 = max(lat1, lat2)*math.pi/180
    weight = integrate.quad(integrand, lat1, lat2)
    return weight[0]

lats = np.array(spData["lat"])
assert(90+lats[0]==90-lats[63])
last_lat_mdiff = 90+lats[0]
lat_mdiff = np.diff(lats)/2
lat_buff = np.append(lat_mdiff, last_lat_mdiff)
lat_edges = lat_buff + lats
lat_edges = np.append(-90, lat_edges)
area_weights = []
for i in range(len(lats)):
    area_weights.append(weight_area(lat_edges[i],lat_edges[i+1]))
area_weights = np.array(area_weights)
area_weights = area_weights[np.newaxis,:,np.newaxis]

error_weights = area_weights * pressures
error_weights = error_weights/sum(error_weights.flatten())

In [None]:
test_input.shape

In [None]:
test_target.shape

# Functions for getting predictions and mean squared error

In [None]:
def get_prediction(proj_name, model_rank):
    f_load = model_path + '%s_model_%03d.h5'%(proj_name, model_rank)
    model = keras.models.load_model(f_load, compile=False)
    unrolled = np.reshape(nn_input, (125, -1)).transpose()
    prediction = model.predict(unrolled).transpose()/outscale[:,np.newaxis]
    prediction = np.reshape(prediction, (60, timesteps, 64, 128))
    return prediction

def squared_error(prediction, target):
    se = (prediction-target)**2
    se_T = np.mean(se[0:30,:,:,:], axis = 1)
    se_Q = np.mean(se[30:60,:,:,:], axis = 1)
    return se_T, se_Q

def weight_error(se):
    return se[12:30]*error_weights

def root_error(wse):
    return np.sum(wse)**.5

def get_rmse(prediction, target):
    se_T, se_Q = squared_error(prediction, target)
    rmse_T = root_error(weight_error(se_T))
    rmse_Q = root_error(weight_error(se_Q))
    return rmse_T, rmse_Q

In [None]:
rmse = []
for i in tqdm(range(num_models)):
    model_rank = i+1
    rmse_T, rmse_Q = get_rmse(get_prediction(proj_name, model_rank), test_target)
    rmse.append([rmse_T, rmse_Q])
rmse = np.array(rmse)

# Save offline errors

In [None]:
save_path = "offline_errors/"
with open(save_path + "rmse.npy", 'wb') as f:
    np.save(f, np.float32(rmse))

In [None]:
print("finished")

In [None]:
!pwd