In [None]:
'''

This code is part of the SIPN2 project focused on improving sub-seasonal to seasonal predictions of Arctic Sea Ice. 
If you use this code for a publication or presentation, please cite the reference in the README.md on the
main page (https://github.com/NicWayand/ESIO). 

Questions or comments should be addressed to nicway@uw.edu

Copyright (c) 2018 Nic Wayand

GNU General Public License v3.0


'''
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.ma as ma
import pandas as pd
import os
import xarray as xr
import glob
# import loadobservations as lo
from esio import import_data
from esio import metrics
from esio import EsioData as ed
from esio import ice_plot
import datetime
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import dask
import timeit
import xesmf as xe


# Dirs
E = ed.EsioData.load()
data_dir = E.obs_dir

# Flags
UpdateAll = True

# Products to import
product_list = ['iceBridgeQuickLook']

In [None]:
# multi-threads is fastest for this script
dask.config.set(scheduler='threads')  # overwrite default with threaded scheduler

In [None]:
method = 'conservative_normed' # ['bilinear', 'conservative', 'nearest_s2d', 'nearest_d2s', 'patch']
UpdateAll = True

In [None]:
# Loop through each product
for c_product in product_list:
    print('Importing ', c_product, '...')
    start_pt = 0 # Starting index

    # Find new files that haven't been imported yet
    native_dir = os.path.join(data_dir, c_product, 'native')
    os.chdir(native_dir)
    
    native_files = sorted(glob.glob('*.txt'))
    nc_dir = os.path.join(data_dir, c_product, 'sipn_nc')

    os.chdir(nc_dir)
    nc_files = sorted(glob.glob('*.nc'))
    
    if UpdateAll:
        new_files = [x.split('.txt')[0] for x in native_files]
        print('Updating all ', len(native_files), ' files...')
    else:
        new_files = np.setdiff1d([os.path.basename(x).split('.txt')[0] for x in native_files], 
                                 [x.split('.nc')[0] for x in nc_files]) # just get file name and compare
        print('Found ', len(new_files), ' new files to import...')

    # Loop through each file
    for nf in new_files:
        print(nf)
        
        # Load in 
        ds = import_data.load_1_iceBridgeQL(filein=os.path.join(native_dir, nf+'.txt'), start_pt=start_pt)
        
        # Save last point value
        start_pt = ds.point.values[-1]
        
        # Save to netcdf file
        ds.to_netcdf(os.path.join(nc_dir, nf+'.nc'))
        ds = None
    
    # For each Product
    print("Finished ", c_product)
    print("")

### Grid point data

In [None]:
FB_unc_threshold = 0.15 # m

In [None]:
0.15*10

In [None]:
# IceBridge Quick Look
ds_all = xr.open_mfdataset(os.path.join(E.obs_dir, 'iceBridgeQuickLook', 'sipn_nc', '*.nc'), concat_dim='point')
ds_all.set_coords(['date','lat','lon'], inplace=True)
# Shift lon to -180 to 180 space
ds_all['lon'] = ((ds_all.lon+180)%360)-180
# Remove points where lat is greater than 90 or less than 50 (recoding error I assume...)
ds_all = ds_all.where(ds_all.lat<=90, drop=True)
ds_all = ds_all.where(ds_all.lat>=50, drop=True)

ds_all

In [None]:
# Get hi unc for 2018 flights (0.8 m)
ds_all.where(ds_all.date>np.datetime64('2018-01-01'), drop=True).hi_unc.mean().values

In [None]:
ds_all.where(ds_all.date>np.datetime64('2018-01-01'), drop=True)

In [None]:
# plt.plot(ds_all.hi_unc.values, ds_all.hi.values,'k*')

In [None]:
# ds_all.hi_unc.plot()
# ds_all.hi.plot()

In [None]:
# ds_all.mean_fb.plot()
# ds_all.fb_unc.plot()


In [None]:
# ds_all.fb_unc.plot.hist(bins=100);

In [None]:
# Remove SIC values with high uncertainty
ds_all['hi'] = ds_all.hi.where(ds_all.fb_unc <= FB_unc_threshold)

