In [None]:
import hypermri
import hypermri.utils.utils_anatomical as ut_anat
import sys
import hypermri.utils.utils_anatomical as ut_anat
import hypermri.utils.utils_fitting as ut_fitting
import hypermri.utils.utils_spectroscopy as ut_spec
from scipy.stats import rayleigh
import os
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
import pandas as pd
import datetime
from matplotlib.patches import Rectangle
from scipy.optimize import curve_fit 
from hypermri.utils.utils_sv_spectroscopy import Plot_Voxel_on_Anat

    
def get_colors_from_cmap(cmap_name, N):
    cmap = plt.get_cmap(cmap_name)
    colors = cmap(np.linspace(0, 1, N))
    return colors

from hypermri.utils.utils_sv_spectroscopy import Plot_Voxel_on_Anat


# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri or magritek folders
%load_ext autoreload
%autoreload 2

%matplotlib widget
import matplotlib
from matplotlib import rc
rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
rc("text", usetex=True)
matplotlib.rcParams.update({"font.size": 11})


In [None]:
# define scan path
dirpath = basepath + '/'
scans = hypermri.BrukerDir(dirpath,verbose=True)


# Load Data

In [None]:
slaser_diss1 = scans[17]
sp90diss1=scans[19]

axial=scans[11]
coronal=scans[32]
sagittal=scans[16]

press_diss2 = scans[23]
sp90diss2=scans[24]

animal_ID=''
date=''


In [None]:
TR=slaser_diss1.method['PVM_RepetitionTime']/1000
NR=slaser_diss1.method['PVM_NRepetitions']
FA=slaser_diss1.method['VoxPul1'][2]
voxsize=slaser_diss1.method['PVM_VoxArrSize']
nvox=slaser_diss1.method['PVM_NVoxels']

print('TR=',TR,'s','NR=',NR,'FA=',FA,'Size',voxsize[0])

# Plot shimcheck

In [None]:
mvpress_shimcheck=scans[27]

In [None]:
fig,ax=plt.subplots(2,1,tight_layout=True,figsize=(8,8))
locations=['Kidney l','Kidney r','Kidney r','Kidney l','Artery','Liver','Muscle','Heart','Abdomen','Abdomen']
proton_shim_mvpress=mvpress_shimcheck.get_fids_spectra(0,0)[0]
max_val=np.max(proton_shim_mvpress)
for n in range(2):

    ax[n].plot(mvpress_shimcheck.ppm_axis,np.abs(np.squeeze(proton_shim_mvpress))[:,n]/max_val)

    ax[n].set_title(locations[n])

# Make one plot with both dissolutions

In [None]:
press_spec=press_diss2.get_fids_spectra(5,70)[0]
ppm_press=press_diss2.get_ppm(70)

slaser_spec=slaser_diss1.get_fids_spectra(5,70)[0]
ppm_slaser=slaser_diss1.get_ppm(70)

In [None]:
plt.close('all')
fig,ax = plt.subplots(2,2,figsize=(5,5),subplot_kw={"projection":"3d"},tight_layout=True)

# first voxel sLASER
time_scale_vox1 = (np.arange(0,slaser_diss1.method['PVM_RepetitionTime']*slaser_diss1.method['PVM_NRepetitions'],slaser_diss1.method['PVM_RepetitionTime'])/1000)
max_val=np.max(np.abs(slaser_spec))
Z = np.array([np.abs((slaser_spec[:,0,0,0,t,0]))/max_val for t in range(15)])
X,Y = np.meshgrid(ppm_slaser,time_scale_vox1)
ax[0,0].plot_wireframe(X,Y,Z,rstride=1,cstride=0,color='k',linewidth=0.7)


        
# second voxel sLASER
time_scale_vox2 = (np.arange(slaser_diss1.method['PVM_RepetitionTime'],slaser_diss1.method['PVM_RepetitionTime']*slaser_diss1.method['PVM_NRepetitions']+slaser_diss1.method['PVM_RepetitionTime'],slaser_diss1.method['PVM_RepetitionTime'])/1000)
z = np.array([np.abs((slaser_spec[:,1,0,0,t,0]))/max_val for t in range(15)])
X,Y = np.meshgrid(ppm_slaser,time_scale_vox2)
ax[0,1].plot_wireframe(X,Y,z,rstride=1,cstride=0,color='k',linewidth=0.7)


