In [None]:
import hypermri
import hypermri.utils.utils_anatomical as ut_anat
import hypermri.utils.utils_general as utg
import sys

import os
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
import pandas as pd
import datetime
import math
from scipy.optimize import curve_fit 
import matplotlib
from matplotlib.patches import Rectangle

from hypermri.utils.utils_sv_spectroscopy import Plot_Voxel_on_Anat


def get_colors_from_cmap(cmap_name, N):
    cmap = plt.get_cmap(cmap_name)
    colors = cmap(np.linspace(0, 1, N))
    return colors


from matplotlib import rc
rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
rc("text", usetex=True)
matplotlib.rcParams.update({"font.size": 11})

from Template import import_paths

basepath,savepath,publication_path=import_paths()

basepath=os.path.dirname(os.path.dirname(basepath))+'/2022/'
print(basepath)
# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri or magritek folders
%load_ext autoreload
%autoreload 2

%matplotlib widget

In [None]:
figsize=6.9

# 1. In vitro comparison CSI PRESS

In [None]:
# define scan path
dirpath = basepath + '/'
ID='PyruvateSphere'
scans=hypermri.BrukerDir(dirpath,verbose=False)
coronal=scans[81]
sagittal=scans[82]
axial=scans[83]

axial_press=scans[94]
sagit_press=scans[95]
coronal_press = scans[92]

mv_press=scans[90]
sp2_press=scans[96]
csi_ref_img=scans[89]
sp2_csi=scans[77]
csi=scans[78]

## 1.3 Plot Proton reference and CSI SNR map

In [None]:
def FT_kspace_csi(csi, k_space, LB=0, cut_off=0):
    ac_time = csi.method["PVM_SpecAcquisitionTime"]
    ac_points = csi.method["PVM_SpecMatrix"]
    time_ax = np.linspace(0, ac_time, ac_points) / 1000
    n_vox = csi.method["PVM_NVoxels"]

    time_ax = np.linspace(0, ac_time, ac_points - cut_off) / 1000
    sigma = 2 * np.pi * LB
    ft_once = np.fft.fftshift(
        np.fft.fft(np.fft.fftshift(k_space, axes=(1,)), axis=1), axes=(1,)
    )
    ft_twice = np.fft.fftshift(
        np.fft.fft(np.fft.fftshift(ft_once, axes=(2,)), axis=2), axes=(2,)
    )
    ft_for_spec =  np.fft.fftshift(
            np.fft.fft(
                np.fft.fftshift(ft_twice * np.exp(-sigma * time_ax)[:,None,None], axes=(0,)),
                axis=0,
            ),
            axes=(0,),
        )
    
    return ft_for_spec
    
