In [None]:
import matplotlib.pyplot as plt
import numpy as np
import glob
import h5py
import os
import re
import pointCollection as pc
tile_re=re.compile('E(.*)_N(.*).h5')


In [None]:
%matplotlib widget

In [None]:
import time
import matplotlib.pyplot as plt
import numpy as np
import glob
import h5py
import os
import re
import pointCollection as pc

class tile_picker(object):
    def __init__(self, thedir, handles=None, W=8.e4, map_data=None, **map_args):
        
        tile_re=re.compile('E(.*)_N(.*).h5')
        self.xy_file_dict = {tuple(1000*np.array([*map(int, tile_re.search(ff).groups())])):ff 
                             for ff in glob.glob(thedir+'/*.h5') }
        self.xy_tiles = np.array(list(self.xy_file_dict.keys()))
        
        if handles is not None and len(handles):
            self.handles=handles
        else:
            self.handles={}
            self.__init_new_ui__(map_data, map_args)
        self.messages=[[]]
        self.last_pt=[[]]
        self.last_file=''
        self.dz_picker=None
        self.last_click_time=0.0
        self.max_click_time = 0.1

        self.W=W

        self.cid=self.handles['figure'].canvas.mpl_connect('button_press_event', self.buttondown)
        self.cid=self.handles['figure'].canvas.mpl_connect('button_release_event', self.buttonup)
    
    def __init_new_ui__(self, map_data, map_args):
        if 'figure' not in self.handles:
            self.handles['figure']=plt.figure()
        if 'tiles_ax' not in self.handles:
            self.handles['tiles_ax'], self.handles['messages']=\
            self.handles['figure'].subplots(1,2)
        if map_data is not None:
            map_data.show(ax=self.handles['tiles_ax'], **map_args)
        self.handles['tiles_ax'].plot(self.xy_tiles[:,0], self.xy_tiles[:,1],'k.')
            
    def buttondown(self, event):
        if not event.inaxes in [self.handles['tiles_ax']]:
            return
        self.last_click_time=time.time()
    def buttonup(self, event):
        try:
            if not event.inaxes in [self.handles['tiles_ax']]:
                self.messages += ['tile_picker: last point not in tiles axis']
                return
            dt_click = time.time()-self.last_click_time
            if time.time()-self.last_click_time > self.max_click_time:
                self.messages += [f'too much time has elapsed : {dt_click}']
                return
            xy0=(event.xdata, event.ydata)
            xy_tile = tuple((np.round(np.array(xy0)/(self.W/2))*self.W/2).astype(int))
            self.messages = [f'xy0={xy0}, xy_tile={xy_tile}']
            if xy_tile not in self.xy_file_dict:
                self.messages += [f'searching by dist for {xy0}']
                this = np.argmin((self.xy_tiles[:,0]-xy0[0])**2 + (self.xy_tiles[:,1]-xy0[1])**2)
                xy_tile = tuple(self.xy_tiles[this,:]) 
            self.last_file=self.xy_file_dict[xy_tile]
            self.handles['tiles_ax'].plot(xy0[0], xy0[1],'x')
            self.handles['tiles_ax'].plot(xy_tile[0], xy_tile[1],'r.')

        except Exception as e:
            self.messages += [e]
            self.handles['tiles_ax'].set_title('ERROR')

In [None]:
thedir='/att/nobackup/project/icesat-2/ATL14_processing/rel001/north/CS'

In [None]:
nc_file='/att/nobackup/project/icesat-2/ATL14_processing/rel001/north/CS/ATL14_CS_0310_100m_001_01.nc'
bounds={'lat':[], 'lon':[]}

with h5py.File(nc_file,'r') as h5f:
    h_fill=np.float((h5f['h'].attrs['_FillValue']))
    #print(h5f.keys())

In [None]:
D=pc.grid.data().from_h5(nc_file, fields=['h','h_sigma','ice_mask'])

In [None]:
D.h[D.h==h_fill]=np.NaN
D.ice_mask=D.ice_mask.astype(np.float)
D.ice_mask[D.ice_mask==127]=np.NaN

In [None]:
tp=tile_picker(os.path.dirname(nc_file)+'/*/', map_data=D, field='ice_mask')

In [None]:

z0=pc.grid.data().from_h5(tp.last_file, group='z0')
hf, hax=plt.subplots(1,2, sharex=True, sharey=True)
z0.show(field='mask', ax=hax[0])
z0.show(field='count', ax=hax[1])

In [None]:
plt.figure()
plt.imshow(z0.cell_area)

In [None]:
plt.figure(); 
z0.z0[~np.isfinite(z0.mask)]=np.NaN
z0.z0[z0.mask==0]=np.NaN
z0.show(field='z0')
plt.colorbar()