# PRESS first voxel
time_scale_vox1 = (np.arange(0,press_diss2.method['PVM_RepetitionTime']*press_diss2.method['PVM_NRepetitions'],press_diss2.method['PVM_RepetitionTime'])/1000)
z = np.array([np.abs((press_spec[:,0,0,0,t,0]))/max_val for t in range(15)])
X,Y = np.meshgrid(ppm_press,time_scale_vox1)
ax[1,0].plot_wireframe(X,Y,z,rstride=1,cstride=0,color='k',linewidth=0.7)



time_scale_vox2 = (np.arange(press_diss2.method['PVM_RepetitionTime'],press_diss2.method['PVM_RepetitionTime']*press_diss2.method['PVM_NRepetitions']+press_diss2.method['PVM_RepetitionTime'],press_diss2.method['PVM_RepetitionTime'])/1000)
z = np.array([np.abs((press_spec[:,1,0,0,t,0]))/max_val for t in range(15)])
X,Y = np.meshgrid(ppm_press,time_scale_vox2)
ax[1,1].plot_wireframe(X,Y,z,rstride=1,cstride=0,linewidth=0.7,color='k')


for n in range(4):
    nx,ny=n%2,n//2
    ax[1,ny].set_xlabel(r'$\sigma$ [ppm]')
    ax[nx,ny].set_xlim([np.max(x),np.min(x)])
    ax[nx,ny].set_zlabel('I [a.u.]')
    #ax[nx,ny].set_ylabel('repetition')

    ax[nx,ny].set_yticks([])
    ax[nx,ny].set_zlim([0,0.7])
    ax[nx,ny].set_zticks([0,0.5])
    ax[0,ny].set_xticks([])
    
    ax[nx,ny].xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax[nx,ny].yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax[nx,ny].zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

    ax[nx,ny].xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax[nx,ny].yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
    ax[nx,ny].zaxis._axinfo["grid"]['color'] =  (1,1,1,0)



# Anatomicals

In [None]:
plt.close('all')
vox_names=['Kidney L', 'Kidney R']
fig,ax=plt.subplots(2,2,figsize=(4,8))
matplotlib.rcParams.update({"font.size": 11})
for n in range(len(vox_names)):
    Plot_Voxel_on_Anat(press_diss2,coronal,ax[0,n],n,0,15,vox_color='C'+str(n))
    ax[0,n].axis('off')
    ax[0,n].set_title(None)
    ax[0,n].set_xlim([12,-12])
    ax[0,n].add_patch(Rectangle((-9,-23),5,1,fc='w'))
    ax[0,n].text(-3,-21.5,'5 mm',color='w')
    #ax[0,n].text(3,21.5,vox_names[n],color='w',size='11')

for n in range(len(vox_names)):
    Plot_Voxel_on_Anat(press_diss2,axial,ax[1,n],n,0,15,vox_color='C'+str(n))
    ax[1,n].axis('off')
    ax[1,n].set_title(None)
    #ax[1,n].set_xlim([12,-12])
    #ax[1,n].set_ylim([-13,13])
    

plt.subplots_adjust(hspace=0,wspace=0.1)


# Plot time curves

