In [None]:
'''

This code is part of the SIPN2 project focused on improving sub-seasonal to seasonal predictions of Arctic Sea Ice. 
If you use this code for a publication or presentation, please cite the reference in the README.md on the
main page (https://github.com/NicWayand/ESIO). 

Questions or comments should be addressed to nicway@uw.edu

Copyright (c) 2018 Nic Wayand

GNU General Public License v3.0


'''

%matplotlib inline
%load_ext autoreload
%autoreload
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import numpy.ma as ma
import struct
import os
import xarray as xr
import glob
import datetime 
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import seaborn as sns
import pandas as pd

# ESIO Imports

from esio import EsioData as ed

import dask

from dask.distributed import Client
client = Client()


# General plotting settings
sns.set_style('ticks')
sns.set_context("talk", font_scale=1.5, rc={"lines.linewidth": 2.5})

#############################################################
# Load in Data
#############################################################
E = ed.EsioData.load()
data_dir = E.data_dir
grid_dir = E.grid_dir
fig_dir = os.path.join(E.fig_dir, 'model', 'extent_test')

In [None]:
client

In [None]:
runType = 'forecast'
variables = ['sic']
cvar = variables[0]
test_plots = False

In [None]:
# Define models
# models_2_process = list(E.model.keys())
# models_2_process = [x for x in models_2_process if x!='piomas'] # remove some models
models_2_process = ['yopp']
models_2_process

In [None]:
# Load in Obs
da_obs_in = xr.open_mfdataset(E.obs['NSIDC_0081']['sipn_nc']+'/*.nc', concat_dim='time', autoclose=True)
ds_region = xr.open_mfdataset(os.path.join(E.grid_dir, 'sio_2016_mask_Update.nc'))

In [None]:
for (i, c_model) in enumerate(models_2_process):
    print(c_model)
    
    # Output temp dir
    out_dir =  os.path.join(data_dir, 'model', c_model , 'forecast', 'agg_nc')
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # Load in Model
    model_forecast = os.path.join(E.model[c_model][runType]['sipn_nc'], '*.nc')
    ds_model = xr.open_mfdataset(model_forecast, 
                 chunks={'ensemble': 1, 'fore_time': 1, 'init_time': 1, 'nj': 304, 'ni': 448},
                    concat_dim='init_time')
    ds_model.rename({'nj':'x', 'ni':'y'}, inplace=True)
    
    # Select by variable
    da_mod_in = ds_model[cvar]
    
    


In [None]:
ds_model

In [None]:
c_ds = ds_model.isel(ensemble=0).isel(init_time=0)
c_ds.coords['valid_time'] = c_ds.init_time + c_ds.fore_time 
c_ds.swap_dims({'fore_time':'valid_time'}, inplace=True)
c_ds

In [None]:
comm_time = [x for x in c_ds.valid_time.values if x in da_obs_in.time.values]
c_obs = da_obs_in.sic.sel(time=comm_time)
c_mod = c_ds.sic.sel(valid_time=comm_time)
c_mod = c_mod.rename({'valid_time':'time'})

In [None]:
# Mask by Regoin
cR = 3
ds_region.sel(nregions=cR).region_names.values

In [None]:
c_mod_reg = c_mod.where(ds_region.mask==cR, drop=True)
c_obs_reg = c_obs.where(ds_region.mask==cR, drop=True)

In [None]:
(c_obs_reg.isel(time=15)-c_obs_reg.isel(time=0)).plot()

In [None]:
(c_mod_reg.isel(time=15)-c_mod_reg.isel(time=0)).plot()

In [None]:
# Initial conditoin error
plt.figure()
plt.scatter(c_mod_reg.T.isel(time=0).values, c_obs_reg.isel(time=0).values)
plt.axis('square');

In [None]:
# plt.figure()
# for x in c_mod_reg.x:
#     for y in c_mod_reg.y:
#         tO = c_obs_reg.sel(x=x1, y=y1)
#         tM = c_mod_reg.sel(x=x1, y=y1)
#         if np.any(tO.notnull()):
#             tO.plot(color='b')
#         if np.any(tM.notnull()):
#             tM.plot(color='r')    

In [None]:
c_obs_reg.values.shape

In [None]:
# Reshape by points (use xarray function)
new_obs = xr.DataArray(np.reshape(c_obs_reg.values, (c_obs_reg.time.size,c_obs_reg.y.size*c_obs_reg.x.size)), dims=('time','point'), coords={'time':c_obs_reg.time})
new_mod = xr.DataArray(np.reshape(c_mod_reg.values, (c_mod_reg.time.size,c_mod_reg.y.size*c_mod_reg.x.size)), dims=('time','point'), coords={'time':c_mod_reg.time})
new_mod = new_mod.where((new_obs.notnull()) , drop=True)
new_obs = new_obs.where((new_obs.notnull()) , drop=True)


In [None]:
f = plt.figure(figsize=(10,5))
plt.plot(new_obs.time, new_obs.values, color='blue', alpha=0.01);
f.autofmt_xdate()

In [None]:
f = plt.figure(figsize=(10,5))
plt.plot(new_mod.time, new_mod.values, color='red', alpha=0.01);
f.autofmt_xdate()

In [None]:
f = plt.figure(figsize=(10,5))
plt.plot(new_mod.values, new_obs.values, color='red', alpha=0.05);

In [None]:
f = plt.figure(figsize=(10,5))
plt.plot(new_mod.time, (new_mod-new_obs).values, color='red', alpha=0.01);
f.autofmt_xdate()

In [None]:
new_mod

In [None]:
new_obs

In [None]:
plt.figure()
plt.scatter(new_mod.values, new_obs.values, alpha=0.05)
plt.axis('square');

In [None]:
plt.figure()
plt.scatter(new_mod.isel(time=0).values, new_obs.isel(time=0).values, alpha=0.05)
plt.axis('square');

In [None]:
x1 = 60
y1 = 170
tO = c_obs_reg.sel(x=x1, y=y1)
tM = c_mod_reg.sel(x=x1, y=y1)
np.any(tO.notnull())
tO.plot(color='b')
tM.plot(color='r')

In [None]:
plt.figure()
c_mod.isel(x=x, y=y).plot()
c_obs.isel(x=x, y=y).plot()

In [None]:
client.close()