# Using `einsum` for matrix multiplication and convolution

In this notebook, examples are provided to demonstrate how to use `numpy.einsum` for efficient matrix multiplication and convolution operations. 

This method aims to replace the nested loops in[ VCE (Variance Component Estimation)](https://github.com/TUDelftGeodesy/DePSI/blob/stable/main/ps_vce.m) step in DePSI.

In [None]:
import numpy as np

## Compute QPQy1QP and N

The following code is equivalent to this nested loop:
```python
    QPQy1QP = np.full((Nifgs, Nifgs, Nsig), np.nan)
    for k in range(Nsig):
        QPQy1QP[:, :, k] = QP @ Qy1[:, :, k] @ QP
        for j in range(Nsig):
            N[k, j] = np.trace(QPQy1QP[:, :, k] @ Qy1[:, :, j])
```

In [None]:
# We randomly generate the data to verify the correctness of the einsum method.
Nsig = 150 # Number of sigmas, equal to Nifgs or Nifgs+1
Nifgs = 150 # Number of interferograms
rng = np.random.default_rng(44)
Qy1 = rng.random((Nifgs, Nifgs, Nsig))
QP = rng.random((Nifgs, Nifgs))

In [None]:
%%time
# Einsum method
QPQy1QP = np.einsum('ij,jlk,lm ->imk', QP, Qy1, QP, optimize=True)
N = np.einsum('abk,baj->kj', QPQy1QP, Qy1, optimize=True)

In [None]:
%%time
# Original nested loop method
# For >400 epochs, this does not finish within 5mins
# For 150 epochs, this takes ~3mins
QPQy1QP2 = np.full((Nifgs, Nifgs, Nsig), np.nan)
N2 = np.full((Nsig, Nsig), np.nan)
for k in range(Nsig):
    QPQy1QP2[:, :, k] = QP @ Qy1[:, :, k] @ QP
    for j in range(Nsig):
        N2[k, j] = np.trace(QPQy1QP2[:, :, k] @ Qy1[:, :, j])

In [None]:
# Intermediate matrix and results should match
print(np.allclose(QPQy1QP, QPQy1QP2))
print(np.allclose(N, N2))

## Compute sig2

For `sig2` the efficiency of the einsum method depends on Narcs_vce x Nsig. With 352 independent arcs, only when Nsig is >350 the einsum method shows a significant speedup compared to the nested loop method. When <200 arcsm, nested loops usually takes 0.7 second, while einsum takes 1 second.

However we still recommend using the einsum method since its performance is more consistent, and the efficiency loss in small number of Nsig is not significant.

In [None]:
Ninv = np.linalg.inv(N)

In [None]:
Narcs_vce = 352
# Narcs_vce = 17152 # larger number of arcs for testing
rng = np.random.default_rng(41)
phase_unwrapped = rng.random((Narcs_vce, Nifgs))

In [None]:
%%time
# Nested loops method
# For 400 epochs and 352 arcs, this takes ~4 seconds
sig2 = np.full((Nsig, Narcs_vce), np.nan)
l = np.full((Nsig, 1), np.nan)
for v in range(Narcs_vce):
    y = phase_unwrapped[v, :].reshape(-1, 1)
    for k in range(Nsig):
        l[k, 0] = (y.T @ QPQy1QP[:, :, k] @ y).squeeze()
    sig2[:, v] = (Ninv @ l).flatten()

In [None]:
%%time
# Compute all l in one go, then perform the matrix multiplication
# This is an intermediate step to reach the einsum method
sig22 = np.full((Nsig, Narcs_vce), np.nan)
l2 = np.full((Nsig, Narcs_vce), np.nan)
for v in range(Narcs_vce):
    y = phase_unwrapped[v, :].reshape(-1, 1)
    for k in range(Nsig):
        l2[k, v] = (y.T @ QPQy1QP[:, :, k] @ y).squeeze()

sig22 = Ninv @ l2

In [None]:
%%time
# Compute using einsum
l3 = np.einsum('ij,jmk,mi -> ki', phase_unwrapped, QPQy1QP, phase_unwrapped.T, optimize='optimal')
sig23 = Ninv @ l3

In [None]:
print(np.allclose(l3, l2)) # the l vectors should match
print(np.allclose(sig2, sig22))  # Check the shapes of the results
print(np.allclose(sig2, sig23))  # Check if the results are the same