In [None]:
def fit_mvpress(experiment,animal_id,date):
    metabs = ['pyruvate', 'lactate','pyruvatehydrate','alanine']

    fit_params = {}

    fit_params["zoomfactor"] = 1.5
    fit_params["max_t2_s"] = 0.5
    fit_params["min_t2_s"] = 0.001
    fit_params["range_t2s_s"] = 0.05

    # get the indices of the peaks:
    fit_params["metabs"] = metabs
    fit_params["fit_range_repetitions"] = NR
    fit_params["range_freqs_Hz"] = 25
    fit_params["cut_off"] = 70
    fit_params["niter"] = 1 # number of iterations:
    fit_params["npoints"] = 21 # number of tested points per iteration:
    fit_params["rep_fitting"] = 11 # number of tested points per iteration:
    fit_params["provided_dims"] = ["fid","repetitions"]


    fit_params = ut_fitting.def_fit_params(fit_params=fit_params, data_obj=experiment)


    cut_off_spec=np.fft.fftshift(np.fft.fft(experiment.complex_fids[fit_params["cut_off"]:,:],axis=0),axes=(0,))

    fit_params = ut_fitting.def_fit_params(fit_params=fit_params, data_obj=experiment)

    fit_spectrums, fit_amps, fit_freqs, fit_t2s, fit_stds  = ut_fitting.fit_data_pseudo_inv(input_data=cut_off_spec,
                                                                                 data_obj=experiment,
                                                                      fit_params=fit_params,
                                                                  use_multiprocessing=True)

    fit_freqs_ppm = ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs), hz_axis=fit_params["freq_range_Hz"], ppm_axis=fit_params["freq_range_ppm"], ppm_axis_flipped=False)
    fit_stds_ppm = ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_stds), hz_axis=fit_params["freq_range_Hz"], ppm_axis=fit_params["freq_range_ppm"], ppm_axis_flipped=False)
    

    pyr_amp=np.abs(np.squeeze(fit_amps[:,:,:,:,:,:,0]))
    lac_amp=np.abs(np.squeeze(fit_amps[:,:,:,:,:,:,1]))

    d_pyr_amp = np.squeeze(fit_stds[..., 0, 0])
    d_lac_amp = np.squeeze(fit_stds[..., 1, 0])
    

    AUCR=lac_amp/pyr_amp

    d_AUCR = np.abs(AUCR * np.sqrt((d_pyr_amp / pyr_amp)**2 + (d_lac_amp / lac_amp)**2))


    return fit_spectrums, fit_amps, fit_freqs, fit_t2s, fit_stds,fit_freqs_ppm,fit_stds_ppm, AUCR, d_AUCR,cut_off_spec



In [None]:
fit_spectrums_sLASER, fit_amps_sLASER, fit_freqs_sLASER, fit_t2s_sLASER, fit_stds_sLASER,fit_freqs_ppm_sLASER,fit_stds_ppm_sLASER, AUCR_sLASER, d_AUCR_sLASER,cut_off_spec_sLASER=fit_mvpress(slaser_diss1,animal_ID,date)

In [None]:
fig,ax=plt.subplots(NR,len(locations),tight_layout=True,figsize=(6,15))
for rep in range(NR):
    [ax[rep,n].plot(ppm,np.abs(np.sum(np.squeeze(fit_spectrums_sLASER)[:,n,rep,:],axis=1)),color='r') for n in range(len(locations))]
    [ax[rep,n].plot(ppm,np.abs(np.squeeze(cut_off_spec_sLASER)[:,n,rep]),color='k',alpha=0.5) for n in range(len(locations))]
    [ax[rep,n].set_xlim([185,165]) for n in range(len(locations))]
    [ax[rep,n].set_xticks([185,180,175,170,165]) for n in range(len(locations))]
    [ax[rep,n].set_title(locations[n]) for n in range(len(locations))]


In [None]:
fit_spectrums_PRESS, fit_amps_PRESS, fit_freqs_PRESS, fit_t2s_PRESS, fit_stds_PRESS,fit_freqs_ppm_PRESS,fit_stds_ppm_PRESS, AUCR_PRESS, d_AUCR_PRESS,cut_off_spec_PRESS=fit_mvpress(press_diss2,animal_ID,date)

