# Experiment F02: undersampeld radial with Compressed Sensing (CS)

In [None]:
experiment_id = 'exF02_undersampled_radial'

import numpy as np
# HACK: newer numpy versions don't contain this, but pypulseq still relies on it
np.int = int
np.float = float
np.complex = complex

import MRzeroCore as mr0
import math
import torch
import matplotlib.pyplot as plt
import pypulseq as pp

(CS_reco)=
## 1) Setup system

In [None]:
system = pp.Opts(max_grad=28, grad_unit='mT/m', max_slew=150, slew_unit='T/m/s',
                 rf_ringdown_time=20e-6, rf_dead_time=100e-6, adc_dead_time=20e-6,
                 grad_raster_time=50*10e-6)

fov = 1000e-3 
slice_thickness=8e-3

Nread = 64    # frequency encoding steps/samples
Nphase = 32   # phase encoding steps/samples

## 2) Construct the Sequence

In [None]:
seq = pp.Sequence()
rf_phase = 180
rf_inc = 180

rf0 = pp.make_sinc_pulse(flip_angle=6/2 * math.pi / 180, duration=1e-3,
                         slice_thickness=slice_thickness, apodization=0.5,
                         time_bw_product=4, system=system)                      
rf1 = pp.make_sinc_pulse(flip_angle=6 * math.pi / 180, duration=1e-3,
                         slice_thickness=slice_thickness, apodization=0.5,
                         time_bw_product=4, system=system)

gx_tmp = pp.make_trapezoid(channel='x', flat_area=Nread, flat_time=5e-3, system=system)
adc = pp.make_adc(num_samples=Nread, duration=5e-3, phase_offset=0*np.pi/180,
                  delay=gx_tmp.rise_time, system=system)

seq.add_block(rf0)
seq.add_block(pp.make_delay(3e-3))

