# Importing packages

In [1]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import os
import re
from tqdm import tqdm

# Relative Humidity Conversion Functions

In [2]:
def eliq(T):
    a_liq = np.float32(np.array([-0.976195544e-15,-0.952447341e-13,\
                                 0.640689451e-10,\
                      0.206739458e-7,0.302950461e-5,0.264847430e-3,\
                      0.142986287e-1,0.443987641,6.11239921]));
    c_liq = np.float32(-80.0)
    T0 = np.float32(273.16)
    return np.float32(100.0)*np.polyval(a_liq,np.maximum(c_liq,T-T0))

def eice(T):
    a_ice = np.float32(np.array([0.252751365e-14,0.146898966e-11,0.385852041e-9,\
                      0.602588177e-7,0.615021634e-5,0.420895665e-3,\
                      0.188439774e-1,0.503160820,6.11147274]));
    c_ice = np.float32(np.array([273.15,185,-100,0.00763685,0.000151069,7.48215e-07]))
    T0 = np.float32(273.16)
    return np.where(T>c_ice[0],eliq(T),\
                   np.where(T<=c_ice[1],np.float32(100.0)*(c_ice[3]+np.maximum(c_ice[2],T-T0)*\
                   (c_ice[4]+np.maximum(c_ice[2],T-T0)*c_ice[5])),\
                           np.float32(100.0)*np.polyval(a_ice,T-T0)))

def esat(T):
    T0 = np.float32(273.16)
    T00 = np.float32(253.16)
    omtmp = (T-T00)/(T0-T00)
    omega = np.maximum(np.float32(0.0),np.minimum(np.float32(1.0),omtmp))
    return np.where(T>T0,eliq(T),np.where(T<T00,eice(T),(omega*eliq(T)+(1-omega)*eice(T))))

def RH(T,qv,P0,PS,hyam,hybm):
    R = np.float32(287.0)
    Rv = np.float32(461.0)
    p = P0 * hyam + PS[:, None] * hybm # Total pressure (Pa)
    
    T = np.float32(T)
    qv = np.float32(qv)
    p = np.float32(p)
    
    return Rv*p*qv/(R*esat(T))

# Data Processing Functions (Part 1)

In [3]:
def doMonth(month):
    datasets = !ls
    n = str(month)
    datasets = [x for x in datasets if "h1.0000-" + n.zfill(2) in x]
    return xr.open_mfdataset(datasets)

def makeSuffix(month):
    n = str(month)
    return "_" + n.zfill(2)

def saveNNInput(month):
    spData = doMonth(month)
    suffix = makeSuffix(month)
    print("read in data")
    nntbp = spData["NNTBP"].values
    nnqbp = spData["NNQBP"].values
    p0 = spData["P0"].values
    ps = spData["NNPS"].values
    hyam = spData["hyam"].values
    hybm = spData["hybm"].values
    relhum = spData["RELHUM"].values
    tphystnd = spData["TPHYSTND"].values
    phq = spData["PHQ"].values

    p0 = np.array(list(set(p0)))
    print("loaded in data")
    newhum = np.zeros((spData["time"].shape[0],\
                         spData["lev"].shape[0], \
                         spData["lat"].shape[0], \
                         spData["lon"].shape[0]))
    lats = spData["lat"]
    lons = spData["lon"]
    print("starting for loop")
    for i in tqdm(range(len(lats))):
        for j in range(len(lons)):
            latIndex = i
            lonIndex = j
            R = np.float32(287.0)
            Rv = np.float32(461.0)
            p = p0 * hyam + ps[:, None, latIndex, lonIndex] * hybm # Total pressure (Pa)

            T = np.float32(nntbp[:, :, latIndex, lonIndex])
            qv = np.float32(nnqbp[:, :, latIndex, lonIndex])
            p = np.float32(p)
            newhum[:,:, latIndex, lonIndex] = Rv*p*qv/(R*esat(T))
    
    nntbp = np.moveaxis(nntbp[1:,:,:,:],0,1)
    print("nntbp")
    print(nntbp.shape)
    
    nnqbp = np.moveaxis(nnqbp[1:,:,:,:],0,1)
    print("nnqbp")
    print(nnqbp.shape)
    
    lhflx = spData["LHFLX"].values[np.newaxis,:-1,:,:]
    print("lhflx")
    print(lhflx.shape)
    
    shflx = spData["SHFLX"].values[np.newaxis,:-1,:,:]
    print("shflx")
    print(shflx.shape)
    
    ps = spData["NNPS"].values[np.newaxis,1:,:,:]
    print("ps")
    print(ps.shape)
    
    solin = spData["SOLIN"].values[np.newaxis,1:,:,:]
    print("solin")
    print(solin.shape)
    
    newhum = np.moveaxis(newhum[1:,:,:,:],0,1)
    print("newhum")
    print(newhum.shape)
    
    oldhum = np.moveaxis(relhum[1:,:,:,:],0,1)
    print("oldhum")
    print(oldhum.shape)
    
    tphystnd = np.moveaxis(tphystnd[1:,:,:,:],0,1)
    print("tphystnd")
    print(tphystnd.shape)
    

    phq = np.moveaxis(phq[1:,:,:,:],0,1)
    print("phq")
    print(phq.shape)

    nnInput = np.concatenate((nntbp, \
                              nnqbp, \
                              lhflx, \
                              shflx, \
                              ps, \
                              solin, \
                              newhum, \
                              oldhum, \
                              tphystnd, \
                              phq))
    print("nnInput")
    nnInput.shape

    errors = (newhum-oldhum/100).flatten()
    result = "Mean error: " + str(np.mean(errors)) + "\n"
    result = result + "Variance: " + str(np.var(errors)) + "\n"
    result = result + "nntbp.shape: " + str(nntbp.shape) + "\n"
    result = result + "nnqbp.shape: " + str(nnqbp.shape) + "\n"
    result = result + "lhflx.shape: " + str(lhflx.shape) + "\n"
    result = result + "shflx.shape: " + str(shflx.shape) + "\n"
    result = result + "ps.shape: " + str(ps.shape) + "\n"
    result = result + "solin.shape: " + str(solin.shape) + "\n"
    result = result + "newhum.shape: " + str(newhum.shape) + "\n"
    result = result + "oldhum.shape: " + str(oldhum.shape) + "\n"
    result = result + "tphystnd.shape: " + str(tphystnd.shape) + "\n"
    result = result + "phq.shape: " + str(phq.shape) + "\n"
    result = result + "nnInput.shape: " + str(nnInput.shape) + "\n"
    print(result)
    
    #added 32 bit fix
    nnInput = np.float32(nnInput)
    
    fileName = 'nnInput' + suffix + '.npy'
    with open(fileName, 'wb') as f:
        np.save(f, nnInput)

    diagnostics = 'diagnostics' + suffix + '.txt'
    with open(diagnostics, 'a') as fp:
        fp.write(result)

