In [None]:
import xarray as xr
import numpy as np
import scipy
from scipy import stats
from scipy.stats import chi2, ncx2
from matplotlib import pyplot as plt
# import matplotlib.cm as cm
import matplotlib.colors as plc

from depsi.network import form_network, arc_selection, remove_isolated_stm, _network_relation_matrix
from depsi.arc_estimation import periodogram
from depsi.mht_utils import pretest
from depsi.network import _mht_network_adjustment, _solve_float_ambiguities

# import dask
# dask.config.set(scheduler='processes')

In [None]:
WAVELENGTH = 0.055465763  # Sentinel-1, in meters

In [None]:
# Load all to memory
stm = xr.open_zarr('../../data/stm_amsterdam_173p.zarr')

# Remove the mother epoch
idx_non_mother = np.squeeze(np.where(stm['h2ph_values'].mean(axis=0).values != 0)) # Mother image is with all h2ph values as 0
stm = stm.isel(time = idx_non_mother)

# For debugging, shorten the time series
stm = stm.isel(time=slice(0, 30))
stm

In [None]:
stm_arcs = form_network(stm, 'sd_phase', 'h2ph_values', 'years', max_length=0.001)
stm_arcs

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs['source'])['lon'].values,
               stm.isel(space = stm_arcs['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs['source'])['lat'].values,
               stm.isel(space = stm_arcs['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
for i in range(stm_arcs.sizes["space"]):
    ax.plot(xx[i], yy[i], color='b', linewidth=0.5)

In [None]:
phs_obs_unwrapped, ambigs, arc_height, arc_velo, ens_coh = periodogram(
    stm_arcs,
    "d_phase",
    "h2ph",
    "Btemp",
    wavelength=WAVELENGTH,
)

stm_arcs["phs_obs_unwrapped"] = phs_obs_unwrapped
stm_arcs["ambigs"] = ambigs
stm_arcs["arc_height"] = arc_height
stm_arcs["arc_velo"] = arc_velo
stm_arcs["ens_coh"] = ens_coh

stm_arcs = stm_arcs.compute()
stm_arcs

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs['source'])['lon'].values,
               stm.isel(space = stm_arcs['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs['source'])['lat'].values,
               stm.isel(space = stm_arcs['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=1.0)
mean_nmad = np.abs(stm_arcs['ens_coh'].data)
for i in range(stm_arcs.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(norm(mean_nmad[i])), linewidth=0.5)
plt.title("Network STM arcs, colored by ens_coh")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='ens_coh')

In [None]:
# select arcs with ens_coh > 0.5
stm_arcs_selected = arc_selection(stm_arcs, 0.75, 'ens_coh', 3)

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs_selected['source'])['lon'].values,
               stm.isel(space = stm_arcs_selected['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs_selected['source'])['lat'].values,
               stm.isel(space = stm_arcs_selected['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=1.0)
mean_nmad = np.abs(stm_arcs_selected['ens_coh'].data)
for i in range(stm_arcs_selected.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(norm(mean_nmad[i])), linewidth=0.5)
plt.title("Network STM arcs, colored by ens_coh")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='ens_coh')

In [None]:
# Update the STM and selected arcs
# Remove isolated STM
# Update indices
stm_updated, stm_arcs_selected = remove_isolated_stm(stm, stm_arcs_selected)

## Network integration

### Initial adjustment

In [None]:
# Netork relation matrix A
A_sparse = _network_relation_matrix(stm_arcs_selected["source"], stm_arcs_selected["target"], stm_updated.sizes["space"])

# Observations y
y = stm_arcs_selected['ambigs'].data

# Prepare stochastic model
N_arcs = stm_arcs_selected.sizes["space"]
Qyy_diag = np.ones(N_arcs)
invQy = scipy.sparse.diags(1/Qyy_diag, 0, shape=(N_arcs, N_arcs))

# First estimation
_, echeck = _solve_float_ambiguities(A_sparse, y, invQy)

In [None]:
# Overall Model Test
kOMT = 1e-10 # threshold
OMT = (echeck.T @ invQy @ echeck).diagonal().sum()
print("OMT:", OMT)
print("OMT > kOMT:", OMT > kOMT) # if True, fail, need adjustment

In [None]:
# Setup tests
a0 = 0.1
g0 = 0.5
max_con = np.abs(A_sparse).sum(axis=0).max()

kb_dict = {}

for n_con in range(1, max_con+1):
    _, k1, kb, _ = pretest(n_con, a0, g0)
    kb_dict[n_con] = kb

In [None]:
def visualize_network_echeck(pnt, arcs, echeck):
    # Plot echeck for all arcs
    # x of sources and targets
    xx = np.stack([pnt.isel(space = arcs['source'])['lon'].values,
                pnt.isel(space = arcs['target'])['lon'].values]).T
    # y of sources and targets
    yy = np.stack([pnt.isel(space = arcs['source'])['lat'].values,
                pnt.isel(space = arcs['target'])['lat'].values]).T
    # Visualize created arcs
    fig, ax = plt.subplots()
    cmap = plt.cm.rainbow
    norm = plc.Normalize(vmin=0, vmax=10.0)
    echeck_sum = echeck.sum(axis=1)
    for i in range(arcs.sizes["space"]):
        ax.plot(xx[i], yy[i], color=cmap(echeck_sum[i]), linewidth=0.5)
    plt.title("Network STM arcs, colored by echeck sum")
    plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='echeck sum')

In [None]:
niter = 0
max_iter = 10
while OMT > kOMT and niter < max_iter:
    # In the loop, OMT fail
    # Choose from two Ha: 1) remove an arc; 2) remove a point
    flag_rm, idx_rm = _mht_network_adjustment(A_sparse, y, Qyy_diag, k1, kb_dict)

    if flag_rm == 0:  # remove arcs
        stm_arcs_selected = stm_arcs_selected.drop_isel(space=idx_rm)  # Remove the arc
        stm_updated, stm_arcs_selected = remove_isolated_stm(
            stm_updated, stm_arcs_selected
        )  # Remove isolated points and update idx
        Qyy_diag = np.delete(Qyy_diag, idx_rm)  # Remove the corresponding Qyy entry
    elif flag_rm == 1:  # remove points
        # Remove all arcs connects to the point to remove
        idx_arcs_selected = np.where(
            (
                (stm_arcs_selected["source"] != idx_rm)
                & (stm_arcs_selected["target"] != idx_rm)
            ).data
        )[0]  # Arc indices in the previous arc stm for selection
        stm_arcs_selected = stm_arcs_selected.isel(space=idx_arcs_selected)

        # Remove isolated points and update idx
        stm_updated, stm_arcs_selected = remove_isolated_stm(
            stm_updated, stm_arcs_selected
        )
        Qyy_diag = Qyy_diag[idx_arcs_selected]  # Select the corresponding Qyy entries

    # Update the functional and stochastic model
    N_arcs = stm_arcs_selected.sizes["space"]
    A_sparse = _network_relation_matrix(
        stm_arcs_selected["source"],
        stm_arcs_selected["target"],
        stm_updated.sizes["space"],
    )
    y = stm_arcs_selected["ambigs"].data
    invQy = scipy.sparse.diags(1 / Qyy_diag, 0, shape=(N_arcs, N_arcs))

    # Estimate residual again
    _, echeck = _solve_float_ambiguities(A_sparse, y, invQy)

    OMT = (echeck.T @ invQy @ echeck).diagonal().sum()

    niter += 1
    print(
        f"Iteration {niter}: OMT = {OMT}, flag_rm: {flag_rm}, OMT > kOMT: {OMT > kOMT}"
    )

    visualize_network_echeck(stm_updated, stm_arcs_selected, echeck)


In [None]:
echeck