def reco_CSI(csi,cut_off=0,LB=0):
    enc_order = csi.method["PVM_EncOrder"]
    fidsz = csi.method["PVM_SpecMatrix"]
    tsz = csi.method["PVM_NRepetitions"]
    ysz = csi.method["PVM_Matrix"][1]
    xsz = csi.method["PVM_Matrix"][0]
    zsz = 1
    y_indices = csi.method["PVM_EncSteps1"] + math.floor(ysz / 2.0)
    x_indices = csi.method["PVM_EncSteps0"] + math.floor(xsz / 2.0)
    old_matrix_coords = np.zeros((ysz, xsz))
    c = 0

    for ny in y_indices:
        for nx in x_indices:
            old_matrix_coords[ny, nx] = c
            c += 1
    get_fids = []
    for n in np.arange(0, csi.rawdatajob0.shape[0], fidsz):
        get_fids.append(csi.rawdatajob0[cut_off + n : n + fidsz])
    fidsz = fidsz - cut_off
    get_fids = np.array(get_fids)

    k_space_array = np.zeros_like(csi.seq2d[cut_off:, :, :, 0],
                                  dtype=np.complex128)

    for idx in np.arange(0, get_fids.shape[0], 1):
        placement_idx = np.where(old_matrix_coords == idx)
        nx = placement_idx[1][0]
        ny = placement_idx[0][0]
        k_space_array[:, ny, nx] = (get_fids[int(idx), :])

    shift_x = (
        -csi.method["PVM_Phase0Offset"]
        * csi.method["PVM_Matrix"][0]
        / csi.method["PVM_Fov"][0]
    )
    shift_y = (
        -csi.method["PVM_Phase1Offset"]
        * csi.method["PVM_Matrix"][1]
        / csi.method["PVM_Fov"][1]
    )
    Wx_1d = np.exp(
        (1j * 2 * np.pi * np.linspace(0, xsz, xsz) * shift_x) / xsz
    )
    Wx = np.tile(Wx_1d.T, [fidsz, ysz, 1])
    Wy_1d = np.exp(
        (1j * 2 * np.pi * np.linspace(0, ysz, ysz) * shift_y) / ysz
    )
    # make 1D array into a 2D one using tile
    Wy = np.tile(Wy_1d.T, [fidsz, xsz, 1])
    Wx = np.transpose(Wx, [0, 1, 2])
    Wy = np.transpose(Wy, [0, 2, 1])
    ordered_k_space = np.flip(np.flip((k_space_array * Wx) * Wy,1),2)
    shifted_final = np.flip(FT_kspace_csi(csi,ordered_k_space, cut_off=cut_off, LB=LB),axis=0) # for reco fixing
    unshifted_final = FT_kspace_csi(csi,
        k_space_array, cut_off=cut_off, LB=LB)
    return shifted_final

### Perform complex reco

In [None]:
complex_csi_reco=reco_CSI(csi,cut_off=70)

### Get extents for plots

In [None]:
_,  sag_ext, _   = utg.get_extent(data_obj=csi_ref_img)
_,csi_ext, _   = utg.get_extent(data_obj=csi)

### Compute SNR map of CSI

In [None]:
csi_snr_map=np.zeros_like((complex_csi_reco[0,:,:]),dtype=float)
csi_normed_snr = np.zeros_like((complex_csi_reco),dtype=float)
bg_mean_map=np.zeros_like(csi_snr_map)
bg_std_map=np.zeros_like(csi_snr_map)

for n in range(csi_snr_map.shape[0]):
    for m in range(csi_snr_map.shape[1]):
        bg_mean= np.mean(np.abs(complex_csi_reco[0:50,n,m]))
        bg_std = np.std(np.abs(complex_csi_reco[0:50,n,m]))
        bg_mean_map[n,m]=bg_mean
        bg_std_map[n,m]=bg_std
        
        csi_snr_map[n,m] = np.max((np.abs(complex_csi_reco[:,n,m])-bg_mean)/bg_std)
        
        csi_normed_snr[:,n,m]=((np.abs(complex_csi_reco[:,n,m])-bg_mean)/bg_std)

In [None]:
fig,ax=plt.subplots(4,tight_layout=True)
im1=ax[0].imshow(bg_mean_map)
im2=ax[1].imshow(bg_std_map)
im3=ax[2].imshow(np.sum(np.abs(complex_csi_reco),axis=0))
ax[0].set_title('Mean of background')
ax[1].set_title('Std of background')
ax[2].set_title('Intensity')

fig.colorbar(im1,ax=ax[0],label='I [a.u.]')
fig.colorbar(im2,ax=ax[1],label='I [a.u.]')
fig.colorbar(im3,ax=ax[2],label='I [a.u.]')

ax[3].imshow(csi_snr_map)

In [None]:
plt.close('all')
fig,ax=plt.subplots(1,2,figsize=(figsize/2,figsize/4),tight_layout=True)
ax[0].imshow(csi_ref_img.seq2d[:,:,1].T,cmap='gray',extent=sag_ext)