# Data Processing Functions (Part 2)

In [None]:
path = "/ocean/projects/atm200007p/jlin96/longSPrun/"
subFolders = !ls
files = [path + x for x in subFolders if "nnInput" in x]
files

In [None]:
def sampleIndices(size, spacing, fixed = True):
    numIndices = np.round(size/spacing)
    if fixed:
        indices = np.array([int(x) for x in np.round(np.linspace(1,size,int(numIndices)))])-1
    else:
        indices = list(range(size))
        np.random.shuffle(indices)
        indices = indices[0:int(numIndices)]
    return indices

def shrinkArray(nnData, spacing):
    nnData = nnData[:,:,:,sampleIndices(nnData.shape[3], spacing, True)]
    nnData = nnData.ravel(order = 'F').reshape(184,-1,order = 'F')
    return nnData

def splitArray(nnData, variant):
    if variant == 0:
        nnInput = nnData[0:64,:]
    if variant == 1:
        nnInput = np.concatenate((nnData[:30,:],nnData[64:94,:], nnData[60:64,:]))
    #nnTarget = nnData[124:,:]
    nnVariant = np.concatenate((nnInput, nnData[124:,:]))
    return nnVariant

def reorderArray(nnData):
    canonical = np.concatenate([nnData[0:30,:], \
                                nnData[30:60,:], \
                                nnData[62:63,:], \
                                nnData[63:64,:], \
                                nnData[61:62,:], \
                                nnData[60:61,:], \
                                nnData[64:94,:], \
                                nnData[94:124,:]], axis = 0)
    return canonical

In [None]:
datasets = []
for arr in tqdm(files):
    with open(arr, 'rb') as f:
        nnData = np.load(f)
        nnData = nnData[:,:-1,:,:] # this was to account for the weird humidity error at the end.
    datasets.append(shrinkArray(nnData, 5))
    del nnData
combinedData = np.concatenate(datasets, axis = 1)

In [None]:
arr_specific = splitArray(combinedData, 0)
del combinedData

path = "/ocean/projects/atm200007p/jlin96/nnIngredientFactory/preprocessing/ingredientsLong/"

with open(path + 'nnDataSpecific_long_5.npy', 'wb') as f:
    np.save(f, rearrangedSpecific)

In [None]:
arr_relative = splitArray(combinedData, 1)
del combinedData

path = "/ocean/projects/atm200007p/jlin96/nnIngredientFactory/preprocessing/ingredientsLong/"

with open(path + 'nnDataRelative_long_5.npy', 'wb') as f:
    np.save(f, rearrangedRelative)

# Data Processing Functions (Part 3)