# Handle large design matrix A in network unwrapping

## The problem

Consider we have a network with `N_points` points (nodes) and `N_arcs` arcs, each point has a wrapped phase time series with the length of `N_time`. By forming arcs, we have differential phases between the points, and we can solve the ambiguities per arc.

Now we want to integrate the arcs ambiguities into point ambiguities. For a point, there can be multiple arcs connected to it, and if we approach a point following different arcs, the ambiguities solutions may conflict. Therefore we need to find an optimal solution to adjust these errors. A very simplified solution is through the following linear model:


```math
y = Ax
```

where:

- `y` is the known arc ambiguity matrices, with the shape `N_arcs` x `N_time`
- `A` is the design matrix, with the shape `N_arcs` x `N_pixels`. `A` represents the relationship between the arcs and the points, with values -1, 1, and 0. Each row of `A` represent to an arc, with 1 marking the target point of that arc, -1 for the source point.
- `x` is the ambuiguity time series of the points need to be solved, with the shape `N_points` x `N_time`.

Based on Least-Squares principle, assuming all `y` are independent and equally weighted, we can estimate `x` as:

```math
x = (A^T A)^{-1} A^T y
```

If we look into this very simple model, `A` is a very large sparse matrix. In this notebook, we can investigate if there is a simple way in python to perform this computation.


## Some conclusions
In this notebook, there are three approaches to solve the above problem. Only the first approach finishes within a reasonable time. Based on the finding, I draw the following conclusion:

- Dask sparse matrix (the `sparse` lib) is very efficient on handling large sparse matrix multiplication. However, it does not support matrix inversion.
- `da.linalg.solve` works well for matrix inversion on Dask Array, however the array must be dense. It can handle large arrays with good memory management and performance.
- It is possible to make the dense operation delayed on a Dask sparse Array, then put it in `da.linalg.solve`. However this seem to be over complicate the task graph and even defeat the chunking mechanism. It gives memory error.

Based on the above findings, I would recommend Approach 1 for solving large matrix A. It solves a network with 10,000 points and 100,000 arcs in less than 5 minutes, on a local laptop. 

In [1]:
import dask
import dask.array as da
import sparse
import numpy as np

## Setup experiments

Even considering extremely large networks, we have:

- N_points < 50000
- N_arcs < 600000

For a single point, its connection n_connection < 100

In this test, we use a network with 10,000 points and 100,000 arcs. Each point has a time series of length 100. This will be larger than 90% cases in practice.

In [None]:
# Number of points and arcs
N_points = 10_000
N_arcs = 100_000
N_time = 100 # Only influence y_dummy

# Dask setup
N_chunks = 10