snr_map=ax[1].imshow(csi_snr_map.T,extent=csi_ext,cmap='magma')


for n in range(2):
    ax[n].axis('off')
    
fig.colorbar(snr_map,ax=ax[1],label='SNR',ticks=[500,1000,1500])


ax[0].hlines(-13.5,8.5,3.5,linewidth=3,color='w')
ax[0].text(8,-12.5,'5mm',color='w',fontsize=8)

In [None]:
plt.close('all')
fig,ax=plt.subplots(1,2,figsize=(figsize/2,figsize/4),tight_layout=True)
ax[0].imshow(csi_ref_img.seq2d[:,:,1].T,cmap='gray',extent=sag_ext)

snr_map=ax[1].imshow(np.sum(csi.seq2d,axis=0).squeeze().T,extent=csi_ext,cmap='magma')


for n in range(2):
    ax[n].axis('off')
    
fig.colorbar(snr_map,ax=ax[1],label='I [a.u.]')


ax[0].hlines(-13.5,8.5,3.5,linewidth=3,color='w')
ax[0].text(8,-12.5,'5mm',color='w',fontsize=8)

## 1.4 Plot spectra

In [None]:
nx=4
ny=8
ppm_csi=csi.get_ppm(70)
spec_press=np.abs(np.fft.fftshift(np.fft.fft(np.squeeze(mv_press.complex_fids[70:,0]))))
ppm_press=mv_press.get_ppm(70)
ppm_sp_press=sp2_press.get_ppm(70)
spec_sp_press = np.abs(sp2_press.complex_spec)

ppm_sp_csi=sp2_csi.get_ppm(70)
spec_sp_csi = np.abs(sp2_csi.complex_spec)

csi_snr_raw=np.max(csi_normed_snr[:,nx,ny])

press_normed_snr=(spec_press-np.mean(spec_press[0:50]))/np.std(spec_press[0:50])
press_snr_raw = np.max(press_normed_snr)

print(csi_snr_raw,'CSI SNR')
print(press_snr_raw,'PRESS SNR')

In [None]:
fig,ax=plt.subplots(1,figsize=(figsize/2,figsize/3),tight_layout=True)
ax.plot(ppm_csi,np.flip(csi_normed_snr[:,nx,ny]),label='CSI',color='k')
ax.plot(ppm_press,press_normed_snr,label='PRESS',color='r')
ax.set_xlim([174,168])
ax.set_yticks([0,500,1000,1500])
ax.legend()
ax.set_ylabel('SNR')
ax.set_xticks([174,171,168])
ax.set_xlabel(r'$\sigma$ [ppm]')


In [None]:
fig,ax=plt.subplots(1,figsize=(figsize/2,figsize/3),tight_layout=True)
ax.plot(ppm_csi,np.flip(csi_normed_snr[:,nx,ny]),label='CSI',color='k')
ax.plot(ppm_press,press_normed_snr,label='PRESS',color='r')
ax.set_xlim([156,159])
#ax.set_ylim([-4,4])
ax.set_yticks([])
ax.legend()
ax.set_xlabel(r'$\sigma$ [ppm]')


# 2. SNR analysis of in vivo measurement

## Measure MV-PRESS and CSI in same region to compare SNR in vivo in a PDAC animal using Hyperpolarized 13C Pyruvate 

In [None]:
dirpath2=basepath+'/'

scans2=hypermri.BrukerDir(dirpath2)

In [None]:
animal=''
coronal=scans2[6] 
axial=scans2[8]
mv_press=scans2[24]
t2w_csi=scans2[35]
sp2_press=scans2[25]
csi = hypermri.BrukerExp(dirpath2+'31/')
sp2_csi=scans2[30]
sp90_press=scans2[26]
sp90_csi=scans2[33]
# name the positions where we have put PRESS Muscle
vox_names=['Kidney','Tumor','Kidney','Liver','Muscle']