In [None]:
fig,ax=plt.subplots(1,4,figsize=(8,2),tight_layout=True)
max_sig=np.max(np.abs(fit_amps_sLASER))
ax[0].plot(slaser_diss1.time_ax_array[0],np.abs(np.squeeze(fit_amps_sLASER)[0,:,0])/max_sig,label='Pyr')
ax[0].plot(slaser_diss1.time_ax_array[0],np.abs(np.squeeze(fit_amps_sLASER)[0,:,1])/max_sig,label='Lac')
ax[1].plot(slaser_diss1.time_ax_array[1],np.abs(np.squeeze(fit_amps_sLASER)[1,:,0])/max_sig,label='Pyr')
ax[1].plot(slaser_diss1.time_ax_array[1],np.abs(np.squeeze(fit_amps_sLASER)[1,:,1])/max_sig,label='Lac')



ax[2].plot(press_diss2.time_ax_array[0],np.abs(np.squeeze(fit_amps_PRESS)[0,:,0])/max_sig,label='Pyr')
ax[2].plot(press_diss2.time_ax_array[0],np.abs(np.squeeze(fit_amps_PRESS)[0,:,1])/max_sig,label='Lac')
ax[3].plot(press_diss2.time_ax_array[1],np.abs(np.squeeze(fit_amps_PRESS)[1,:,0])/max_sig,label='Pyr')
ax[3].plot(press_diss2.time_ax_array[1],np.abs(np.squeeze(fit_amps_PRESS)[1,:,1])/max_sig,label='Lac')


for n in range(4):
    ax[3].legend()
    ax[n].set_xlabel('t [s]')
    ax[0].set_ylabel('AUC [a.u.]')
    ax[n].set_yticks([0,0.5,1])
    ax[n].set_xticks([0,6,12,18])
    ax[n].set_xlim([0,20])
ax[0].set_title('semi-LASER')
ax[2].set_title('PRESS')


