In [None]:
import numpy as np
import matplotlib.pyplot as plt
from load_dcm_folder import load_dcm_folder

In [None]:
data_folder = '../../data/cine/NIH_SA_RealTime_CINE_base_RETRO_54/'
res  = load_dcm_folder(data_folder)
res = res[next(iter(res))]
imt = res[0].squeeze()

In [None]:
def fourier_shift(im, shift=0, axis=0):
    k = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(im)))

    if axis == 0:
        mod = (np.arange(im.shape[axis])[:,None] - im.shape[axis]/2)/im.shape[axis]
    elif axis == 1:
        mod = (np.arange(im.shape[axis])[None,:] - im.shape[axis]/2)/im.shape[axis]
    
    k *= np.exp(1j * mod * shift * 2 * np.pi)
    im =np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(k)))

    return np.real(im)

In [None]:
# Testing the fourier shift function here

plt.figure(dpi=200)
plt.imshow(imt[:,:,0], cmap='gray')

plt.figure(dpi=200)
plt.imshow(fourier_shift(imt[:,:,0], 30, axis=0), cmap='gray')

In [None]:
# Testing a very simple respiratory shift selector

resp_dur = 3000  # duration of a respiratory cycle [ms]
tt = np.arange(6000)  # ms
shift = np.sin(tt/resp_dur * 2 * np.pi)

plt.plot(tt, shift)


In [None]:
Nt = imt.shape[2]
axis = 0
Nk = imt.shape[axis]
N_seg = 4
seg_dur = 8  # ms
RR_t = N_seg * seg_dur * Nt  # ms
print('RR_t:', RR_t)

resp_dur = 1600
resp_mag = 4  # pixels

total_acq = int(Nt * N_seg * np.ceil(Nk/N_seg))  # Round up by heartbeat
# print(total_acq, Nk*Nt)

kt_m = np.zeros(imt.shape, complex)

for it in range(total_acq):
    tt = it * seg_dur  # Scan time in ms
    
    i_hb = it // (N_seg * Nt)  # This is the heartbeat number
    i_seg = it % N_seg  # This is where we are in a segment
    i_line = i_hb * N_seg + i_seg  # k-space line number
    i_frame = (it // N_seg) % Nt  # This is the timeframe we are in
    
    # print('{:4d}  {:4d}  {:4d}  {:4d}  {:4d}'.format(it, i_hb, i_seg, i_line, i_frame))

    if i_line < Nk:
        # These two are the motion lines
        shift = resp_mag*np.sin(tt/resp_dur * 2 * np.pi)
        im_shift = fourier_shift(imt[:,:,i_frame], shift, axis=axis)
        k = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(im_shift)))
        if axis == 0:
            kt_m[i_line, :, i_frame] = k[i_line, :]
        elif axis == 1:
            kt_m[:, i_line, i_frame] = k[:, i_line]

imt_m = np.abs(np.fft.ifftshift(np.fft.ifftn(np.fft.fftshift(kt_m, axes=(0,1)), axes=(0,1)), axes=(0,1)))

In [None]:
plt.figure(dpi=200)
plt.imshow(imt[:,:,0], cmap='gray')

plt.figure(dpi=200)
plt.imshow(imt_m[:,:,0], cmap='gray')