## 2.1 Plot Coronal and axial images with PRESS voxel and CSI overlayed

In [None]:
# add csi patch to coronal
fov_csi = csi.method["PVM_Fov"]
fov_cor = coronal.method["PVM_Fov"]

slice_thick = csi.method["PVM_SliceThick"]
# first find out what way the CSI slice is positioned
rot_matrix = csi.method[
    "PVM_SPackArrGradOrient"
]  # rotated slice in axial direction around y axis
# rotmatrix is a matrix of form (cos(a) 0 sin a)
#                               (0.     1.  0)
#                               (-sin a  0.   cos a)
# so to find alpha we need to arccos(rot_matrix[0][0][0])*180/pi ---> 36.137°
rot_angle = np.arccos(rot_matrix[0][0][0])
offset = csi.method[
    "PVM_SPackArrSliceOffset"
]  # offset of center of slice from (0,0,0)
orient = csi.method["PVM_SPackArrSliceOrient"]  # axial

# now we need to change the coordinate system from the center of the rectangle to the edge /this is a bit too
# much work i know, but i did it that way
s = slice_thick #/ np.sin(rot_angle)


df = np.sqrt(s**2 - slice_thick**2)
eps = (df + fov_csi[0] / 2.0) * np.sin(rot_angle)

phi = (df + fov_csi[0] / 2.0) * np.cos(rot_angle)

# now we need to fogure out what the coordinates of the bottom left point of the rectangle visualizing the CSI are
# some geometric drawing and trigonometry reveal that it has the coordinates:
px = (
    phi + 2 * np.cos(rot_angle) / fov_csi[0]
)  # this might have to be adjusted if the FOV is not quadratic
pz = (
    eps
    + np.abs(offset)
    + slice_thick / 2.0
    + 2 * np.sin(rot_angle) / fov_csi[1]
)


In [None]:
fig,ax=plt.subplots(2,1,figsize=(figsize/3,figsize/2),height_ratios=(30,25))
Plot_Voxel_on_Anat(mv_press,coronal,ax[0],vmin=0,vmax=160,vox_color='r')
ax[0].set_xlim([13,-13])
ax[0].set_ylim([-13,13])
ax[0].axis('off')
ax[0].set_title(None)
ax[0].add_patch(Rectangle(
                    (-px, -pz),
                    fov_csi[0],
                    slice_thick,
                    angle=rot_angle * 180 / np.pi,
                    alpha=1,
                    ec="r",
                    color="None",linewidth=1,linestyle='solid'))
Plot_Voxel_on_Anat(mv_press,axial,ax[1],vmin=0,vmax=160)
ax[1].set_xlim([13,-13])
ax[1].set_ylim([-7,15])
ax[1].axis('off')
ax[1].set_title(None)


## 2.2 Plot spectra