In [None]:
def compute_snr(experiment,cutoff_spec,fit_spectrums,nvox=2,metabs=['pyruvate', 'lactate','pyruvatehydrate','alanine']):
    plt.close('all')
    NR=experiment.method['PVM_NRepetitions']
    noise_floor=np.ones((nvox,NR))
    for voxel in range(2):
        fig,ax=plt.subplots(2,NR,tight_layout=True,figsize=(3*NR,3))
        for n in range(NR):
            ax[0,n].plot(np.abs((np.squeeze(cutoff_spec)[:,voxel,n]-np.mean(np.abs(np.squeeze(cutoff_spec)[0:150,voxel,n])))/np.std(np.abs(np.squeeze(cutoff_spec)[0:150,voxel,n]))))
            ax[0,n].plot(np.abs((np.squeeze(cutoff_spec)[:,voxel,n]-np.mean(np.abs(np.squeeze(cutoff_spec)[0:150,voxel,n])))/np.std(np.abs(np.squeeze(cutoff_spec)[0:150,voxel,n])))[0:150])
            noise_spec=np.abs(((np.squeeze(cutoff_spec)[:,voxel,n]-np.mean(np.real(np.squeeze(cutoff_spec)[0:150,voxel,n])))/np.std(np.real(np.squeeze(cutoff_spec)[0:150,voxel,n]))))[0:150]

            params = rayleigh.fit(noise_spec)
            scale = params[1]

            # Define x values for plotting the fitted distribution
            x = np.linspace(0, max(noise_spec), 100)
            pdf_fitted = rayleigh.pdf(x, loc=0, scale=scale)


            ax[1,n].hist(noise_spec, bins=30, density=True, alpha=0.3, color='C0', edgecolor='black')
            ax[1,n].plot(x, pdf_fitted, 'r-', label=f'Rayleigh fit (scale={scale:.2f})')
            # Create a frozen Rayleigh distribution object with the fitted scale parameter
            fitted_rayleigh = rayleigh(scale=scale)

            # Get mean and standard deviation
            mean = fitted_rayleigh.mean()
            std_dev = fitted_rayleigh.std()
            noise_floor[voxel,n]=mean+std_dev
            ax[1,n].set_title(str(mean.round(1))+'±'+str(std_dev.round(1)))
        ax[0,0].set_title('Voxel '+str(voxel))
    

        
    fig,ax=plt.subplots(2,NR,tight_layout=True,figsize=(NR,2))
    peak_snrs=np.ones((4,nvox,NR))*np.nan
    mean_noise=np.array([[np.mean(np.real(np.squeeze(cutoff_spec)[0:150,n,rep])) for rep in range(NR)] for n in range(nvox)])
    std_noise=np.array([[np.std(np.real(np.squeeze(cutoff_spec)[0:150,n,rep])) for rep in range(NR)] for n in range(nvox)])

    for rep in range(NR):
        for voxel in range(nvox):
            for peak in range(4):
                max_peak_fit_val=np.max(np.abs(np.squeeze(fit_spectrums)[:,voxel,rep,peak]))
                
                snr=np.round((max_peak_fit_val-mean_noise[voxel,rep])/std_noise[voxel,rep],2)
                peak_snrs[peak,voxel,rep]=snr
                ax[voxel,rep].plot(experiment.get_ppm(70),np.abs(np.squeeze(fit_spectrums)[:,voxel,peak]),label=metabs[peak]+','+str(snr))
                ax[voxel,rep].set_title(str(np.round(noise_floor[voxel,rep],0)))
            ax[voxel,rep].set_xlim([195,155])

            handles, labels = ax[voxel,rep].get_legend_handles_labels()
            unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
            ax[voxel,rep].legend(*zip(*unique),fontsize=4)
            

            ax[voxel,0].set_title(locations[voxel])
    
    return peak_snrs,noise_floor


In [None]:
peak_snrs_slaser,noise_slaser=compute_snr(slaser_diss1,cut_off_spec_sLASER,fit_spectrums_sLASER)

In [None]:
peak_snrs_press,noise_press=compute_snr(press_diss2,cut_off_spec_PRESS,fit_spectrums_PRESS)

In [None]:
output_df_press=pd.DataFrame(columns=['Repetition','Voxel','SNR Pyr','SNR Lac','SNR Hydr','SNR Ala','Noise'])
output_df_press['Repetition']=np.concatenate([[n for n in range(NR)],[n for n in range(NR)]])
output_df_press['Voxel']=np.concatenate([[locations[0] for n in range(NR)],[locations[1] for n in range(NR)]])
for i, col in enumerate(['SNR Pyr','SNR Lac','SNR Hydr','SNR Ala']):
    array_flattened = peak_snrs_press[i, :, :].T.flatten(order='F')
    output_df_press[col] = array_flattened
output_df_press['Noise']= noise_press.flatten()
output_df_press.round(0)

In [None]:
output_df_slaser=pd.DataFrame(columns=['Repetition','Voxel','SNR Pyr','SNR Lac','SNR Hydr','SNR Ala','Noise'])
output_df_slaser['Repetition']=np.concatenate([[n for n in range(NR)],[n for n in range(NR)]])
output_df_slaser['Voxel']=np.concatenate([[locations[0] for n in range(NR)],[locations[1] for n in range(NR)]])

for i, col in enumerate(['SNR Pyr','SNR Lac','SNR Hydr','SNR Ala']):
    # Flatten the array along the last axis for the current column
    array_flattened = peak_snrs_slaser[i, :, :].T.flatten(order='F')  # Column-major (alternating)
    # Assign flattened values to the corresponding column in the DataFrame
    output_df_slaser[col] = array_flattened
output_df_slaser['Noise']= noise_slaser.flatten()
output_df_slaser.round(0)