for ii in range(-Nphase//2, Nphase//2):  # e.g. -64:63
    rf1.phase_offset = rf_phase / 180 * np.pi   # set current rf phase
    
    adc.phase_offset = rf_phase / 180 * np.pi  # follow with ADC
    rf_phase = divmod(rf_phase + rf_inc, 360.0)[1]        # increment additional pahse

    seq.add_block(rf1)
    
    if ii != 0:
        gx = pp.make_trapezoid(channel='x', flat_area=-Nread*np.sin(ii/Nphase*np.pi), flat_time=5e-3, system=system)
    else:
        # pulseq 1.3.1post1 thinks flat_area=0 means it was not set
        gx = pp.make_trapezoid(channel='x', flat_area=1e-7, flat_time=5e-3, system=system)
    gy = pp.make_trapezoid(channel='y', flat_area=Nread*np.cos(ii/Nphase*np.pi), flat_time=5e-3, system=system)
    
    gx_pre = pp.make_trapezoid(channel='x', area=-gx.area / 2, duration=1e-3, system=system)
    gy_pre = pp.make_trapezoid(channel='y', area=-gy.area / 2, duration=1e-3, system=system)
    
    
    seq.add_block(gx_pre,gy_pre)
    seq.add_block(adc,gx,gy)
    # seq.add_block(adc,gx,gy)
    seq.add_block(gx_pre,gy_pre)
    # seq.add_block(make_delay(10))

## 3) Check, Plot and Write the sequence as .seq

In [None]:
ok, error_report = seq.check_timing()  # Check whether the timing of the sequence is correct
if ok:
    print('Timing check passed successfully')
else:
    print('Timing check failed. Error listing follows:')
    [print(e) for e in error_report]

# PLOT sequence
seq.plot()

# Prepare the sequence output for the scanner
seq.set_definition('FOV', [fov, fov, slice_thickness]*1000)
seq.set_definition('Name', 'gre')
seq.write('out/external.seq')
seq.write('out/' + experiment_id +'.seq')

## 4) Simulate the .seq file

In [None]:
%%capture

sz = (64, 64)

phantom = mr0.VoxelGridPhantom.brainweb("subject05.npz")
phantom = phantom.interpolate(64, 64, 32).slices([16])
data = phantom.build()

seq = mr0.Sequence.from_seq_file("out/external.seq")
graph = mr0.compute_graph(seq, data, 200, 1e-3)
signal = mr0.execute_graph(graph, seq, data)
kspace_loc = seq.get_kspace()


kspace_adc=torch.reshape((signal),(Nphase,Nread)).clone().t()
spectrum=kspace_adc

space = torch.zeros_like(spectrum)


import scipy.interpolate
grid = kspace_loc[:,:2]
Nx=64
Ny=64

X, Y = np.meshgrid(np.linspace(0,Nx-1,Nx) - Nx / 2, np.linspace(0,Ny-1,Ny) - Ny/2)
grid = np.double(grid.numpy())
grid[np.abs(grid) < 1e-3] = 0

plt.subplot(347); plt.plot(grid[:,0].ravel(),grid[:,1].ravel(),'rx',markersize=3);  plt.plot(X,Y,'k.',markersize=2);
plt.show()

spectrum_resampled_x = scipy.interpolate.griddata((grid[:,0].ravel(), grid[:,1].ravel()), np.real(signal.ravel()), (X, Y), method='cubic')
spectrum_resampled_y = scipy.interpolate.griddata((grid[:,0].ravel(), grid[:,1].ravel()), np.imag(signal.ravel()), (X, Y), method='cubic')

kspace_r=spectrum_resampled_x+1j*spectrum_resampled_y
kspace_r[np.isnan(kspace_r)] = 0



# k-space sampling pattern needed for the CS algorithms
pattern_resampled=np.zeros([sz[0],sz[1]])
gridx=grid[:,0].ravel()
gridy=grid[:,1].ravel()
for ii in range(len(gridx)):
    pattern_resampled[int(gridx[ii]),int(gridy[ii])]=1
plt.imshow(pattern_resampled)
plt.show()
# end sampling pattern

## 5) Compressed Sensing MR reconstruction of undersampled signal

### 5.1) Function definitions

In [None]:
import pywt
from skimage.restoration import denoise_tv_chambolle


def shrink(coeff, epsilon):
	shrink_values = (abs(coeff) < epsilon) 
	high_values = coeff >= epsilon
	low_values = coeff <= -epsilon
	coeff[shrink_values] = 0
	coeff[high_values] -= epsilon
	coeff[low_values] += epsilon

# help: https://www2.isye.gatech.edu/~brani/wp/kidsA.pdf
for family in pywt.families():
    print("%s family: " % family + ', '.join(pywt.wavelist(family)))
print(pywt.Wavelet('haar'))


def waveletShrinkage(current, epsilon):
	# Compute Wavelet decomposition
	cA, (cH, cV, cD)  = pywt.dwt2(current, 'haar')
	#Shrink
	shrink(cA, epsilon)
	shrink(cH, epsilon)
	shrink(cV, epsilon)
	shrink(cD, epsilon)
	wavelet = cA, (cH, cV, cD)
	# return inverse WT
	return pywt.idwt2(wavelet, 'haar')


def updateData(k_space, pattern, current, step,i):
    # go to k-space
    update = np.fft.ifft2(np.fft.fftshift(current))
    # compute difference
    update = k_space - (update * pattern)
    print("i: {}, consistency RMSEpc: {:3.6f}".format(i,np.abs(update[:]).sum()*100))
    # return to image space
    update = np.fft.fftshift(np.fft.fft2(update))
    update = current + (step * update)  # improve current estimation by consitency
    return update

### 5.2) Preparation and conventional fully sampled reconstruction

In [None]:
kspace_full = np.fft.ifftshift(kspace_r)  # high  frequencies centered as FFT needs it

kspace=kspace_full
recon_nufft = (np.fft.fftshift(np.fft.fft2(kspace_full))) # fully sampled recon

### 5.3) Undersampling and undersampled reconstruction

In [None]:
# kspace_full= kspace_full/ np.linalg.norm(kspace_full[:])   # normalization of the data somethimes helps

# parameters of iterative reconstructio using total variation denoising  
denoising_strength = 5e-5
number_of_iterations = 8000

# parameters of random subsampling pattern
# percent = 0.25        # this is the amount of data that is randomly measured
# square_size = 16      # size of square in center of k-space 


# # generate a random subsampling pattern
# np.random.seed(np.random.randint(100))
# pattern = np.random.random_sample(kspace.shape)
# pattern=pattern<percent  # random data

# pattern[sz[0]//2-square_size//2:sz[0]//2+square_size//2,sz[0]//2-square_size//2:sz[0]//2+square_size//2] = 1   # square in center of k-space 
# pattern = np.fft.fftshift(pattern) # high  frequencies centered as kspace and as FFT needs it

pattern= pattern_resampled

kspace = kspace_full *pattern  # apply the undersampling pattern

actual_measured_percent =np.count_nonzero(pattern) / pattern.size *100  #  calculate the actually measured data in percent

## actual iterative reconstruction algorithm 
current = np.zeros(kspace.size).reshape(kspace.shape)
current_shrink = np.zeros(kspace.size).reshape(kspace.shape)
first = updateData(kspace, pattern, current, 1,0)
current_shrink=first
all_iter = np.zeros((kspace.shape[0],kspace.shape[1],number_of_iterations))

i = 0
while i < number_of_iterations:
    current = updateData(kspace, pattern, current_shrink, 0.1,i)
   
    current_shrink = denoise_tv_chambolle(abs(current), denoising_strength)
    # current_shrink = waveletShrinkage(abs(current), denoising_strength)
    
    all_iter[:,:,i]=current
    i = i + 1; 
		
## Plotting

pattern_vis = np.fft.fftshift(pattern * 256)

fig=plt.figure(figsize=(5, 10), dpi=90)
plt.subplot(321)
plt.set_cmap(plt.gray())
plt.imshow(abs(recon_nufft)); plt.ylabel('recon_full')
plt.subplot(322)
plt.set_cmap(plt.gray())
plt.imshow(abs(pattern_vis)); plt.ylabel("pattern_vis"); plt.title("{:.1f} % sampled".format(actual_measured_percent))
plt.subplot(323)
plt.set_cmap(plt.gray())
plt.imshow(abs(first)); plt.ylabel('first iter (=NUFFT)')
plt.subplot(325)
plt.set_cmap(plt.gray())
plt.imshow(abs(current_shrink)) ; plt.ylabel('final recon')
plt.subplot(324)
plt.set_cmap(plt.gray())
plt.imshow(np.log(abs(np.fft.fftshift(kspace_full)))); plt.ylabel('kspace_nufft')
plt.subplot(326)
plt.set_cmap(plt.gray())
plt.imshow(np.log(abs(np.fft.fftshift((kspace))))); plt.ylabel('kspace*pattern')
plt.show()

In [None]:
# Plot some iterations
idx=np.linspace(1,all_iter.shape[2],25)-1       # make 25 example iterations
red_iter=all_iter[:,:,tuple(idx.astype(int))]   # choose them from all iters
Tot=red_iter.shape[2]
Rows = Tot // 5 
if Tot % 5 != 0:
    Rows += 1
Position = range(1,Tot + 1) # Position index

fig = plt.figure()
for k in range(Tot):
  ax = fig.add_subplot(Rows,5,Position[k])
  ax.imshow((abs((red_iter[:,:,k])))); plt.title('iter {}'.format(idx[k].astype(int)))
  print(k)
plt.show()