In [None]:
def Reco_CSI_animal(csi, Plot_QA=True, cut_off=70, LB=0):
    # check for encoding type:
    enc_order = csi.method["PVM_EncOrder"]
    fidsz = csi.method["PVM_SpecMatrix"]
    tsz = csi.method["PVM_NRepetitions"]
    ysz = csi.method["PVM_Matrix"][1]
    xsz = csi.method["PVM_Matrix"][0]
    zsz = 1
    y_indices = csi.method["PVM_EncSteps1"] + math.floor(ysz / 2.0)
    x_indices = csi.method["PVM_EncSteps0"] + math.floor(xsz / 2.0)
    read_orient = csi.method["PVM_SPackArrReadOrient"] # depends how encoding is done either A_P or L_R (default)
    # now reco differs in case linear or centric is used
    if enc_order == "LINEAR_ENC LINEAR_ENC":
        # do linear encoding reco
        print("Encoding order", enc_order)
        if csi.dual_channel_flag is True:
            print(
                "This is dual channel data, performing phasing first, then csi reco..."
            )
            # do phasing
        else:
            print("This is single channel data")

        # TODO currently implementing this
        # to be implemented
        shifted_final = None

    elif enc_order == "CENTRIC_ENC CENTRIC_ENC":
        print("Encoding order", enc_order)
        # do centric encoding reco

        # this now applies to measurement where we don't have two channels
        # e.g. 31mm coil measurements
        # reorder k space according to sampling
        old_matrix_coords = np.zeros((ysz, xsz))

        if read_orient == 'L_R':

            #print('xsz',xsz,'ysz', ysz)
            #print(x_indices, y_indices)
            c = 0
            for ny in y_indices:
                for nx in x_indices:
                    old_matrix_coords[ny, nx] = c
                    c += 1
            #as_list = np.reshape(old_matrix_coords, ysz * xsz)

            # extract fids from dataset , i.e. cut
            get_fids = []
            for n in np.arange(0, csi.rawdatajob0.shape[0], fidsz):
                # transform the long FID into an array were every entry has the FID of a certain pixel
                get_fids.append(csi.rawdatajob0[cut_off + n : n + fidsz])
            # make an array
            # update fidsz in case we decide to loose the first 70 entries of each fid as there is
            # no signal there
            fidsz = fidsz - cut_off
            get_fids = np.array(get_fids)
            # reorder indices according to
            k_space_array = np.zeros_like(csi.seq2d[cut_off:, :, :, 0],
                                          dtype=np.complex128)
            print(k_space_array.shape)
            for idx in np.arange(0, get_fids.shape[0], 1):
                placement_idx = np.where(old_matrix_coords == idx)
                # found the index where each fid needs to go
                # the first one in the rawdata file is the center of k space
                # and so on outwards
                nx = placement_idx[0][0]
                ny = placement_idx[1][0]
                # possibly change the dimensions of k space array
                k_space_array[:, ny, nx] = (get_fids[int(idx), :])

        elif read_orient == 'A_P':

            c = 0
            for ny in y_indices:
                for nx in x_indices:
                    old_matrix_coords[ny, nx] = c
                    c += 1
            #as_list = np.reshape(old_matrix_coords, ysz * xsz)

            # extract fids from dataset , i.e. cut
            get_fids = []
            for n in np.arange(0, csi.rawdatajob0.shape[0], fidsz):
                # transform the long FID into an array were every entry has the FID of a certain pixel
                get_fids.append(csi.rawdatajob0[cut_off + n : n + fidsz])
            # make an array
            # update fidsz in case we decide to loose the first 70 entries of each fid as there is
            # no signal there
            fidsz = fidsz - cut_off
            get_fids = np.array(get_fids)
            # reorder indices according to
            k_space_array = np.zeros_like(csi.seq2d[cut_off:, :, :, 0],
                                          dtype=np.complex128)
            print(k_space_array.shape)
            for idx in np.arange(0, get_fids.shape[0], 1):
                placement_idx = np.where(old_matrix_coords == idx)
                # found the index where each fid needs to go
                # the first one in the rawdata file is the center of k space
                # and so on outwards
                nx = placement_idx[0][0]
                ny = placement_idx[1][0]
                # possibly change the dimensions of k space array
                k_space_array[:, nx, ny] = (get_fids[int(idx), :])
            k_space_array = np.transpose(k_space_array,[0,2,1])
        else:
            print(read_orient,' Orientation not know, no k-space reordering performed')
        print('k-space_shape',k_space_array.shape)

    # need to account for phase shift in k space due to phase offset
    # calculate shift in k space
    # transform phase shift from mm into voxels
    # minus because otherwise we shift in the different direction, this might have to be updated for future measurements-
    shift_x = (
        -csi.method["PVM_Phase0Offset"]
        * csi.method["PVM_Matrix"][0]
        / csi.method["PVM_Fov"][0]
    )
    shift_y = (
        -csi.method["PVM_Phase1Offset"]
        * csi.method["PVM_Matrix"][1]
        / csi.method["PVM_Fov"][1]
    )
    print("Voxel shift x:", shift_x, "Voxel shift y:", shift_y)
    # make a shift matrix
    Wx_1d = np.exp(
        (1j * 2 * np.pi * np.linspace(0, xsz, xsz) * shift_x) / xsz
    )
    Wx = np.tile(Wx_1d.T, [fidsz, ysz, 1])
    Wy_1d = np.exp(
        (1j * 2 * np.pi * np.linspace(0, ysz, ysz) * shift_y) / ysz
    )
    # make 1D array into a 2D one using tile
    Wy = np.tile(Wy_1d.T, [fidsz, xsz, 1])
    Wx = np.transpose(Wx, [0, 2, 1])
    Wy = np.transpose(Wy, [0, 1, 2])
    # apply shift in k space
    #print('Wx.shape=',Wx.shape)
    #print('Wy.shape',Wy.shape)
    print(Wx.shape,Wy.shape,k_space_array.shape)
    ordered_k_space = np.flip(np.flip((k_space_array * Wx) * Wy,1),2)
    # need to do a few flips and transposes in order to be of the same orientation as the
    # 2dseq file
    # shifted_final = np.transpose(
    #     np.flipud(
    #         np.fliplr(
    #             csi.FT_kspace_csi(ordered_k_space, cut_off=cut_off, LB=LB)
    #         )
    #     ),
    #     [1, 0, 2],
    # )

    # like this it has the same dimensions as the csi
    if read_orient == 'L_R':
        shifted_final = FT_kspace_csi(csi,ordered_k_space, cut_off=cut_off, LB=LB) # for reco fixing
        unshifted_final = FT_kspace_csi(csi,
            k_space_array, cut_off=cut_off, LB=LB
        )
    elif read_orient == 'A_P':
        shifted_final = np.transpose(FT_kspace_csi(csi,ordered_k_space, cut_off=cut_off, LB=LB),[0,2,1]) # for reco fixing
        unshifted_final = np.transpose(FT_kspace_csi(csi,k_space_array, cut_off=cut_off, LB=LB),[0,2,1])
    else:
        pass

    return shifted_final


