# Clean NaN data out of vertical wind profiles

Some GCMs fill ua/va/zg fields with NaN where the pressure level drops below surface pressure (e.g., NaN where ua/va/zg vertical levels are below surface elevation)

This fills those NaN values in ua/va/zg with surface fields (e.g., uas/vas/orog). 

GCMs with this problem: 
MRI-ESM-2.0
CESM2
CESM2-WACCM

In [None]:
from concurrent.futures import ProcessPoolExecutor
import os
import time
import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from PIL import Image
from rex import init_logger
import xarray as xr
from scipy.ndimage.filters import gaussian_filter
import scipy.stats as stats
import seaborn as sns
from scipy import ndimage

In [None]:
log_file = None
logger = logging.getLogger(__name__)
init_logger(__name__, log_level='DEBUG', log_file=log_file)
init_logger('sup3r', log_level='DEBUG', log_file=log_file)

In [None]:
def run(ua_fps, uas_fps, fp_out):
    logger.info('Opening files...')
    logger.info(f'3D Vert File 0 = {os.path.basename(ua_fps[0])}')
    logger.info(f'2D Surf File 0 = {os.path.basename(uas_fps[0])}')
    handle = xr.open_mfdataset(ua_fps)
    surface = xr.open_mfdataset(uas_fps, chunks={'time': 1})

    logger.info('Loading data...')
    ti_surf = surface['time'].values
    ti = handle['time'].values
    orog = handle['orog'].values
    ua = handle['ua'].values
    va = handle['va'].values
    zg = handle['zg'].values
    plev = handle['plev'].values

    for idt, timestamp in enumerate(ti):
        idt_surf = np.where(ti_surf == timestamp)[0][0]
        uas = surface['uas'][idt_surf].values
        vas = surface['vas'][idt_surf].values
        
        # set any NaN data to surface wind
        for idz in range(ua.shape[1]):
            nan_mask = (np.isnan(ua[idt, idz]) 
                        | np.isnan(va[idt, idz]) 
                        | np.isnan(zg[idt, idz]) 
                        | (zg[idt, idz] < orog))
            ilat, ilon = np.where(nan_mask)
            ua[idt, idz][nan_mask] = uas[nan_mask]
            va[idt, idz][nan_mask] = vas[nan_mask]
            zg[idt, idz][nan_mask] = orog[nan_mask]
        
        # set lowest pressure level to surface wind
        id0 = np.argmax(plev)
        ua[idt, id0] = uas
        va[idt, id0] = vas
        zg[idt, id0] = orog
        
        if (idt+1) % 100 == 0:
            logger.info('Finished {} out of {} ({}).'.format(idt+1, len(ti), timestamp))
    
    logger.info('Finished.')

    handle['ua'] = (handle['ua'].dims, ua)
    handle['va'] = (handle['va'].dims, va)
    handle['zg'] = (handle['zg'].dims, zg)

    handle.to_netcdf(fp_out)
    logger.info(f'Wrote output file: {fp_out}')

In [None]:
gcm = 'MRI-ESM2-0'
base_dir = f'/projects/alcaps/cmip6/{gcm}/'
scenario = 'historical'

ua_dates0 = ['20150101', '20250101', '20350101', '20450101', '20550101']
ua_dates1 = ['20241231', '20341231', '20441231', '20541231', '20641231']
uas_date0 = '20150101'
uas_date1 = '20641231'

ua_dates0 = ['19800101', ]  # slater enter the real date strings here
ua_dates1 = ['20001231', ]
uas_date0 = '19800101'
uas_date1 = '20001231'

futures = []
with ProcessPoolExecutor() as exe:
    for ua_date0, ua_date1 in zip(ua_dates0, ua_dates1):
        fp_out = os.path.join(base_dir, f'wind_day_{gcm}_{scenario}_r1i1p1f1_gn_{ua_date0}_{ua_date1}.nc')
        
        ua_fps = [
            os.path.join(base_dir, f'ua_day_{gcm}_{scenario}_r1i1p1f1_gn_{ua_date0}-{ua_date1}.nc'),
            os.path.join(base_dir, f'va_day_{gcm}_{scenario}_r1i1p1f1_gn_{ua_date0}-{ua_date1}.nc'),
            os.path.join(base_dir, f'zg_day_{gcm}_{scenario}_r1i1p1f1_gn_{ua_date0}-{ua_date1}.nc'),
            os.path.join(base_dir, f'orog_fx_{gcm}_ssp585_r1i1p1f1_gn.nc'),
            ]
        
        uas_fps = [
            os.path.join(base_dir, f'uas_day_{gcm}_{scenario}_r1i1p1f1_gn_{uas_date0}-{uas_date1}.nc'),
            os.path.join(base_dir, f'vas_day_{gcm}_{scenario}_r1i1p1f1_gn_{uas_date0}-{uas_date1}.nc'),
            ]
        
        future = exe.submit(run, ua_fps, uas_fps, fp_out)
        futures.append(future)
    for i, future in enumerate(futures):
        _ = future.result()
        logger.info(f'Completed {i+1} out of {len(futures)} futures')