In [None]:
from scipy.interpolate import interpn
class dz_picker(object):
    def __init__(self, handles, file_dict=None, dz_dict=None, file_args=None, W=2.e3):
        self.handles=handles
        self.dz_dict=dz_dict
        self.messages=[[]]
        self.last_pt=[[]]
        self.file_dict=file_dict
        if file_args is None:
            self.file_args={}
        else:
            self.file_args=file_args
        self.dz_dict=dz_dict
        self.W=W
        self.cid = self.handles['figure'].canvas.mpl_connect('button_press_event', self)
    
    def __call__(self, event):
        try:
            if not event.inaxes in [self.handles['map_ax']]:
                self.messages += ['dz_picker: last point not in tiles axis']
            xy0=(event.xdata, event.ydata)
            self.last_pt += [xy0]
            tx = 'xy =[%f,%f]' % xy0
            self.handles['plot_ax'].set_title(tx)
            if self.dz_dict is not None:
                dz_dict=self.dz_dict
            elif self.file_dict is not None:
                dz_dict={}
                for key, file in self.file_dict.items():
                    pad=np.array([-0.5, 0.5])*self.W
                    dz_dict[key]=pc.grid.data().from_h5(file, bounds=[xy0[0]+pad, xy0[1]+pad], **self.file_args)
            for key, dz0 in dz_dict.items():
                tt=dz0.t
                
                zz=interpn((dz0.y, dz0.x, dz0.t), dz0.dz, (event.ydata*np.ones_like(tt), event.xdata*np.ones_like(tt), tt))
                h_line=self.handles['plot_ax'].plot(tt, zz, label=tx+' '+str(key))
                
                if 'sigma_dz' in dz0.fields:
                    szz=interpn((dz0.y, dz0.x, dz0.t), dz0.sigma_dz, (event.ydata*np.ones_like(tt), event.xdata*np.ones_like(tt), tt))
                    color=h_line[0].get_color()
                    for sign in [-1, 1]:
                        self.handles['plot_ax'].plot(tt, zz+sign*szz,'--', color=color)                   
            y_vals=np.r_[[item._y.ravel() for item in self.handles['plot_ax'].lines]].ravel()
            self.handles['plot_ax'].set_ylim([np.nanmin(y_vals), np.nanmax(y_vals)])
        except Exception as e:
            self.messages += [e]
            plt.gca().set_title('ERROR')
        self.handles['plot_ax'].figure.canvas.draw()
    
    def clear_lines(self):
        lines=list(self.handles['plot_ax'].lines)
        for line_no in range(len(list(self.handles['plot_ax'].lines))):
            self.handles['plot_ax'].lines.pop(0)
        self.handles['plot_ax'].figure.canvas.draw()

In [None]:
dz=pc.grid.data().from_h5(tp.last_file, group='dz')
z0=pc.grid.data().from_h5(tp.last_file, group='z0')


In [None]:
fig, hax=plt.subplots(1,2)

for tslice in range(dz.dz.shape[2]):
    dz.dz[:,:,tslice][dz.cell_area==0]=np.NaN
    dz.sigma_dz[:,:,tslice][dz.cell_area==0]=np.NaN

z0.z0[z0.cell_area==0]=np.NaN

z0.show(ax=hax[0], gradient=True, cmap='gray', field='z0', clim=[-0.1, 0.1], interpolation='nearest')
hax[0].imshow(np.std(np.diff(dz.dz, axis=2), axis=2), clim=[0, 0.5], alpha=0.3, extent=dz.extent, origin='lower')
dzp=dz_picker({'figure':fig,'map_ax':hax[0], 'plot_ax':hax[1]}, dz_dict={tp.last_file:dz})




In [None]:
dzp.messages

In [None]:
dzp.clear_lines()

In [None]:
D=pc.data().from_h5(tp.last_file, group='data')

In [None]:
ii=D.cycle==1
hax[0].plot(D.x[ii], D.y[ii],'b.')
ii &= (D.three_sigma_edit==1)
hax[0].plot(D.x[ii], D.y[ii],'r.')


In [None]:
np.sum(ii)

In [None]:
from scipy.ndimage import label

In [None]:
plt.figure(); plt.imshow(islets[0], cmap='jet', origin='lower')

In [None]:
from scipy.ndimage import label

components, n_components = label(dz.cell_area>0)
first_epoch=np.zeros(n_components, dtype=int)+n_components
last_epoch=np.zeros(n_components, dtype=int)

for comp in range(1, n_components):
    these = components==comp
    for t_slice in range(dz.shape[2]):
        sampled=np.any(dz.count[:,:,t_slice][these]>1)
        if t_slice <= first_epoch[comp]:
            if sampled:
                first_epoch[comp]=t_slice
        if t_slice >= last_epoch[comp]:
            if sampled:
                last_epoch[comp]=t_slice

last_epoch_map=np.zeros_like(dz.cell_area)+np.NaN
first_epoch_map=np.zeros_like(dz.cell_area)+np.NaN

for comp in range(1, n_components):
    last_epoch_map[components==comp]=last_epoch[comp]
    first_epoch_map[components==comp]=first_epoch[comp]

for t_slice in range(dz.dz.shape[2]):
    dz.dz[:,:,t_slice][t_slice < first_epoch_map]=np.NaN
    dz.dz[:,:,t_slice][t_slice > last_epoch_map]=np.NaN

In [None]:
last_epoch_map=np.zeros_like(dz.cell_area)+np.NaN
first_epoch_map=np.zeros_like(dz.cell_area)+np.NaN

for comp in range(1, n_components):
    last_epoch_map[components==comp]=last_epoch[comp]
    first_epoch_map[components==comp]=first_epoch[comp]

In [None]:
for t_slice in range(dz.dz.shape[2]):
    dz.dz[:,:,t_slice][t_slice < first_epoch_map]=np.NaN
    dz.dz[:,:,t_slice][t_slice > last_epoch_map]=np.NaN

In [None]:
#plt.figure(); plt.imshow(first_epoch_map, origin='lower')
plt.figure(); plt.imshow(z0.sigma_z0, origin='lower', clim=[0, 10])
#plt.figure(); plt.imshow(dz.dz[:,:,0], origin='lower')

In [None]:
t_slice=0
dz.dz[:,:,t_slice][t_slice < first_epoch_map]

In [None]:
plt.figure(); plt.imshow(dz.sigma_dz[:,:,0], clim=[0, 5])

In [None]:
plt.figure(); plt.plot(z0.y, z0.sigma_z0,'.')