## Perform reco

In [None]:
complex_csi_reco_in_vivo=Reco_CSI_animal(csi,cut_off=70,LB=5)

In [None]:
fig,ax=plt.subplots(1,2)
ax[0].imshow(np.sum(csi.seq2d,axis=0))
ax[1].imshow(np.abs(np.sum(np.abs(complex_csi_reco_in_vivo),axis=0)))

In [None]:
print('Voxel Number [4,8] in the CSI corresponds to the PRESS voxel')
plt.close('all')
fig,ax=plt.subplots(1,figsize=(figsize/2,figsize/3),tight_layout=True)
press_ppm,press_spec = mv_press.get_ppm(70), np.abs(np.squeeze(mv_press.get_fids_spectra(5,70)[0])[:,0])
csi_spec = np.abs(complex_csi_reco_in_vivo[:,8,4])
ppm_csi=hypermri.utils.utils_spectroscopy.get_freq_axis(csi,cut_off=70)
ax.plot(press_ppm,(press_spec-np.mean(press_spec[0:50]))/np.std(press_spec[0:50]),label='MV-PRESS',color='r')
ax.plot(ppm_csi,(csi_spec-np.mean(csi_spec[0:50]))/np.std(csi_spec[0:50]),label='CSI',color='k')
ax.set_ylabel('SNR')
ax.set_xlabel(r"$\sigma$[ppm]")
ax.set_xlim([190,160])
ax.legend()

