In [None]:
import xarray as xr
import numpy as np
import scipy
from scipy import stats
from scipy.stats import chi2, ncx2
from matplotlib import pyplot as plt
# import matplotlib.cm as cm
import matplotlib.colors as plc

from depsi.network import form_network, arc_selection, remove_isolated_stm, _network_relation_matrix
from depsi.arc_estimation import periodogram
from depsi.mht_utils import pretest

# import dask
# dask.config.set(scheduler='processes')

In [None]:
WAVELENGTH = 0.055465763  # Sentinel-1, in meters

In [None]:
# Load all to memory
stm = xr.open_zarr('../../data/stm_amsterdam_173p.zarr')

# Remove the mother epoch
idx_non_mother = np.squeeze(np.where(stm['h2ph_values'].mean(axis=0).values != 0)) # Mother image is with all h2ph values as 0
stm = stm.isel(time = idx_non_mother)

# For debugging, shorten the time series
stm = stm.isel(time=slice(0, 30))
stm

In [None]:
stm_arcs = form_network(stm, 'sd_phase', 'h2ph_values', 'years', max_length=0.001)
stm_arcs

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs['source'])['lon'].values,
               stm.isel(space = stm_arcs['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs['source'])['lat'].values,
               stm.isel(space = stm_arcs['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
for i in range(stm_arcs.sizes["space"]):
    ax.plot(xx[i], yy[i], color='b', linewidth=0.5)

In [None]:
phs_obs_unwrapped, ambigs, arc_height, arc_velo, ens_coh = periodogram(
    stm_arcs,
    "d_phase",
    "h2ph",
    "Btemp",
    wavelength=WAVELENGTH,
)

stm_arcs["phs_obs_unwrapped"] = phs_obs_unwrapped
stm_arcs["ambigs"] = ambigs
stm_arcs["arc_height"] = arc_height
stm_arcs["arc_velo"] = arc_velo
stm_arcs["ens_coh"] = ens_coh

stm_arcs = stm_arcs.compute()
stm_arcs

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs['source'])['lon'].values,
               stm.isel(space = stm_arcs['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs['source'])['lat'].values,
               stm.isel(space = stm_arcs['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=1.0)
mean_nmad = np.abs(stm_arcs['ens_coh'].data)
for i in range(stm_arcs.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(norm(mean_nmad[i])), linewidth=0.5)
plt.title("Network STM arcs, colored by ens_coh")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='ens_coh')

In [None]:
# select arcs with ens_coh > 0.5
stm_arcs_selected = arc_selection(stm_arcs, 0.75, 'ens_coh', 3)

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs_selected['source'])['lon'].values,
               stm.isel(space = stm_arcs_selected['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs_selected['source'])['lat'].values,
               stm.isel(space = stm_arcs_selected['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=1.0)
mean_nmad = np.abs(stm_arcs_selected['ens_coh'].data)
for i in range(stm_arcs_selected.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(norm(mean_nmad[i])), linewidth=0.5)
plt.title("Network STM arcs, colored by ens_coh")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='ens_coh')

In [None]:
# Update the STM and selected arcs
# Remove isolated STM
# Update indices
stm_updated, stm_arcs_selected = remove_isolated_stm(stm, stm_arcs_selected)

In [None]:
stm_updated

In [None]:
stm_arcs_selected

In [None]:
# x of sources and targets
xx = np.stack([stm_updated.isel(space = stm_arcs_selected['source'])['lon'].values,
               stm_updated.isel(space = stm_arcs_selected['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm_updated.isel(space = stm_arcs_selected['source'])['lat'].values,
               stm_updated.isel(space = stm_arcs_selected['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=1.0)
mean_nmad = np.abs(stm_arcs_selected['ens_coh'].data)
for i in range(stm_arcs_selected.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(norm(mean_nmad[i])), linewidth=0.5)
plt.title("Network STM arcs, colored by ens_coh")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='ens_coh')

## Network integration

### Initial adjustment

In [None]:
# Netork relation matrix A
A_sparse = _network_relation_matrix(stm_arcs_selected["source"], stm_arcs_selected["target"], stm_updated.sizes["space"])
A_sparse = A_sparse.tocsr()
A_sparse

In [None]:
# Prepare stochastic model
N_arcs = stm_arcs_selected.sizes["space"]
Qyy_diag = np.ones(N_arcs)
invQy = scipy.sparse.diags(1/Qyy_diag, 0, shape=(N_arcs, N_arcs))
invQy

In [None]:
y = stm_arcs_selected['ambigs'].data
y.shape

In [None]:
# Least square function for sparse data
@np.vectorize(signature="(i)->(j)")
def lsmr(y):
    x, *_ = scipy.sparse.linalg.lsmr(invQy @ A_sparse, y)
    return x

In [None]:
# first estimation
acheck = lsmr(y.T).T # estimated point ambiguity
ycheck  = A_sparse @ acheck # estimated arc ambiguit
echeck  = y - ycheck # residual arc ambiguity

In [None]:
# Plot echeck for all arcs
# x of sources and targets
xx = np.stack([stm_updated.isel(space = stm_arcs_selected['source'])['lon'].values,
               stm_updated.isel(space = stm_arcs_selected['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm_updated.isel(space = stm_arcs_selected['source'])['lat'].values,
               stm_updated.isel(space = stm_arcs_selected['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=10.0)
echeck_sum = echeck.sum(axis=1)
for i in range(stm_arcs_selected.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(echeck_sum[i]), linewidth=0.5)
plt.title("Network STM arcs, colored by echeck sum")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='echeck sum')

In [None]:
# visualize residuals
fig, axes = plt.subplots(1,2)
fig.set_size_inches(12, 6)
axes[0].imshow(echeck, aspect="auto", cmap="viridis")
fig.colorbar(axes[0].imshow(echeck, aspect="auto", cmap="viridis"), ax=axes[0])
axes[1].hist(echeck.flatten(), bins=20)

### Overall Model Test

In [None]:
# Decision on if the Overall Model Test is passed
kOMT = 1e-10 # threshold
OMT = (echeck.T @ invQy @ echeck).diagonal().sum()
print("OMT:", OMT)
print("OMT > kOMT:", OMT > kOMT) # if True, fail, need adjustment

### Compute test statistics for rejecting arcs (TT1) or points (TTq)

In [None]:
# Setup tests
a0 = 0.1
g0 = 0.5
max_con = np.abs(A_sparse).sum(axis=0).max()

kb_dict = {}
ab_dict = {}

for n_con in range(1, max_con+1):
    lam0,k1,kb,ab = pretest(n_con, a0, g0)
    kb_dict[n_con] = kb
    ab_dict[n_con] = ab

In [None]:
# Compute post-poriori Qecheck
Qyy = scipy.sparse.diags(Qyy_diag, 0, shape=(N_arcs, N_arcs))
invQy = scipy.sparse.diags(1/Qyy_diag, 0, shape=(N_arcs, N_arcs))
Qxx = np.linalg.inv((A_sparse.T @ invQy @ A_sparse).todense())
Qecheck = Qyy.diagonal() - (A_sparse @ Qxx @ A_sparse.T)
Qecheck_diag = Qecheck.diagonal()
Qecheck_diag = np.array(Qecheck_diag.flatten()).squeeze()
Qecheck_diag

In [None]:
# Compute TT1
N_epochs = stm_updated.sizes["time"]
w = echeck**2 / np.tile(np.abs(Qecheck_diag), (N_epochs, 1)).T
TT1 = np.sum(w, axis=0) / k1 **2
TT1

In [None]:
# Compute TTq
TTq = np.zeros(stm_updated.sizes["space"])
for pnt_idx in range(stm_updated.sizes["space"]):
    # Get arcs connected to the point
    arcs_idx = np.where((stm_arcs_selected['source'] == pnt_idx) | (stm_arcs_selected['target'] == pnt_idx))[0]

    # Drop one arc to create basis, see e.g. verhoef97
    arcs_idx = arcs_idx[1:]

    # Get relevant echeck of this point
    echeck_point = echeck[arcs_idx, :]

    # Get relevant Qecheck_diag of this point
    Qecheck_point = Qecheck[arcs_idx, :][:, arcs_idx]

    # Compute the test statistic for this point
    Tq =  np.sum(np.abs((echeck_point.T @ np.linalg.inv(Qecheck_point) @ echeck_point).diagonal())) ## ??? abs taken correct?

    TTq[pnt_idx] = Tq/kb_dict[len(arcs_idx)]

In [None]:
# N = A_sparse.T @ A_sparse
# R = np.linalg.cholesky(N.todense())
# Rinv = np.linalg.inv(R)
# H0 = A_sparse*Rinv

# Qyy_point = np.diag(Qyy_diag[arcs_idx])
# Qee_point = Qyy_point - (H0[arcs_idx, :] @ H0[arcs_idx, :].T)
# Qee_point

In [None]:
# Print as decesion for rejecting arcs or points
# When max(TT1) > max(TTq), remove arcs
# When max(TTq) > max(TT1), remove points
print(TT1.shape) # shape over time
print(TTq.shape) # shape over points
print(max(TT1)) 
print(max(TTq))

### Reject points

In [None]:
# Get max and index of TT1
max_TTq = max(TTq)
idx_pnt_remove = np.argmax(TTq)
idx_pnt_remove

In [None]:
# Remove arcs connected to the point
stm_arcs_selected2 = stm_arcs_selected.where(
    (stm_arcs_selected["source"] != idx_pnt_remove)
    & (stm_arcs_selected["target"] != idx_pnt_remove),
    drop=True,
)
stm_arcs_selected2

In [None]:
# Update STM and arc idx
stm_updated2, stm_arcs_selected2 = remove_isolated_stm(stm_updated, stm_arcs_selected2)
stm_updated2

In [None]:
# Netork relation matrix A
A_sparse2 = _network_relation_matrix(stm_arcs_selected2["source"], stm_arcs_selected2["target"], stm_updated2.sizes["space"])
A_sparse2 = A_sparse2.tocsr()

N_arcs = stm_arcs_selected2.sizes["space"]
Qyy_diag = np.ones(N_arcs)
invQy = scipy.sparse.diags(1/Qyy_diag, 0, shape=(N_arcs, N_arcs))
invQy

y = stm_arcs_selected2['ambigs'].data
# Least square function for sparse data
@np.vectorize(signature="(i)->(j)")
def lsmr(y):
    x, *_ = scipy.sparse.linalg.lsmr(invQy @ A_sparse2, y)
    return x
acheck = lsmr(y.T).T # estimated point ambiguity
ycheck  = A_sparse2 @ acheck # estimated arc ambiguit
echeck  = y - ycheck # residual arc ambiguity

In [None]:
# Plot echeck for all arcs
# x of sources and targets
xx = np.stack([stm_updated.isel(space = stm_arcs_selected['source'])['lon'].values,
               stm_updated.isel(space = stm_arcs_selected['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm_updated.isel(space = stm_arcs_selected['source'])['lat'].values,
               stm_updated.isel(space = stm_arcs_selected['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=10.0)
echeck_sum = echeck.sum(axis=1)
for i in range(stm_arcs_selected2.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(echeck_sum[i]), linewidth=0.5)
plt.title("Network STM arcs, colored by echeck sum")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='echeck sum')

In [None]:
# Decision on if the Overall Model Test is passed
kOMT = 1e-10 # threshold
OMT = (echeck.T @ invQy @ echeck).diagonal().sum()
print("OMT:", OMT)
print("OMT > kOMT:", OMT > kOMT) # if True, fail, need adjustment

### Reject arcs

In [None]:
# Update statistics based on pnt rejection step

# Compute post-poriori Qecheck
Qyy = scipy.sparse.diags(Qyy_diag, 0, shape=(N_arcs, N_arcs))
invQy = scipy.sparse.diags(1/Qyy_diag, 0, shape=(N_arcs, N_arcs))
Qxx = np.linalg.inv((A_sparse2.T @ invQy @ A_sparse2).todense())
Qecheck = Qyy.diagonal() - (A_sparse2 @ Qxx @ A_sparse2.T)
Qecheck_diag = Qecheck.diagonal()
Qecheck_diag = np.array(Qecheck_diag.flatten()).squeeze()


# Compute TT1
N_epochs = stm_updated.sizes["time"]
w = echeck**2 / np.tile(np.abs(Qecheck_diag), (N_epochs, 1)).T
TT1 = np.sum(w, axis=0) / k1 **2


# Compute TTq
TTq = np.zeros(stm_updated2.sizes["space"])
for pnt_idx in range(stm_updated2.sizes["space"]):
    # Get arcs connected to the point
    arcs_idx = np.where((stm_arcs_selected2['source'] == pnt_idx) | (stm_arcs_selected2['target'] == pnt_idx))[0]

    # Drop one arc to create basis, see e.g. verhoef97
    arcs_idx = arcs_idx[1:]

    # Get relevant echeck of this point
    echeck_point = echeck[arcs_idx, :]

    # Get relevant Qecheck_diag of this point
    Qecheck_point = Qecheck[arcs_idx, :][:, arcs_idx]

    # Compute the test statistic for this point
    Tq =  np.sum(np.abs((echeck_point.T @ np.linalg.inv(Qecheck_point) @ echeck_point).diagonal())) ## ??? abs taken correct?

    TTq[pnt_idx] = Tq/kb_dict[len(arcs_idx)]

# Although TTq > TT1 but let's still reject an arc
print(TT1.shape) # shape over time
print(TTq.shape) # shape over points
print(max(TT1)) 
print(max(TTq))

In [None]:
max_TT1 = max(TT1)
idx_arc_remove = np.argmax(TT1)
idx_arc_remove

In [None]:
stm_arcs_selected2

In [None]:
# Remove arcs with idx idx_arc_remove
stm_arcs_selected3 = stm_arcs_selected2.drop_isel(space=idx_arc_remove)

stm_updated3, stm_arcs_selected3 = remove_isolated_stm(stm_updated, stm_arcs_selected3)
stm_updated3

In [None]:
y.shape

In [None]:
# Network relation matrix A
A_sparse3 = _network_relation_matrix(stm_arcs_selected3["source"], stm_arcs_selected3["target"], stm_updated3.sizes["space"])
A_sparse3 = A_sparse3.tocsr()

N_arcs = stm_arcs_selected3.sizes["space"]
Qyy_diag = np.ones(N_arcs)
invQy = scipy.sparse.diags(1/Qyy_diag, 0, shape=(N_arcs, N_arcs))
invQy

y = stm_arcs_selected3['ambigs'].data
# Least square function for sparse data
@np.vectorize(signature="(i)->(j)")
def lsmr(y):
    x, *_ = scipy.sparse.linalg.lsmr(invQy @ A_sparse3, y)
    return x
acheck = lsmr(y.T).T # estimated point ambiguity
ycheck  = A_sparse3 @ acheck # estimated arc ambiguity
echeck  = y - ycheck # residual arc ambiguity

In [None]:
# Plot echeck for all arcs
# x of sources and targets
xx = np.stack([stm_updated.isel(space = stm_arcs_selected['source'])['lon'].values,
               stm_updated.isel(space = stm_arcs_selected['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm_updated.isel(space = stm_arcs_selected['source'])['lat'].values,
               stm_updated.isel(space = stm_arcs_selected['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=10.0)
echeck_sum = echeck.sum(axis=1)
for i in range(stm_arcs_selected3.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(echeck_sum[i]), linewidth=0.5)
plt.title("Network STM arcs, colored by echeck sum")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='echeck sum')

In [None]:
# Decision on if the Overall Model Test is passed
kOMT = 1e-10 # threshold
OMT = (echeck.T @ invQy @ echeck).diagonal().sum()
print("OMT:", OMT)
print("OMT > kOMT:", OMT > kOMT) # if True, fail, need adjustment