In [48]:
import numpy as np
import matplotlib.pyplot as plt
from cbi_toolbox import reconstruct as cbir
from cbi_toolbox import parallel as cbip

import scipy.signal as signal
import scipy.fft as fft
import scipy.sparse as sparse
import scipy.linalg as linalg

%matplotlib widget

In [2]:

measure = np.load('../out/exp/measure.npy')

measure = measure[:, ::2, ...]

In [3]:
wimage = cbir.dwt_preprocess(measure, drop_hi=2, parallel=True)
del measure

In [4]:
# Resampling

n_time = 100
wimage = signal.resample(wimage, n_time, axis=0)

wimage.shape

(100, 128, 87, 87)

In [168]:

coords = np.mgrid[:64, :64] - 32

wimage = np.zeros((20, 10, 64, 64))
for n in range(wimage.shape[0]):
    wimage[n, 0, ...] = (np.abs(coords[0]) < n+5) * (np.abs(coords[1]) < n+8)


shifts = np.random.randint(0, 20, size=10)
shifts[0] = 0
for n in range(wimage.shape[1]):
    wimage[:, n, ...] = np.roll(wimage[:, 0, ...], shifts[n], axis=0)

print(shifts)

[ 0 12 19 15 12 16 14 13 19  7]


In [186]:
max_dk = 2
w_k = [1, 1]

n_depth = wimage.shape[1]

Q_k_dk = np.empty((wimage.shape[0], wimage.shape[1], max_dk + 1))

workers = cbip.max_workers()

wfft = fft.rfft(wimage, axis=0, overwrite_x=False, workers=workers)


for delta_k in range(max_dk+1):
    corr = fft.irfft(wfft * np.conj(np.roll(wfft, -delta_k, axis=1)), axis=0, workers=workers, overwrite_x=True)
    Q_k_dk[..., delta_k] = corr.sum((-1, -2))

del wfft

S_k_dk = np.argmax(Q_k_dk, axis=0)

# plt.figure()
# # plt.imshow(Q_kk[..., 0])
# plt.imshow(S_k_dk)
# plt.colorbar()
# plt.show()
A = []
s = []
w = []

A0 = np.zeros((1, n_depth))
A0[:, 0] = 1

A.append(A0)
s.append(np.atleast_1d(0))
w.append(1)

for delta_k in range(max_dk):
    delta_k += 1
    A_d = sparse.diags_array((1, -1), offsets=(0, delta_k), shape=(n_depth - delta_k, n_depth)).toarray()
    A.append(A_d)
    s.append(S_k_dk[:n_depth - delta_k, delta_k])
    w.extend((w_k[delta_k - 1],) * (n_depth - delta_k))
     
A = np.vstack(A)
s = np.hstack(s)
W = np.sqrt(np.diag(w))

A = W @ A
s = s @ W

sol, _, _, _ = linalg.lstsq(A, s)
# sol = sol.astype(int) % n_depth

sol = np.round((sol)).astype(int) % 20

print(shifts)
print(sol)

s

[ 0 12 19 15 12 16 14 13 19  7]
[ 0  0 11  6  3  7  6  3 13  9]


array([ 0.,  8., 13.,  4.,  3., 16.,  2.,  1., 14., 12.,  1., 17.,  7.,
       19., 18.,  3., 15.,  6.])