csi_snrs = [np.max((csi_spec[0:120]-np.mean(csi_spec[0:50]))/np.std(csi_spec[0:50])),np.max((csi_spec[120:]-np.mean(csi_spec[0:50]))/np.std(csi_spec[0:50]))]
mvpress_snrs = [np.max((press_spec[0:512]-np.mean(press_spec[0:50]))/np.std(press_spec[0:50])),np.max((press_spec[512:]-np.mean(press_spec[0:50]))/np.std(press_spec[0:50])) ]
print('CSI-SNR Kidney Pyruvate= ',csi_snrs[0])
print('MV-PRESS SNR-Kidney Pyruvate = ',mvpress_snrs[0])
print('CSI-SNR Kidney Lactate = ',csi_snrs[1])
print('MV-PRESS SNR-Kidney Lactate = ',mvpress_snrs[1])


# 3. PSF simulation

### Load siumulated RF pulse profile from Matlab tool

In [None]:
press_180=scio.loadmat(savepath+'180_ref_pulse_PRESS_simul_2mm_1ms.mat')
y_range = press_180['parameters']['yrange_cm'][0][0][0][0] * 10
y_pts = press_180['parameters']['yrange_cm_npoints'][0][0][0][0]

x_ax_press_simul=np.linspace(-y_range/2,y_range/2,y_pts)
ref_profile_180=press_180['data']['mz'][0][0][:,1199]

In [None]:
def generate_bruker_centric_centric_high_res(ysz=10,xsz=10):

    y_indices = []
    x_indices = []
    
    for n in range(ysz):
        index = n // 2 if n % 2 == 0 else -(n // 2 + 1)
        y_indices.append(index+ math.floor(ysz / 2.0))
    for n in range(xsz):
        index = n // 2 if n % 2 == 0 else -(n // 2 + 1)
        x_indices.append(index+ math.floor(xsz / 2.0))

    pattern = np.zeros((ysz, xsz))
    c = 0
    for ny in y_indices:
        for nx in x_indices:
            pattern[ny, nx] = c
            c += 1

    return pattern
def generate_psf(sampling_pattern, Mz, alpha_deg):
    """Generate the PSF based on the sampling pattern and hyperpolarized signal decay."""
    psf = np.zeros_like(sampling_pattern, dtype=float)
    for i in range(len(Mz)):
        psf[sampling_pattern == i] = np.sin(np.deg2rad(alpha_deg)) * Mz[i]
        # psf[sampling_pattern == i] = Mz[i]
    return psf
def create_sample_image(nx, ny, shape, size):
    sample_image = np.zeros((nx, ny), dtype="complex")
    center_x, center_y = nx // 2, ny // 2
    radius = int(min(nx, ny) * size / 2)
    print(radius)

    if shape == "square":
        x_start, x_end = center_x - radius, center_x + radius
        y_start, y_end = center_y - radius, center_y + radius
        sample_image[x_start:x_end, y_start:y_end] = 1 + 0j

    if shape == "2 squares":
        x_start_1, x_end_1 = center_x + radius // 2, center_x + 3 * radius // 2
        y_start_1, y_end_1 = center_y + radius // 2, center_y + 3 * radius // 2
        sample_image[x_start_1:x_end_1, y_start_1:y_end_1] = 1 + 0j

        x_start_2, x_end_2 = center_x - 3 * radius // 2, center_x - radius // 2
        y_start_2, y_end_2 = center_y - 3 * radius // 2, center_y - radius // 2
        sample_image[x_start_2:x_end_2, y_start_2:y_end_2] = 1 + 0j
        print(x_start_1, x_end_1, y_start_1, y_end_1)
        print(x_start_2, x_end_2, y_start_2, y_end_2)

    elif shape == "circle":
        y, x = np.ogrid[-center_x : nx - center_x, -center_y : ny - center_y]
        mask = x * x + y * y <= radius * radius
        sample_image[mask] = 1 + 0j

    return sample_image



## Create sampling pattern and set parameters

In [None]:

nx=40
ny=26
sampling_pattern=generate_bruker_centric_centric_high_res(ny,nx)
M0_au=1e5
TR_s=csi.method['PVM_RepetitionTime']/1000
T1_s=30
alpha_deg=2
sample_shape="square"
sample_size=0.4
noise_level=0.0
cmap='magma'

## Simulate PSF

In [None]:
n_excitations = nx * ny
Mz, sampling_time = hyperpolarized_Mz_flipangle_T1(
    M0_au=M0_au,
    n_excitations=n_excitations,
    TR_s=TR_s,
    T1_s=T1_s,
    alpha_deg=alpha_deg,
    plot=False,
    interactive=False,
)
Mz_T1 = hyperpolarized_Mz_T1(M0_au=M0_au, t_s=sampling_time, T1_s=T1_s)
Mz_flipangle = hyperpolarized_Mz_flipangle(
    M0_au=M0_au, n_excitations=n_excitations, alpha_deg=alpha_deg
)

# Generate PSF
psf = generate_psf(sampling_pattern, Mz, alpha_deg)
psf = psf / psf.shape[0] / psf.shape[1]

# Create sample image
sample_image = create_sample_image(nx, ny, sample_shape, sample_size)
noise = np.random.normal(
    0, noise_level, sample_image.shape
) + 1j * np.random.normal(0, noise_level, sample_image.shape)
sample_image += noise
sample_image = sample_image / np.sum(np.abs(sample_image))

# Calculate PSF FT and convolved image
psf_ft = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(psf)))
sample_kspace = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(sample_image)))
convolved_image = np.abs(
    np.fft.ifftshift(
        np.fft.ifft2(
            np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(sample_image)))
            * np.rot90(psf)
        )
    )
)
convolved_image = convolved_image / np.sum(np.abs(convolved_image))

fig,ax=plt.subplots(2,2,tight_layout=True,figsize=(6.9/1.5,6.9/1.5))

img1=ax[0,0].imshow(sampling_pattern.T,cmap=cmap)
img2=ax[0,1].imshow(np.abs(sample_image)/np.max(np.abs(sample_image)),cmap=cmap)
img3=ax[1,1].imshow(100*((convolved_image/np.max(convolved_image))-(np.abs(sample_image)/np.max(np.abs(sample_image)))),cmap='coolwarm')
img4=ax[1,0].imshow(convolved_image/np.max(convolved_image),cmap=cmap,vmin=0,vmax=1)

fig.colorbar(img1,ax=ax[0,0],label='k-space points',ticks=[0,750])
fig.colorbar(img2,ax=ax[0,1],label='I [a.u.]',ticks=[0,1])
fig.colorbar(img4,ax=ax[1,0],label='I [a.u.]',ticks=[0,1])
fig.colorbar(img3,ax=ax[1,1],label='I [\%]',ticks=[-30,0,30])

ax[0,0].set_title('A',loc='left')
ax[0,1].set_title('B',loc='left')
ax[1,0].set_title('C',loc='left')
ax[1,1].set_title('D',loc='left')

for n in range(2):
    for m in range(2):
        ax[n,m].axis('off')


In [None]:
fig,ax=plt.subplots(1,figsize=(1.6*5,1.6*3),tight_layout=True)
x_ax=np.linspace(-2.5,2.5,sample_image.shape[1])
ax.plot(x_ax,convolved_image[int(sample_image.shape[0]/2),:]/np.max(convolved_image[int(sample_image.shape[0]/2),:]),label='CSI',color='k')
ax.plot(x_ax_press_simul,(-1*ref_profile_180+1)/2,color='r',label='PRESS')

ax.plot(x_ax,sample_image[int(sample_image.shape[0]/2),:]/np.max(sample_image[int(sample_image.shape[0]/2),:]),label='Sample')
ax.vlines(-1,0,1,color='g',linestyle='dashed',label='Voxel')
ax.vlines(1,0,1,color='g',linestyle='dashed')

ax.set_ylabel('I [a.u.]')
ax.legend()