In [None]:
# Get target grid
stero_grid_file = E.obs['NSIDC_0051']['grid']
obs_grid = import_data.load_grid_info(stero_grid_file, model='NSIDC')
# Ensure latitude is within bounds (-90 to 90)
# Have to do this because grid file has 90.000001
obs_grid['lat_b'] = obs_grid.lat_b.where(obs_grid.lat_b < 90, other = 90)
# Format for use here
#obs_grid = obs_grid.rename({'lat':'latitude','lon':'longitude'})[['latitude','longitude']]
`obs_grid

In [None]:
unique_dates = sorted(list(set(ds_all.date.values)), reverse=True)
print(len(unique_dates))

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
test_plots = False

In [None]:
# Load stuff
ds_all.load()

In [None]:
# Xemsf method of regridding

cvar = 'hi'
for cdate in unique_dates:
    print(cdate)
    
    # Output file path
    f_o = os.path.join(E.obs_dir, 'iceBridgeQuickLook', 'sipn_nc_grid', 'IB_'+pd.to_datetime(cdate).strftime('%Y-%m-%d')+'.nc')
    
    # Check if we already imported it
    if not UpdateAll:
        if os.path.isfile(f_o):
            print("Skipping ", os.path.basename(f_o), " already imported.")
            continue # Skip, file already imported
    
    # Grab icebridge data from current date (what we want to grid)
    X1 = ds_all[cvar].where(ds_all.date==cdate, drop=True)
    X2 = X1.where(X1.notnull(), drop=True)
    if X2.point.size==0:
        continue
        

    
    # Build ds_fine
    ds_fine = None
    
    m_res = 1000 # X meters
    deg_res = np.cos(np.deg2rad(X2.lat.mean().values))*111.321*1000
    tar_res = m_res / deg_res

    lat_arr = np.arange(X2.lat.min().values,X2.lat.max().values,tar_res)
    lon_arr = np.arange(X2.lon.min().values,X2.lon.max().values,tar_res)

    lat_arr_b = np.arange(X2.lat.min().values-tar_res/2,X2.lat.max().values+tar_res/2,tar_res)
    lon_arr_b = np.arange(X2.lon.min().values-tar_res/2,X2.lon.max().values+tar_res/2,tar_res)

    ds_cvar = xr.DataArray(np.ones((len(lat_arr),len(lon_arr))) * np.nan, dims=('y','x'))


    lonmesh, latmesh = np.meshgrid(lon_arr, lat_arr)
    lonmesh = xr.DataArray(lonmesh, dims=('y','x'))
    latmesh = xr.DataArray(latmesh, dims=('y','x'))
    #
    lonmesh_b, latmesh_b = np.meshgrid(lon_arr_b, lat_arr_b)
    lonmesh_b = xr.DataArray(lonmesh_b, dims=('y_b','x_b'))
    latmesh_b = xr.DataArray(latmesh_b, dims=('y_b','x_b'))

    y_arr = xr.DataArray(np.arange(0,len(lat_arr),1), dims=('y'))
    x_arr = xr.DataArray(np.arange(0,len(lon_arr),1), dims=('x'))
    #
    y_arr_b = xr.DataArray(np.arange(0,len(lat_arr_b),1), dims=('y_b'))
    x_arr_b = xr.DataArray(np.arange(0,len(lon_arr_b),1), dims=('x_b'))

    ds_fine = xr.Dataset({cvar:ds_cvar})
    ds_fine.coords['lat'] = latmesh
    ds_fine.coords['lon'] = lonmesh
    ds_fine.coords['y'] = y_arr
    ds_fine.coords['x'] = x_arr
    #
    ds_fine.coords['lat_b'] = latmesh_b
    ds_fine.coords['lon_b'] = lonmesh_b
    ds_fine.coords['y_b'] = y_arr_b
    ds_fine.coords['x_b'] = x_arr_b

    ds_fine.coords['lat_f'] = xr.DataArray(lat_arr, dims=('y'))
    ds_fine.coords['lon_f'] = xr.DataArray(lon_arr, dims=('x'))

    ds_fine['counter'] = xr.DataArray(np.ones((len(lat_arr),len(lon_arr))), dims=('y','x'))
    
    # Add point data to fine grid
    print("Found",X2.point.size,"points.")
    print("Adding to fine grid...")
    start_time_cmod = timeit.default_timer()

    # Get locations of all points in ds_fine grid
    X_points = ds_fine.swap_dims({'x':'lon_f','y':'lat_f'}).sel(lat_f=X2.lat, lon_f=X2.lon, method='nearest')

    # Loop through points and add up on fine grid
    for pt in X_points.point:

        # Set that value
        if ds_fine[cvar][pt.y,pt.x].isnull(): # Set if nan
            ds_fine[cvar][pt.y,pt.x] = X2.sel(point=pt)#.values
        else:
            # add them
            ds_fine[cvar][pt.y,pt.x] = ds_fine[cvar][pt.y,pt.x] + X2.sel(point=pt)#.values
            # increment counter 
            ds_fine['counter'][pt.y,pt.x] = ds_fine['counter'][pt.y,pt.x] + 1 
    print("Took ", (timeit.default_timer() - start_time_cmod)/X2.point.size, " seconds / point.")

    # Divide
    ds_fine[cvar] = ds_fine[cvar] / ds_fine['counter']
    ds_fine = ds_fine.drop('counter')
    ds_fine['mask'] = ds_fine.hi.notnull()
    

    # Check ds_fine still contains some non-NaN data
    if ds_fine[cvar].notnull().sum()==0:
        raise ValueError("No data left in ds_fine!")
    
    # Regrid
    regridder = xe.Regridder(ds_fine, obs_grid, method, periodic=False, 
                             reuse_weights='False')
    offset = 10
    ds_coarse = regridder(ds_fine[cvar]+10)
    ds_coarse = ds_coarse.where(ds_coarse>=(offset)) - offset
    regridder.clean_weight_file()
    
    # Add info
    ds_coarse.coords['time'] = cdate
    
    # Save to file
    ds_coarse.to_netcdf(f_o)
    print("Saved",f_o)
    
    if test_plots:
        plt.figure()
        X2.plot.hist(bins=100, color='k', alpha=0.3, label='Point Raw');
        ds_fine[cvar].plot.hist(bins=100, color='r', alpha=0.3, label='1km gridded mean');
        ds_coarse.plot.hist(bins=100, color='g', alpha=0.3, label='25km gridded mean');
        plt.legend()
        
        plt.figure()
        plt.plot(X2.point,X2)

        plt.figure()
        f = plt.figure(figsize=(10, 10))
        ax = plt.axes(projection=ccrs.NorthPolarStereo(central_longitude=-45))
        ax.coastlines(linewidth=0.75, color='black', resolution='50m')
        plt.scatter(X2.lon, X2.lat, c=X2,transform=ccrs.PlateCarree())
        plt.colorbar()

        plt.figure()
        f = plt.figure(figsize=(10, 10))
        ax = plt.axes(projection=ccrs.NorthPolarStereo(central_longitude=-45))
        ax.coastlines(linewidth=0.75, color='black', resolution='50m')
        ds_fine[cvar].plot(x='lon',y='lat',transform=ccrs.PlateCarree())

        plt.figure()
        f = plt.figure(figsize=(10, 10))
        ax = plt.axes(projection=ccrs.NorthPolarStereo(central_longitude=-45))
        ax.coastlines(linewidth=0.75, color='black', resolution='50m')
        ds_coarse.plot(x='lon',y='lat',transform=ccrs.PlateCarree())
        ax.set_extent([ds_fine.lon.min().values,
                      ds_fine.lon.max().values,
                      ds_fine.lat.min().values,
                      ds_fine.lat.max().values])

    #xr.exit()



In [None]:
# ## Going from point to fine grid







# # # Check lat/lon bounds
# # f = plt.figure(figsize=(10, 10))
# # ax = plt.axes(projection=ccrs.PlateCarree())
# # ax.coastlines(linewidth=0.75, color='black', resolution='50m')
# # ax.scatter(ds_fine.lon[0:10,0:10],ds_fine.lat[0:10,0:10],color='k', transform=ccrs.PlateCarree())
# # ax.scatter(ds_fine.lon_b[0:10,0:10],ds_fine.lat_b[0:10,0:10],color='r', transform=ccrs.PlateCarree())

# # # ax.scatter(ds_fine.lon[0:100:ds_fine.lon.size].values,
# # #            ds_fine.lat[0:100:ds_fine.lat.size].values,color='k', 
# # #            transform=ccrs.PlateCarree())

# # # Compare target and source grids
# # # Check lat/lon bounds
# # f = plt.figure(figsize=(10, 10))
# # ax = plt.axes(projection=ccrs.NorthPolarStereo(central_longitude=-45))
# # ax.coastlines(linewidth=0.75, color='black', resolution='50m')
# # ax.scatter(ds_fine.lon,ds_fine.lat,color='k',marker='.', transform=ccrs.PlateCarree())
# # ax.scatter(obs_grid.lon,obs_grid.lat,color='r', transform=ccrs.PlateCarree())
# # ax.scatter(ds_fine.lon,ds_fine.lat,color='k',marker='.', transform=ccrs.PlateCarree())






# # # Save to bucket

# # ds_fine.to_netcdf('/home/disk/sipn/nicway/data/temp/IBdata/ds_fine.nc')
# # obs_grid.to_netcdf('/home/disk/sipn/nicway/data/temp/IBdata/ds_target.nc')

# # ds_fine.to_zarr('/home/disk/sipn/nicway/data/temp/IBdata/fine.zarr', mode='w')
# # obs_grid.to_zarr('/home/disk/sipn/nicway/data/temp/IBdata/target.zarr', mode='w')

# ds_fine.hi.plot.hist(bins=100);







# f = plt.figure(figsize=(10, 10))
# ax = plt.axes(projection=ccrs.NorthPolarStereo(central_longitude=-45))
# ax.coastlines(linewidth=0.75, color='black', resolution='50m')
# ds_fine[cvar].plot(x='lon',y='lat',transform=ccrs.PlateCarree(),vmin=0,vmax=1)

# f = plt.figure(figsize=(10, 10))
# ax = plt.axes(projection=ccrs.NorthPolarStereo(central_longitude=-45))
# ax.coastlines(linewidth=0.75, color='black', resolution='50m')
# #ax.gridlines(crs=ccrs.PlateCarree(), linestyle='-')
# ds_coarse.plot(x='lon',y='lat',transform=ccrs.PlateCarree(),vmin=0,vmax=1)
# #ds_fine[cvar].plot(x='lon',y='lat',transform=ccrs.PlateCarree(), add_colorbar=False)

# ax.set_extent([ds_fine.lon.min().values,
#                ds_fine.lon.max().values,
#                ds_fine.lat.min().values,
#                ds_fine.lat.max().values])
# # ax.scatter(ds_fine.lon[0:100:ds_fine.lon.size].values,
# #            ds_fine.lat[0:100:ds_fine.lat.size].values,color='k', 
# #            transform=ccrs.PlateCarree())





# xr.exit()



# # LAT_STR = 'lat'
# # LON_STR = 'lon'
# # LAT_BOUNDS_STR = 'lat_b'
# # LON_BOUNDS_STR = 'lon_b'
# # X_STR = 'x'
# # Y_STR = 'y'
# # X_BOUNDS_STR = 'x_b'
# # Y_BOUNDS_STR = 'y_b'


# # def add_lat_lon_bounds(arr, lat_str=LAT_STR, lon_str=LON_STR,
# #                        lat_bounds_str=LAT_BOUNDS_STR,
# #                        lon_bounds_str=LON_BOUNDS_STR,
# #                        lat_min=-90., lat_max=90.,
# #                        lon_min=0., lon_max=360.):
# #     """Add bounding arrays to lat and lon arrays."""
# #     lon_vals = arr[lon_str].values
# #     lon_bounds_values = 0.5*(lon_vals[:-1] + lon_vals[1:])
# #     lon_bounds = np.concatenate([[lon_min], lon_bounds_values, [lon_max]])
    
# #     lat_vals = arr[lat_str].values
# #     lat_bounds_values = 0.5*(lat_vals[:-1] + lat_vals[1:])
# #     lat_bounds = np.concatenate([[lat_min], lat_bounds_values, [lat_max]])
    
# #     ds = arr.to_dataset()
# #     ds.coords[lon_bounds_str] = xr.DataArray(lon_bounds, dims=[X_BOUNDS_STR])
# #     ds.coords[lat_bounds_str] = xr.DataArray(lat_bounds, dims=[Y_BOUNDS_STR])
    
# #     return ds

# # ds_fine_4_xesmf_bnds = add_lat_lon_bounds(ds_fine[cvar], 
# #                        lat_str=LAT_STR, lon_str=LON_STR,
# #                        lat_bounds_str=LAT_BOUNDS_STR,
# #                        lon_bounds_str=LON_BOUNDS_STR,
# #                        lat_min=ds_fine.lat.min().values, 
# #                        lat_max=ds_fine.lat.max().values, 
# #                        lon_min=ds_fine.lon.min().values, 
# #                        lon_max=ds_fine.lon.min().values)
# # ds_fine_4_xesmf_bnds                       







# ds_fine.hi.plot.hist(alpha=0.5,label='gridded');
# X2.plot.hist(alpha=0.5,label='point data')
# plt.legend()



# f = plt.figure(figsize=(10, 10))
# ax = plt.axes(projection=ccrs.PlateCarree())
# ax.coastlines(linewidth=0.75, color='black', resolution='50m')
# #ax.gridlines(crs=ccrs.PlateCarree(), linestyle='-')
# ax.scatter(ds_fine_lon[0:3893:100, 0:14527:100].values,
#            ds_fine_lat[0:3893:100, 0:14527:100].values,color='k', 
#            transform=ccrs.PlateCarree())
# plt.plot(X2.lon,X2.lat,'r*')

# # First grid point data

# import collocate

# pt_lat = obs_grid.stack(gridcell=('ni', 'nj')).rename({'lat':'latitude','lon':'longitude'}).latitude
# pt_lon = obs_grid.stack(gridcell=('ni', 'nj')).rename({'lat':'latitude','lon':'longitude'}).longitude

# da_pt = xr.DataArray(np.zeros(pt_lat.size), dims=('gridcell'), coords = {'latitude':pt_lat, 'longitude':pt_lon})
# da_pt

# ds_fine







# ds_IN.hi.to_dataframe('var')

# ds_IN = ds_fine.rename({'lat':'latitude','lon':'longitude'}).swap_dims({'latitude':'lat_i','longitude':'lon_i'})

# # (Points, grid, distance)
# IB_mean = collocate.collocate(da_pt, 
#                               ds_IN,
#                               h_sep=25)

# IB_mean





# xr.exit()





# ds_all.where(ds_all.hi.notnull(), drop=True).date.plot()

# ## Test plots below
# # Get sea ice thickness data


# # ds_all = xr.open_mfdataset(nc_dir+'/*.nc', concat_dim='point')

# # Due to file name change, point is not in time order
# ds_all.date.plot()

# ds_all.hi.mean().values

# ds_all.hi.plot.hist(bins=100);

# ds_all.sd.plot.hist(bins=100);

# plt.figure(figsize=(20,10))
# ds_all.hi.plot(color='k')

# ds_all.date.plot()

# plt.figure(figsize=(10,10))
# plt.plot(ds_all.hi.values, ds_all.date.values,'k*')

# (f, ax) = ice_plot.polar_axis()
# f.set_size_inches(20,10)
# plt.scatter(ds_all.lon.values, ds_all.lat.values, transform=ccrs.PlateCarree(), c='r', marker='o')