# Dummy arc ambiguities
y_dummy = da.random.randint(0, 1, (N_arcs, N_time), chunks=(N_arcs // N_chunks, N_time))

# randomly generate connections between points
start_end_indices = da.random.randint(0, N_points, size=(N_arcs, 2), chunks=(N_arcs//10, 2)).astype(np.int32).compute()
xindex = np.arange(N_arcs)
yindex_start = start_end_indices[:, 0]
yindex_end = start_end_indices[:, 1]

## Solution 1: Build A as a dask sparse matrix

This approach create A as a dask sparse matrix, which takes very little memory. The disadavantage is that we need to persist `A^T A` in memory, since the matrix inversion process `da.linalg.solve` does not support sparse matrix. This assumes that `A^T A`, which has the size of `N_point` x `N_point` can fit into memory. Considering `A`has a dtype of `np.int8`. For `A^T A`, since the conncetion per node is usually < 100, therefore it can also be represented as `np.int8` (range -128 to 127). Therefore, `N` is likely to be able to fit into memory.

For N_points = 50000, the memory usage of `A^T A` is:

```python
N_points = 50000
N_bytes = N_points * N_points * np.dtype(np.int8).itemsize
N_bytes / 1024**2  # in MB
```

Which is about 2.4GB


In [None]:
# Create a sparse COO matrix
A_sparse_start = sparse.COO((xindex, yindex_start), np.full_like(xindex, -1, dtype=np.int8), shape=(N_arcs, N_points))
A_sparse_end = sparse.COO((xindex, yindex_end), np.full_like(xindex, 1, dtype=np.int8), shape=(N_arcs, N_points))
A_sparse = A_sparse_start + A_sparse_end
# Convert to Dask array with desired chunks
A = da.from_array(A_sparse, chunks=(N_arcs//N_chunks, N_points//N_chunks))

A

In [None]:
# One can perform matrix multiplication on sparse array A,
# But inverse is not supported for sparse arrays
# We can compute N (N_point x N_point) as a dense array, since it is smaller than A (N_points < N_arcs)
N = (A.T @ A).astype(np.int8)
 
# Dask array N does not have a .todense() method
# delayed dense N for linear solve, but still make the computation fit in memory
N_dense = N.map_blocks(lambda x: x.todense(), dtype=N.dtype)

The following code is commented out because it give memory error

In [None]:
# # If we directly use N_dense as a Dask array in da.linalg.solve, we will run into memory errors
# param_dummy = (
#     da.linalg.solve(
#         N_dense, da.eye(N_points, chunks=N_points // N_chunks, dtype=N.dtype)
#     )
#     @ A.T
#     @ y_dummy
# )

# param_dummy_comp = param_dummy.compute()

The following code works, if we first persist `N = A.T @ A` in memory, and then use `da.linalg.solve` to solve the linear system.

In [None]:
# The whole cell take about 3 minites to run
# Solving N_dense_inv seem to be a must
# da.linalg.solve(N_dense, A_dense) will cause memory issues since A_dense can not fit in memory
# da.linalg.solve only works with dense matrices
N_dense_inv = da.linalg.solve(N_dense, da.eye(N_points, chunks=N_points//N_chunks, dtype=N.dtype))
N_dense_inv_comp = N_dense_inv.compute() # About 2mins 39s to compute

# Create a Dask Array from the computed inverse matrix in memory
da_N_dense_inv_comp = da.from_array(N_dense_inv_comp, chunks=N_points//N_chunks)

# Linear solve to find parameters
# param_dummy = N_dense_inv @ A.T @ y_dummy
param_dummy = da_N_dense_inv_comp @ A.T @ y_dummy
param_dummy_comp = param_dummy.compute() # This takes ~30s to compute
param_dummy_comp

## Solution 2 (Not Working): First dense A
This assumes `A` can fit into memory. If this is the case then we dense `A` from the beginning, and create dask array from it. This is an attemp to accelerate the computation but it is not successful.

In [None]:
# Create a sparse COO matrix
A_sparse_start = sparse.COO((xindex, yindex_start), np.full_like(xindex, -1, dtype=np.int8), shape=(N_arcs, N_points))
A_sparse_end = sparse.COO((xindex, yindex_end), np.full_like(xindex, 1, dtype=np.int8), shape=(N_arcs, N_points))
A_sparse = A_sparse_start + A_sparse_end
# Convert to Dask array with desired chunks
A = da.from_array(A_sparse.todense(), chunks=(N_arcs//N_chunks, N_points//N_chunks))

A

In [None]:
# This does not finish in 30 minutes
N = (A.T @ A).astype(np.int8)
param_dummy = (
    da.linalg.solve(
        N, da.eye(N_points, chunks=N_points // N_chunks, dtype=N.dtype)
    )
    @ A.T
    @ y_dummy
)
param_dummy_comp = param_dummy.compute()

## Solution 3 (not working): Create A directly as a densed Dask array, without sparsity

In this method, it is attempted to create `A` as a dask array without sparsity, from the coordinates. Despite the fact that there is no memory error, the computation is extremely slow. Computing N does not finish in 10 minutes.

Maybe there is a better way to create `A` without block-wisely looping coordinates?

In [None]:
# Create a dense Dask array of zeros
A = da.zeros((N_arcs, N_points), dtype=np.int8, chunks=(N_arcs//N_chunks, -1))

# Create a function to set the 1 and -1 in the Dask array
# TODO:Is there a better way to do this? It seems Dask Array does not support setting values by index
def set_numbers(block, xindex, yindex_start, yindex_end, block_info=None):
    # in block, set (xindex, yindex_start) to -1 and (xindex, yindex_end) to 1
    if block_info is None:
        return block

    # Get the block location in global space
    arr_loc = block_info[None]['array-location']
    row_start, row_end = arr_loc[0]
    col_start, col_end = arr_loc[1]
    xindex_block = xindex[row_start:row_end]
    yindex_start_block = yindex_start[row_start:row_end]
    yindex_end_block = yindex_end[row_start:row_end]

    # Set the values in the block
    result = block.copy()
    result[xindex_block - row_start, yindex_start_block] = -1
    result[xindex_block - row_start, yindex_end_block] = 1

    return result

# Apply
A = A.map_blocks(set_numbers, xindex, yindex_start, yindex_end, dtype=A.dtype).rechunk((N_arcs // N_chunks, N_points // N_chunks))

A

In [None]:
N = (A.T @ A).astype(np.int8)
N

In [None]:
# This does not finish in 30 mins
N_comp = N.compute()

In [None]:
N_inv = da.linalg.solve(N, da.eye(N_points, chunks=N_points//N_chunks, dtype=N.dtype))
param_dummy = N_inv @ A.T @ y_dummy
param_dummy

In [None]:
# This takes
param_dummy_comp = param_dummy.compute()