In [None]:
import numpy as np
import matplotlib.pyplot as plt

import ipywidgets as widgets
import pandas as pd
import os
import pydicom
import scipy.io as sio
import hypermri.utils.utils_general as utg
import hypermri.utils.utils_fitting as utf
import hypermri.utils.utils_anatomical as uta
import hypermri.utils.utils_spectroscopy as uts
import seaborn as sns

from hypermri.utils.utils_fitting import def_fit_params, fit_t2_pseudo_inv, fit_freq_pseudo_inv, fit_data_pseudo_inv, fit_func_pseudo_inv, plot_fitted_spectra, basefunc
from hypermri.utils.utils_general import get_gmr, calc_sampling_time_axis
from hypermri.utils.utils_spectroscopy import apply_lb, multi_dim_linebroadening, get_metab_cs_ppm, generate_fid, get_freq_axis, make_NDspec_6Dspec, freq_to_index, find_npeaks

import sys
# define paths:
sys.path.append('../../')

import hypermri.utils.utils_spectroscopy as uts


# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri or magritek folders
%load_ext autoreload
%autoreload 2

%matplotlib widget

import Template_Cambridge
basepath,savepath = Template_Cambridge.import_all_packages(False,True)
def get_colors_from_cmap(cmap_name, N):
    cmap = plt.get_cmap(cmap_name)
    colors = cmap(np.linspace(0, 1, N))
    return colors


In [None]:
revision_path=r'.../Publication/Revision1/RefittedData/Renal/'

