In [None]:
import xarray as xr
import numpy as np
import scipy
from scipy import stats
from scipy.stats import chi2, ncx2
from matplotlib import pyplot as plt
# import matplotlib.cm as cm
import matplotlib.colors as plc

from depsi.network import form_network, arc_selection, remove_network_points_min_connections, _network_relation_matrix
from depsi.arc_estimation import periodogram
from depsi.mht_utils import pretest
from depsi.network import _mht_network_adjustment, _solve_float_ambiguities

# import dask
# dask.config.set(scheduler='processes')

In [None]:
WAVELENGTH = 0.055465763  # Sentinel-1, in meters

In [None]:
# Load all to memory
stm = xr.open_zarr('../../data/stm_amsterdam_173p.zarr')

# Remove the mother epoch
idx_non_mother = np.squeeze(np.where(stm['h2ph_values'].mean(axis=0).values != 0)) # Mother image is with all h2ph values as 0
stm = stm.isel(time = idx_non_mother)

# For debugging, shorten the time series
stm = stm.isel(time=slice(0, 30))
stm

## Step1: form network and remove arcs with low ensemble coherence

In [None]:
stm_arcs = form_network(stm, 'sd_phase', 'h2ph_values', 'years', max_length=0.001)
stm_arcs

In [None]:
np.unique(stm_arcs['uid'].values)

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs['source'])['lon'].values,
               stm.isel(space = stm_arcs['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs['source'])['lat'].values,
               stm.isel(space = stm_arcs['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
for i in range(stm_arcs.sizes["space"]):
    ax.plot(xx[i], yy[i], color='b', linewidth=0.5)

In [None]:
phs_obs_unwrapped, ambigs, arc_height, arc_velo, ens_coh = periodogram(
    stm_arcs,
    "d_phase",
    "h2ph",
    "Btemp",
    wavelength=WAVELENGTH,
)

stm_arcs["phs_obs_unwrapped"] = phs_obs_unwrapped
stm_arcs["ambigs"] = ambigs
stm_arcs["arc_height"] = arc_height
stm_arcs["arc_velo"] = arc_velo
stm_arcs["ens_coh"] = ens_coh

stm_arcs = stm_arcs.compute()
stm_arcs

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs['source'])['lon'].values,
               stm.isel(space = stm_arcs['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs['source'])['lat'].values,
               stm.isel(space = stm_arcs['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=1.0)
mean_nmad = np.abs(stm_arcs['ens_coh'].data)
for i in range(stm_arcs.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(norm(mean_nmad[i])), linewidth=0.5)
plt.title("Network STM arcs, colored by ens_coh")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='ens_coh')

In [None]:
# select arcs with ens_coh > 0.5
stm_arcs = arc_selection(stm_arcs, 0.75, 'ens_coh', 3)

In [None]:
# x of sources and targets
xx = np.stack([stm.isel(space = stm_arcs['source'])['lon'].values,
               stm.isel(space = stm_arcs['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm.isel(space = stm_arcs['source'])['lat'].values,
               stm.isel(space = stm_arcs['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=1.0)
mean_nmad = np.abs(stm_arcs['ens_coh'].data)
for i in range(stm_arcs.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(norm(mean_nmad[i])), linewidth=0.5)
plt.title("Network STM arcs, colored by ens_coh")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='ens_coh')

In [None]:
# Ensure all points have at least 3 connections
stm_step1 = stm.copy()
stm_arcs_step1 = stm_arcs.copy()
previous_size = -1  # Initialize with an impossible value to trigger the while loop

In [None]:
# Keep iterating until no more points are removed
while stm_step1.sizes["space"] != previous_size:
    previous_size = stm_step1.sizes["space"]
    # Remove points with <=2 connections
    stm_step1, stm_arcs_step1 = remove_network_points_min_connections(
        stm_step1, stm_arcs_step1, min_connections=3
    )

In [None]:
# x of sources and targets
xx = np.stack([stm_step1.isel(space = stm_arcs_step1['source'])['lon'].values,
               stm_step1.isel(space = stm_arcs_step1['target'])['lon'].values]).T
# y of sources and targets
yy = np.stack([stm_step1.isel(space = stm_arcs_step1['source'])['lat'].values,
               stm_step1.isel(space = stm_arcs_step1['target'])['lat'].values]).T
# Visualize created arcs
fig, ax = plt.subplots()
cmap = plt.cm.rainbow
norm = plc.Normalize(vmin=0, vmax=1.0)
mean_nmad = np.abs(stm_arcs_step1['ens_coh'].data)
for i in range(stm_arcs_step1.sizes["space"]):
    ax.plot(xx[i], yy[i], color=cmap(norm(mean_nmad[i])), linewidth=0.5)
plt.title("Network STM arcs, colored by ens_coh")
plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='ens_coh')

## Step 2: MHT rejecting arcs/pnts

In [None]:
# Netork relation matrix A
A_sparse = _network_relation_matrix(stm_arcs_step1["source"], stm_arcs_step1["target"], stm_step1.sizes["space"])

# A_sparse = _network_relation_matrix(stm_arcs_step1["target"], stm_arcs_step1["source"], stm_step1.sizes["space"])

# Observations y
y = stm_arcs_step1['ambigs'].data

# Prepare stochastic model
N_arcs = stm_arcs_step1.sizes["space"]
Qyy_diag = np.ones(N_arcs)
invQy = scipy.sparse.diags(1/Qyy_diag, 0, shape=(N_arcs, N_arcs))

# First estimation
_, echeck = _solve_float_ambiguities(A_sparse, y, invQy)

In [None]:
# Overall Model Test
kOMT = 1e-10 # threshold
OMT = (echeck.T @ invQy @ echeck).diagonal().sum()
print("OMT:", OMT)
print("OMT > kOMT:", OMT > kOMT) # if True, fail, need adjustment

In [None]:
# Setup tests
a0 = 0.1
g0 = 0.5
max_con = np.abs(A_sparse).sum(axis=0).max()

kb_dict = {}

for n_con in range(1, max_con+1):
    _, k1, kb, _ = pretest(n_con, a0, g0)
    kb_dict[n_con] = kb

In [None]:
def visualize_network_echeck(pnt, arcs, echeck):
    # Plot echeck for all arcs
    # x of sources and targets
    xx = np.stack([pnt.isel(space = arcs['source'])['lon'].values,
                pnt.isel(space = arcs['target'])['lon'].values]).T
    # y of sources and targets
    yy = np.stack([pnt.isel(space = arcs['source'])['lat'].values,
                pnt.isel(space = arcs['target'])['lat'].values]).T
    # Visualize created arcs
    fig, ax = plt.subplots()
    cmap = plt.cm.rainbow
    norm = plc.Normalize(vmin=0, vmax=10.0)
    echeck_sum = echeck.sum(axis=1)
    for i in range(arcs.sizes["space"]):
        ax.plot(xx[i], yy[i], color=cmap(echeck_sum[i]), linewidth=0.5)
    ax.scatter(pnt['lon'], pnt['lat'], c='k', s=3)
    plt.title("Network STM arcs, colored by echeck sum")
    plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label='echeck sum')

In [None]:
TT1_THRES = 1 # Threshold for TT1, if TT1max < TT1_THRES, stop iteration
TT1max = TT1_THRES + 1.0 # Initial TT1_max to trigger the while loop
MAX_ITER = stm_arcs_step1.sizes["space"] * stm_arcs_step1.sizes["time"]
stm_step2 = stm_step1.copy()
stm_arcs_step2 = stm_arcs_step1.copy()
niter = 0

while (TT1max > TT1_THRES) and (OMT > kOMT) and (niter < MAX_ITER):
    # In the loop, OMT fail
    # Choose from two Ha: 1) remove an arc; 2) remove a point
    flag_rm, idx_rm, TT1max, TTqmax = _mht_network_adjustment(A_sparse, y, Qyy_diag, k1, kb_dict)

    if flag_rm == 0:  # remove arcs
        stm_arcs_step2 = stm_arcs_step2.drop_isel(space=idx_rm)  # Remove the arc
    elif flag_rm == 1:  # remove points
        # Removing points is achieved by removing all arcs connects to the point
        # Later the points will be actually removed when ensuring minimum connections
        # Arc indices connecting to the point to remove
        idx_arcs_selected = np.where(
            (
                (stm_arcs_step2["source"] != idx_rm)
                & (stm_arcs_step2["target"] != idx_rm)
            ).data
        )[0] 
        # Remove all arcs connects to the point to remove
        stm_arcs_step2 = stm_arcs_step2.isel(space=idx_arcs_selected)
    
    # Ensure all points have at least 3 connections
    previous_size = -1  # Initialize with an impossible value to trigger the while loop
    # Keep iterating until no more points are removed
    while stm_step2.sizes["space"] != previous_size:
        previous_size = stm_step2.sizes["space"]
        # Remove points with <=2 connections
        stm_step2, stm_arcs_step2 = remove_network_points_min_connections(
            stm_step2, stm_arcs_step2, min_connections=3
        )

    # Get indices of selected arcs based on uid
    idx_arcs_selected = np.where(stm_arcs_step2['uid'].isin(stm_arcs_step1['uid']))[0]

    # Update the functional and stochastic model
    Qyy_diag = Qyy_diag[idx_arcs_selected]
    N_arcs = stm_arcs_step2.sizes["space"]
    invQy = scipy.sparse.diags(1 / Qyy_diag, 0, shape=(N_arcs, N_arcs))
    A_sparse = _network_relation_matrix(
        stm_arcs_step2["source"],
        stm_arcs_step2["target"],
        stm_step2.sizes["space"],
    )
    
    y = stm_arcs_step2["ambigs"].data
    

    # Estimate residual again
    _, echeck = _solve_float_ambiguities(A_sparse, y, invQy)

    OMT = (echeck.T @ invQy @ echeck).diagonal().sum()

    niter += 1
    print(
        f"Iteration {niter}: OMT = {OMT}, flag_rm: {flag_rm}, TT1max: {TT1max}, TTqmax: {TTqmax}"
    )

    visualize_network_echeck(stm_step2, stm_arcs_step2, echeck)


## Step 3: fix ambiguity errors per epoch

In [None]:
kOMT = 1e-7 # threshold for each epoch, changed from 1e-10 to avoid repeatively fixing several same arcs

# Fix unwrapping ambiguities by looping over epochs
stm_arcs_step3 = stm_arcs_step2.copy()
stm_step3 = stm_step2.copy()
estimated_ambigs = np.zeros((stm_step3.sizes["space"], stm_step3.sizes["time"]))
for epoch in range(stm_step2.sizes["time"]):
    y = stm_arcs_step2["ambigs"].isel(time=epoch).data
    acheck_ifg, echeck_ifg = _solve_float_ambiguities(A_sparse, y, invQy)
    OMT = echeck_ifg.T @ invQy @ echeck_ifg
    print(f"Epoch {epoch}, updated OMT: {OMT}")

    idx_previous_arc_fix = -1 # Avoid fixing the same arc again in the same epoch

    while OMT > kOMT: # While OMT fail, fix for this epoch
        # Find arc index with largest abs echeck
        # When OMT > kOMT, echeck_ifg[idx_max_echeck] is guaranteed to be non-zero
        idx_sort = np.argsort(np.abs(echeck_ifg))[::-1] # Indices of echeck sorted by abs value, descending
        idx_max_echeck = idx_sort[0] # Index of arc with largest abs echeck
        if idx_max_echeck == idx_previous_arc_fix:
            # If get same arc as previous fix, take the second largest
            idx_max_echeck = idx_sort[1]
        
        if np.round(abs(echeck_ifg[idx_max_echeck])) >= 1: # If >= 1, minus closest integer
            y[idx_max_echeck] -= np.round(echeck_ifg[idx_max_echeck])
        elif echeck_ifg[idx_max_echeck] > 0: # if (0, 1), minus 1
            y[idx_max_echeck] -= 1.0
        elif echeck_ifg[idx_max_echeck] < 0: # if (-1, 0), plus 1
            y[idx_max_echeck] += 1.0
        
        idx_previous_arc_fix = idx_max_echeck # record the fixed arc index

        # Recalculate OMT
        acheck_ifg, echeck_ifg = _solve_float_ambiguities(A_sparse, y, invQy)
        OMT = echeck_ifg.T @ invQy @ echeck_ifg
        print(f"Epoch {epoch}, updated OMT: {OMT}, idx_fixed_arc: {idx_max_echeck}")
    
    # Update the ambiguities for arcs and stm points
    stm_arcs_step3["ambigs"].data[:, epoch] = y

    # Assign 
    estimated_ambigs[:, epoch] = acheck_ifg

In [None]:
# TODO: CHECK estimated_ambigs
# This value should be all close to integers since we have fixed arc ambiguities
estimated_ambigs

In [None]:
# Round the estimated ambiguities to nearest integer and assign to stm_step3
stm_step3["estimated_ambigs"] = (("space", "time"), np.round(estimated_ambigs))

# Visualize the estimated ambiguities
fig, ax = plt.subplots()
ax.imshow(estimated_ambigs)

## Step 4: estimate the unwrapped phase

In [None]:
# Use _solve_float_ambiguities function, but it should be renamed to a more general name
psc_phase, _ = _solve_float_ambiguities(A_sparse, stm_arcs_step3['d_phase'].data, invQy) 
psc_phase_unw = 2*np.pi*estimated_ambigs+psc_phase
stm_step3["psc_phase_unw"] = (("space", "time"), psc_phase_unw)

# visualize unwrapped phase
fig, ax = plt.subplots()
ax.imshow(psc_phase_unw)

In [None]:
# Visualize one pnt wrapped and unwrapped phase
pnt_id = 25
stm_1pnt = stm_step3.isel(space = pnt_id)

fig, ax = plt.subplots(figsize=(12, 5))
#plot without line
ax.plot(stm_1pnt["years"].data, stm_1pnt["sd_phase"].data+2*np.pi, marker='.', linestyle='None', color='gray',)
ax.plot(stm_1pnt["years"].data, stm_1pnt["sd_phase"].data-2*np.pi, marker='.', linestyle='None', color='gray',)
ax.plot(stm_1pnt["years"].data, stm_1pnt["psc_phase_unw"], marker='.', color='red', linewidth=0.5, label='Unwrapped phase')
ax.plot(stm_1pnt["years"].data, stm_1pnt["sd_phase"].data, color='blue', linewidth=0.5, label='Wrapped phase')
ax.set_ylabel('Phase [rad]')
ax.set_xlabel('Time [years]')
ax.legend()

## Step 5: Model estimation (TBC)

This can be implemented as a separate function