In [None]:
import matplotlib.pyplot as plt
import numpy as np
import re
import glob
#%matplotlib widget
import xarray as xr
import datetime

In [None]:
%matplotlib widget

In [None]:
import pointCollection as pc

In [None]:
# ATL15 stores time in a format that is not at all confusing.  
def to_decimal_years(t):
    t0=datetime.datetime.fromisoformat('2018-01-01T00:00:00.000000')
    return (t-np.datetime64(t0)).astype(float)/24./3600./365.25/1.e9 + 2018
def from_decimal_years(y):
    t0=datetime.datetime.fromisoformat('2018-01-01T00:00:00.000000')
    return np.datetime64(t0)+np.int(y*24*3600*365.25*1.e9)

In [None]:
grounded_mask=pc.grid.data().from_geotif('../masks/Antarctic/scripps_antarctica_IceShelves1km_v1.tif')
grounded_mask.z=(grounded_mask.z==0).astype(float)

In [None]:
MOA=pc.grid.data().from_geotif('/Users/ben/data/MOA/2009/moa_2009_1km.tif')

In [None]:
nc_file='/users/ben/Downloads/ATL15_AA_0310_01km_001_01.nc'

In [None]:
# read the data, and turn it into a pc.data() instance
with xr.open_dataset(nc_file,group='/delta_h') as fh:
    dh=pc.grid.data().from_dict({'x':np.array(fh['x']),'y':np.array(fh['y']),
                             'dz':np.array(fh['delta_h']), 
                             'sigma_dz':np.array(fh['delta_h_sigma']),
                             't':to_decimal_years(fh['time']), 
                            'cell_area':np.array(fh['cell_area'])})
    print(fh)

In [None]:
dh.assign({'grounded':grounded_mask.interp(dh.x, dh.y, gridded=True)})

In [None]:
with xr.open_dataset(nc_file,group='/dhdt_lag8') as fh:
    dhdt=pc.grid.data().from_dict({'x':np.array(fh['x']),'y':np.array(fh['y']),
                             'dhdt':np.array(fh['dhdt']), 
                             'sigma_dhdt':np.array(fh['dhdt_sigma']),
                             't':to_decimal_years(fh['time'])})
dhdt_sigma=pc.grid.data().from_dict({'x':dh.x,'y':dh.y,'z':np.std(np.diff(dh.dz[2:, :, :], axis=2), axis=2)*4})

In [None]:
plt.figure();
MOA.show(cmap='gray', clim=[14000, 17000])
plt.imshow(dhdt.sigma_dhdt[1,:,:], alpha=0.6, clim=[0, 0.1], extent=dhdt.extent, origin='lower', interpolation='nearest')

In [None]:
plt.figure()
MOA.show(cmap='gray', clim=[14000, 17000])

plt.imshow(dhdt.dhdt[1,:,:], alpha=0.5, clim=[-0.5, 0.5], extent=dhdt.extent, origin='lower', interpolation='nearest', cmap='Spectral')

In [None]:
xg, yg=np.meshgrid(dh.x, dh.y)
lat_mask=np.abs(xg+1j*yg)>(2*np.pi/180*6370e3)

In [None]:
V=np.zeros(dh.dz.shape[0])
for ii in range(dh.dz.shape[0]):
    V[ii]=np.nansum(dh.dz[ii,:,:]*dh.cell_area*lat_mask*dh.grounded)

In [None]:
plt.figure();
plt.plot(dh.t, V/1.e9)

In [None]:
from scipy.interpolate import interpn

