In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pointCollection as pc
import glob
import re
import os
tile_re=re.compile('E(.*)_N(.*).h5')

from LSsurf.fd_grid import fd_grid
from LSsurf.grid_functions import calc_cell_area, sum_cell_area, setup_mask,\
    setup_averaging_ops
import h5py


In [2]:
%matplotlib widget

In [3]:

args={'ctr':{'x':-1600000.0, 'y':-320000.0, 't':2020.},
      'W':{'x':6.e4, 'y':6.e4, 't':3},
      'spacing':{'z0':100, 'dz':1.e3, 'dt':0.25},
      'srs_proj4': '+proj=stere +lat_0=-90 +lat_ts=-71 +lon_0=0 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs',
    'mask_file':None,
    'avg_scales':[1.e4],
    'dzdt_lags':[1,2]
     }
pad=np.array([-5.e4, 5.e4])
args['mask_data']=pc.grid.data().from_h5('../../ATL1415/masks/Antarctic/Greene_22_shelf_plus_10m_mask.h5',
                                 bounds=[args['ctr']['x']+pad, args['ctr']['y']+pad],
                                bands=np.arange(17, 24))
# append two coppies of the last field in the mask data to the end of the mask data
args['mask_data'].z = np.concatenate([args['mask_data'].z,args['mask_data'].z[:,:,[-1]], args['mask_data'].z[:,:,[-1]]], axis=2)
args['mask_data'].t = np.concatenate([args['mask_data'].t,args['mask_data'].t[-1]+np.array([1, 2])], axis=0)
args['mask_data'].__update_size_and_shape__()


bds={ dim: c_i+np.array([-0.5, 0.5])*args['W'][dim]  for dim, c_i in args['ctr'].items()}

grids={}

In [4]:

z0_mask_data=None
if args['mask_data'] is not None:
    mask_file = None
    if len(args['mask_data'].shape)==3:
        z0_mask_data=args['mask_data'].copy()
        valid_t = (args['mask_data'].t >= bds['t'][0]) & (args['mask_data'].t < bds['t'][-1])
        z0_mask_data=pc.grid.data().from_dict({
            'x':args['mask_data'].x,
            'y':args['mask_data'].y,
            'z':np.sum(args['mask_data'].z[:,:,valid_t], axis=2)>0
            })
else:
    mask_file=args['mask_file']
grids['z0']=fd_grid( [bds['y'], bds['x']], args['spacing']['z0']*np.ones(2),\
                    name='z0', srs_proj4=args['srs_proj4'], mask_file=args['mask_file'],\
                    mask_data=z0_mask_data)

grids['dz']=fd_grid( [bds['y'], bds['x'], bds['t']], \
                    [args['spacing']['dz'], args['spacing']['dz'], args['spacing']['dt']], \
                    name='dz', col_0=grids['z0'].N_nodes, srs_proj4=args['srs_proj4'], \
                    mask_file=mask_file, mask_data=args['mask_data'])

grids['z0'].col_N=grids['dz'].col_N
grids['t']=fd_grid([bds['t']], [args['spacing']['dt']], name='t')
grids['z0'].cell_area=calc_cell_area(grids['z0'])


In [5]:

mask_data=args['mask_data']
if np.any(grids['dz'].delta[0:2]>grids['z0'].delta):
    if mask_data is not None and mask_data.t is not None and len(mask_data.t) > 1:
        # we have a time-dependent grid
        grids['dz'].cell_area = np.zeros(grids['dz'].shape)
        for t_ind, this_t in enumerate(grids['dz'].ctrs[2]):
            if this_t <= mask_data.t[0]:
                this_mask=mask_data[:,:,0]
            elif this_t >= mask_data.t[-1]:
                this_mask=mask_data[:,:,-1]
            else:
                # find the first time slice of mask_data that is gt this time
                i_t = np.argmin(mask_data.t < this_t)-1
                di = (this_t - mask_data.t[i_t])/(mask_data.t[i_t+1]-mask_data.t[i_t])
                this_mask = pc.grid.data().from_dict({'x':mask_data.x,
                                                      'y':mask_data.y,
                                                      'z':mask_data.z[:,:,i_t]*(1-di)+mask_data.z[:,:,i_t+1]*di})
                temp_grid = fd_grid( [bds['y'], bds['x']], args['spacing']['z0']*np.ones(2),\
                     name='z0', srs_proj4=args['srs_proj4'], \
                    mask_data=this_mask)
                grids['dz'].cell_area[:,:,t_ind] = sum_cell_area(temp_grid, grids['dz'])
    else:
        grids['dz'].cell_area=sum_cell_area(grids['z0'], grids['dz'])
else:
    grids['dz'].cell_area=calc_cell_area(grids['dz'])
# last-- multiply the z0 cell area by the z0 mask
if grids['z0'].mask is not None:
    grids['z0'].cell_area *= grids['z0'].mask



In [6]:
ops=setup_averaging_ops(grids['dz'], grids['z0'].col_N, args, grids['dz'].cell_area)


dzdt_lag1
dzdt_lag2


In [7]:
ops

{'dzdt_lag1': <LSsurf.lin_op.lin_op at 0x14bebb5b0>,
 'dzdt_lag2': <LSsurf.lin_op.lin_op at 0x14bebbf70>,
 'avg_dz_10000m': <LSsurf.lin_op.lin_op at 0x14bebb100>,
 'avg_dzdt_10000m_lag1': <LSsurf.lin_op.lin_op at 0x14bebb1f0>,
 'avg_dzdt_10000m_lag2': <LSsurf.lin_op.lin_op at 0x14bebb430>}