In [None]:
def perform_analysis(studyfolder,exam_num,error_thld='both',fit_error_threshold = 3,snr_threshold=5):
    animal_id = 'MRE-'+studyfolder
    if studyfolder in ['012','014','017']:
        raw_file=sio.loadmat(os.path.join(basepath, animal_id,  animal_id+"_raw_fid_"+str(exam_num)+".mat"))
        raw_data=raw_file['raw_data']
        dyn_fid = raw_data[0::8, :, 0, 0, :]
        all_files_in_dir = os.listdir(os.path.join(savepath,'fit_results_100Hz_400ms/'))
        fitted_files = []
        for index,file in enumerate(all_files_in_dir):
            if file.endswith('.pkl'):
                fitted_files.append(file)
                #print('Found files ',file)
        # sort files by number
        fitted_files.sort()
        # look for file matching the studyfolder_num
        for n in range(len(fitted_files)):
            if fitted_files[n].find(str(animal_id)+'_fit_spectra_100Hz_400ms_'+str(exam_num))>=0:
                load_file=fitted_files[n]
            else:
                pass
        print('-----')
        print('Selected study:',studyfolder)
        print('Loading',load_file)

    else:
        raw_file=sio.loadmat(os.path.join(basepath, animal_id,  animal_id+"_raw_fid.mat"))
        raw_data=raw_file['raw_data']
        dyn_fid = raw_data[0::8, :, 0, 0, :]
        all_files_in_dir = os.listdir(os.path.join(savepath,'fit_results_100Hz_400ms/'))
        fitted_files = []
        for index,file in enumerate(all_files_in_dir):
            if file.endswith('.pkl'):
                fitted_files.append(file)
                #print('Found files ',file)
        # sort files by number
        fitted_files.sort()
        # look for file matching the studyfolder_num
        for n in range(len(fitted_files)):
            if fitted_files[n].find(str(animal_id)+'_fit_spectra_100Hz_400ms__')>=0:
                load_file=fitted_files[n]
            else:
                pass
        print('-----')
        print('Selected study:',studyfolder)    
    print(load_file)
    dyn_spec = np.conj(np.flip(np.fft.fftshift(np.fft.fft(dyn_fid, axis=1), axes=(1,)), axis=1))
    input_data = uts.make_NDspec_6Dspec(input_data=dyn_spec, provided_dims=["reps", "spec", "z","chans"])
    fit_results = utg.load_as_pkl(dir_path=os.path.join(savepath,'fit_results_100Hz_400ms/'), filename=load_file, global_vars=globals())
    print('Loaded data',savepath+'/'+load_file)
    fit_spectrums = fit_results['fit_spectrums']
    fit_amps = fit_results['fit_amps']
    fit_freqs = fit_results['fit_freqs']
    fit_t2s = fit_results['fit_t2s']
    fit_params = fit_results['fit_params']
    fit_stds=fit_results['fit_stds']

        

    for chan in range(0,input_data.shape[5]):
        fig,axes=plt.subplots(input_data.shape[1],input_data.shape[4],figsize=(15,3))
        for slic in range(input_data.shape[1]):
            for rep in range(input_data.shape[4]):
                ax = axes[slic,rep]
                ax.plot(fit_params['freq_range_ppm'],np.abs(np.squeeze(input_data)[:,slic,rep,chan]))
                ax.plot(fit_params['freq_range_ppm'],np.abs(np.sum(np.squeeze(fit_spectrums)[:,slic,rep,chan,:],axis=1)))
                ax.set_xlim([160,190])
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_title('')
                ax.set_xlabel('')
                ax.set_ylabel('')

        [axes[n,0].set_ylabel('Slice '+str(n)) for n in range(input_data.shape[1])]
        [axes[0,n].set_title('Rep '+str(n)) for n in range(input_data.shape[4])]
        fig.suptitle('Channel '+str(chan))
        plt.subplots_adjust(wspace=0, hspace=0)


        
        # Frequency difference [Hz]:

    freq_diff_hz = np.take(fit_freqs, indices=1, axis=-1) - np.take(fit_freqs, indices=0, axis=-1)

    # Frequency difference [ppm]:
    freq_diff_ppm = uts.freq_Hz_to_ppm(freq_Hz  = freq_diff_hz,
                                       hz_axis  = fit_params["freq_range_Hz"],
                                       ppm_axis = fit_params['freq_range_ppm'],
                                       ppm_centered_at_0=True)

    # Temperature [Hz]:
    temp, _ = utf.temperature_from_frequency(frequency=freq_diff_ppm,
                                             calibration_type='5mM',
                                             frequency_is_ppm=True)

    # std[freq difference [Hz]] [0/1 = pyr/lac, 1=freqs]
    freq_diff_hz_std = np.sqrt(fit_stds[..., 0, 1]**2 +
                               fit_stds[..., 1, 1]**2 )

    freq_diff_ppm_std = uts.freq_Hz_to_ppm(freq_Hz=freq_diff_hz_std,
                                       hz_axis  = fit_params["freq_range_Hz"],
                                       ppm_axis = fit_params['freq_range_ppm'],
                                       ppm_centered_at_0=True)


    temp_plus_std, _ = utf.temperature_from_frequency(frequency=freq_diff_ppm+freq_diff_ppm_std,
                                             calibration_type='5mM',
                                             frequency_is_ppm=True)

    temp_std = np.abs(temp_plus_std - temp)
    
    
    ## SNR masking
    snr,noise=uts.compute_snr_from_fit(input_data,fit_spectrums)

    snr_mask_cond = (snr[..., 0] < snr_threshold) | (snr[..., 1] < snr_threshold)
    snr_mask = np.where(snr_mask_cond, np.nan, 1)
    snr_masked_temp=snr_mask*temp


    ## Fit error masking
    fit_error_mask=np.where(temp_std <= fit_error_threshold, 1, np.nan)
    fit_error_masked_temp=np.where(~np.isnan(fit_error_mask), temp, np.nan)


    # combine and report all three variants: 1. SNR masked, 2. Fit error masked and 3. combined
    snr_and_fit_error_masked_temp=snr_mask*fit_error_masked_temp

    def get_mean_t(masked_T):
        mean_t_per_slice=np.nanmean(np.squeeze(masked_T),axis=1)
        d_mean_t_per_slice=np.nanstd(np.squeeze(masked_T),axis=1)
        return mean_t_per_slice,d_mean_t_per_slice
    
    if error_thld=='both':
        mean_t_per_slice,d_mean_t_per_slice=get_mean_t(snr_and_fit_error_masked_temp)
        masked_T=snr_and_fit_error_masked_temp
        inverse_masked_T=np.where(np.isnan(snr_and_fit_error_masked_temp), temp, np.nan)
    elif error_thld=='SNR':
        mean_t_per_slice,d_mean_t_per_slice=get_mean_t(snr_masked_temp)
        masked_T=snr_masked_temp
        inverse_masked_T=np.where(np.isnan(snr_masked_temp), temp, np.nan)
    elif error_thld=='fiterror':
        mean_t_per_slice,d_mean_t_per_slice=get_mean_t(fit_error_masked_temp)
        masked_T=fit_error_masked_temp
        inverse_masked_T=np.where(np.isnan(fit_error_masked_temp), temp, np.nan)
    else:
        raise KeyError("Error thresholding must be selected")
    colors = get_colors_from_cmap('tab10', 10)
    for chan in range(input_data.shape[5]):
        fig,ax=plt.subplots(3,input_data.shape[1],figsize=(10,5),tight_layout=True)

        for n in range(input_data.shape[1]):

            ax[0,n].plot(inverse_masked_T[0,n,0,0,:,chan],'o',markersize=3,color='r',label='Excluded points')
            ax[0,n].plot(masked_T[0,n,0,0,:,chan],'o',markersize=3,color=colors[0])

            ax[1,n].plot(np.abs(temp_std[0,n,0,0,:,chan]),'o',markersize=3,color=colors[1])
            ax[2,n].plot(np.abs(freq_diff_hz_std[0,n,0,0,:,chan]),'o',markersize=3,color=colors[2])
            ax[1,n].hlines(3,0,15,linestyle='dashed',color='g')
            ax[2,n].hlines(1,0,15,linestyle='dashed',color='g')


            ax[0,n].set_ylim([20,45])
            ax[0,n].set_title('Slice '+str(n)+', T='+str(np.round(mean_t_per_slice[n,chan],1))+'±'+str(np.round(d_mean_t_per_slice[n,chan],1))+r'$^\circ$C')
            ax[2,n].set_xlabel('Repetition')
            ax[0,n].set_ylabel(r'T[$^\circ$C]')
            ax[1,n].set_ylabel(r'dT[$^\circ$C]')
            ax[2,n].set_ylabel(r'df[Hz]')
            ax[0,n].legend(fontsize=6)

        fig.suptitle('MRE-'+str(studyfolder)+' Temperature filtered')

 
        
    pyr_amp=np.abs(np.squeeze(fit_amps[:,:,:,:,:,:,0]))
    lac_amp=np.abs(np.squeeze(fit_amps[:,:,:,:,:,:,1]))

    d_pyr_amp = np.squeeze(fit_stds[..., 0, 0])
    d_lac_amp = np.squeeze(fit_stds[..., 1, 0])

    sum_pyr=np.sum(pyr_amp,axis=1)
    sum_lac=np.sum(lac_amp,axis=1)


    d_sum_pyr = np.sqrt(np.sum(d_pyr_amp**2, axis=1)) 
    d_sum_lac = np.sqrt(np.sum(d_lac_amp**2, axis=1)) 

    AUCR=sum_lac/sum_pyr

    d_AUCR = np.abs(AUCR * np.sqrt((d_sum_pyr / sum_pyr)**2 + (d_sum_lac / sum_lac)**2))


    for chan in range(input_data.shape[5]):
        fig,ax=plt.subplots(input_data.shape[1],1,tight_layout=True,figsize=(5,7))
        for slic in range(input_data.shape[1]):
            ax[slic].errorbar(np.arange(0,input_data.shape[4],1),pyr_amp[slic,:,chan]/np.max(pyr_amp[slic,:,chan]),yerr=np.abs(d_pyr_amp[slic,:,chan])/np.max(pyr_amp[slic,:,chan]),label='Pyruvate')
            ax[slic].errorbar(np.arange(0,input_data.shape[4],1),lac_amp[slic,:,chan]/np.max(pyr_amp[slic,:,chan]),yerr=np.abs(d_lac_amp[slic,:,chan])/np.max(pyr_amp[slic,:,chan]),label='Lactate')
            ax[slic].set_ylabel('Slice '+str(slic))
            ax[slic].legend()
            ax[slic].set_xlabel('Repetition')
            ax[slic].set_title('AUCR='+str(np.round(AUCR[slic,chan],1))+'±'+str(np.round(d_AUCR[slic,chan],3)))
        fig.suptitle('MRE-'+str(studyfolder)+', chan '+str(chan))


    data = {'ID':[],'exam':[],'slice': [], 'channel': [], 'T': [], 'dT': [], 'AUCR': [], 'dAUCR': [],'nT':[]}

    # Populate the dictionary
    for s in range(input_data.shape[1]):  # Row index (0–4)
        for c in range(input_data.shape[5]):  # Column index (0–7)
            data['ID'].append('MRE-'+str(studyfolder))
            data['exam'].append(exam_num)
            data['nT'].append((~np.isnan(masked_T[0,s,0,0,:,c])).sum())
            data['slice'].append(s)
            data['channel'].append(c)
            data['T'].append(mean_t_per_slice[s][c])
            data['dT'].append(d_mean_t_per_slice[s][c])
            data['AUCR'].append(AUCR[s][c])
            data['dAUCR'].append(d_AUCR[s][c])

    # Create the DataFrame
    df = pd.DataFrame(data)
    df.to_excel(revision_path+'MRE-'+studyfolder+'_'+str(exam_num)+'_results_thld_snr_fit.xlsx')
    print('Saved df to ',revision_path+'MRE-'+studyfolder+'_'+str(exam_num)+'_results_thld_snr_fit.xlsx')

In [None]:
studyfolder_nums = ['001','006','008','010','012','013','014','016','017','019','020','023']

In [None]:
for studyfolder in studyfolder_nums:
    if studyfolder in ['012','014']:
        for exam_num in range(1,3):
            perform_analysis(studyfolder,exam_num)
    else:
        perform_analysis(studyfolder,exam_num=1)