class dz_picker(object):
    def __init__(self, handles=None, field='dh', file_dict=None, dz_dict=None, file_args=None, W=2.e3, map_data=None, **map_args):

        self.dz_dict=dz_dict
        self.messages=[[]]
        self.last_pt=[[]]
        self.field=field
        
        if handles is not None and len(handles):
            self.handles=handles
        else:
            self.handles={}
            self.__init_new_ui__(map_data, map_args)
        
        self.file_dict=file_dict
        if file_args is None:
            self.file_args={}
        else:
            self.file_args=file_args
        self.dz_dict=dz_dict
        self.W=W
        self.last_data={}
        self.cid = self.handles['figure'].canvas.mpl_connect('button_press_event', self)
        
    def __init_new_ui__(self, map_data, map_args):
        if 'figure' not in self.handles:
            self.handles['figure']=plt.figure()
        if 'map_ax' not in self.handles:
            self.handles['map_ax'], self.handles['plot_ax']=\
            self.handles['figure'].subplots(1,2)
        if map_data is not None:
            map_data.show(ax=self.handles['map_ax'], **map_args)

            
    def __interp__(self, dz_dict, xy0):
        out={key:{} for key in dz_dict}
        for key, dz0 in dz_dict.items():
            dz=getattr(dz0, self.field)
            sigma_dz=None
            try:
                sigma_dz=getattr(dz0, 'sigma_'+field)
            except Exception:
                pass
            tt=dz0.t
            out[key]['t']=tt
            if dz0.shape==(dz0.y.size, dz0.x.size, dz0.t.size):
                self.messages += ['y, x, t']
                coords_out=(xy0[1]*np.ones_like(tt), xy0[0]*np.ones_like(tt), tt)
                coords_in=(dz0.y, dz0.x, dz0.t)
            else:
                self.messages += ['t, y, x']
                coords_out=(tt, xy0[1]*np.ones_like(tt), xy0[0]*np.ones_like(tt))
                coords_in=(dz0.t, dz0.y, dz0.x)
            out[key]['z']=interpn(coords_in, dz, coords_out)

            if 'sigma_'+self.field in dz0.fields:
                out[key]['sigma_z']=interpn(coords_in, getattr(dz0, 'sigma_'+self.field), coords_out)
            
        self.this_zi=out

    def __call__(self, event):
        try:
            if not event.inaxes in [self.handles['map_ax']]:
                self.messages += ['dz_picker: last point not in tiles axis']
            xy0=(event.xdata, event.ydata)
            self.last_pt += [xy0]
            tx = 'xy =[%f,%f]' % xy0
            self.handles['plot_ax'].set_title(tx)
            if self.dz_dict is not None:
                dz_dict=self.dz_dict
            elif self.file_dict is not None:
                dz_dict={}
                for key, file in self.file_dict.items():
                    pad=np.array([-0.5, 0.5])*self.W
                    dz_dict[key]=pc.grid.data().from_h5(file, bounds=[xy0[0]+pad, xy0[1]+pad], **self.file_args)
            self.__interp__(dz_dict, [event.xdata, event.ydata])
            for key, dzi in self.this_zi.items():
                self.messages+=['before line']
                self.messages += [key]
                h_line=self.handles['plot_ax'].plot(dzi['t'], dzi['z'], label=tx+' '+str(key))             
                if 'sigma_dz' in dzi:
                    color=h_line[0].get_color()
                    self.messages+=['before sigma']
                    for sign in [-1, 1]:
                        self.handles['plot_ax'].plot(dzi['t'], dzi['z']+sign*dzi['sigma_z'],'--', color=color)                   
            y_vals=np.r_[[item._y.ravel() for item in self.handles['plot_ax'].lines]].ravel()
            self.handles['plot_ax'].set_ylim([np.nanmin(y_vals), np.nanmax(y_vals)])
        except Exception as e:
            self.messages += [e]
            plt.gca().set_title('ERROR (see "messages" )')
        self.handles['plot_ax'].figure.canvas.draw()
    
    def clear_lines(self):
        lines=list(self.handles['plot_ax'].lines)
        for line_no in range(len(list(self.handles['plot_ax'].lines))):
            self.handles['plot_ax'].lines.pop(0)
        self.handles['plot_ax'].figure.canvas.draw()

In [None]:
dzp=dz_picker(dz_dict={'dz':dh}, field='dz', map_data=dhdt_sigma, cmap='magma', clim=[0, 1])
dzp.handles['map_ax'].set_facecolor('gray')


In [None]:
dzp.this_zi['dz']['z']
dzp.messages

In [None]:
dzp.clear_lines()

In [None]:
coords_out=(dh.t, np.zeros_like(dh.t)+dzp.last_pt[-1][1], np.zeros_like(dh.t)+dzp.last_pt[-1][0])
coords_in=(dh.t, dh.y, dh.x)
interpn(coords_in, np.array(dh.dz), coords_out)