In [8]:
ops['avg_dz_10000m'].dst_grid.cell_area.shape

(5, 5, 13)

In [9]:
c_1km=ops['dzdt_lag1'].dst_grid.cell_area
ext_1km=np.r_[grids['dz'].bds[0]+[-500, 500], grids['dz'].bds[1]+[-500, 500]]
c_10km=ops['avg_dz_10000m'].dst_grid.cell_area
g_10km=ops['avg_dz_10000m'].dst_grid
ext_10km=np.r_[g_10km.bds[0]+[-500, 500], g_10km.bds[1]+[-500, 500]]

hf, hax=plt.subplots(12, 2, gridspec_kw={'wspace':0.01,'hspace':0.01}, figsize=(4,12), 
                     sharex=True, sharey=True)
for ti, ha in enumerate(hax):
    ha[0].imshow(c_1km[:,:,ti], origin='lower', clim=[0, 1.4e6], extent=ext_1km)
    ha[1].imshow(c_10km[:,:,ti], origin='lower', clim=[0, 1.4e8], extent=ext_10km)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
c_1km=ops['dzdt_lag1'].dst_grid.cell_area
#c_1km = grids['dz'].cell_area
ext_1km=np.r_[grids['dz'].bds[0]+[-500, 500], grids['dz'].bds[1]+[-500, 500]]
c_10km=ops['avg_dzdt_10000m_lag1'].dst_grid.cell_area.reshape(ops['avg_dzdt_10000m_lag1'].dst_grid.shape)
g_10km=ops['avg_dzdt_10000m_lag1'].dst_grid
ext_10km=np.r_[g_10km.bds[0]+[-500, 500], g_10km.bds[1]+[-500, 500]]

hf, hax=plt.subplots(12, 3, gridspec_kw={'wspace':0.01,'hspace':0.01}, figsize=(4,12), 
                     sharex=True, sharey=True)
for ti, ha in enumerate(hax):
    ha[0].imshow(grids['dz'].mask_3d.z[:,:,ti], origin='lower', clim=[0, 1.4], extent=ext_1km)  
    ha[1].imshow(ops['dzdt_lag1'].dst_grid.cell_area[:,:,ti], origin='lower', clim=[0, 1.4e6], extent=ext_1km)
    #ha[1].imshow(c_1km[:,:,ti], origin='lower', clim=[0, 1.4e6], extent=ext_1km)
    ha[2].imshow(c_10km[:,:,ti], origin='lower', clim=[0, 1.4e8], extent=ext_10km)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Show the input cells assocated with one output cell

In [11]:
op=ops['avg_dzdt_10000m_lag1'].toCSR()[:, grids['dz'].col_0:grids['dz'].col_N]
this_row=np.ravel_multi_index((0,3, 2), ops['avg_dzdt_10000m_lag1'].dst_grid.shape)

In [12]:
op_slice=op[this_row,:].toarray().reshape(grids['dz'].shape)
[ii,jj,kk]=np.where(op_slice)
np.unique(kk)
plt.figure(); plt.imshow(op_slice[:,:,kk[1]], origin='lower')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x14cd50280>

### Test the operator on simple inputs

The first input is all ones:

In [13]:
m_sub=np.ones(grids['dz'].shape)

m = np.zeros(ops['avg_dzdt_10000m_lag1'].col_N)
m[ops['avg_dzdt_10000m_lag1'].col_0:ops['avg_dzdt_10000m_lag1'].col_N] = m_sub.ravel()
dzdt_est=ops['avg_dzdt_10000m_lag1']\
    .toCSR()\
    .dot(m)\
    .reshape(ops['avg_dzdt_10000m_lag1'].dst_grid.shape)
hf, hax = plt.subplots(4, 4, gridspec_kw={'wspace':0.01,'hspace':0.01}, figsize=(8,8))
hax=hax.ravel()
for ii in range(dzdt_est.shape[2]):
    hax[ii].imshow(dzdt_est[:,:,ii], vmin=-5, vmax=5, origin='lower')
for ii in range(ii, len(hax)):
    hax[ii].set_visible(False)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

The second input has z0=0, $\delta z$=1 for the even epochs, $\delta z$=-1 for the odd epochs:

In [14]:
m_sub=np.zeros(grids['dz'].shape)
for i_t in range(m_sub.shape[2]):
    m_sub[:,:,i_t]=(-1)**i_t
m = np.zeros(ops['avg_dzdt_10000m_lag1'].col_N)
m[ops['avg_dzdt_10000m_lag1'].col_0:ops['avg_dzdt_10000m_lag1'].col_N] = m_sub.ravel()

In [15]:
dzdt_est=ops['avg_dzdt_10000m_lag1']\
    .toCSR()\
    .dot(m)\
    .reshape(ops['avg_dzdt_10000m_lag1'].dst_grid.shape)
hf, hax = plt.subplots(4, 4, gridspec_kw={'wspace':0.01,'hspace':0.01}, figsize=(8,8))
hax=hax.ravel()
for ii in range(dzdt_est.shape[2]):
    hax[ii].imshow(dzdt_est[:,:,ii], vmin=-5, vmax=5, origin='lower')
for ii in range(ii, len(hax)):
    hax[ii].set_visible(False)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …