In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib

params = {"text.usetex" : True,
          "font.family" : "serif",
          "font.serif" : ["Computer Modern Serif"],
         'font.size':12}
plt.rcParams.update(params)
import ipywidgets as widgets
import datetime
import pandas as pd
import os
import math
import hypermri
from tqdm.auto import tqdm
from scipy.optimize import curve_fit
import scipy.io as sio
from matplotlib import cm
import pydicom
import hypermri.utils.utils_general as utg
import hypermri.utils.utils_spectroscopy as uts
import hypermri.utils.utils_fitting as utf
import sys
# define paths:
sys.path.append('../../../')
import Template_Cambridge
basepath,_ = Template_Cambridge.import_all_packages(scan_is_csi=True)
from hypermri.utils.utils_fitting import def_fit_params, fit_t2_pseudo_inv, fit_freq_pseudo_inv, fit_data_pseudo_inv, fit_func_pseudo_inv, plot_fitted_spectra, basefunc
from hypermri.utils.utils_general import get_gmr, calc_sampling_time_axis
from hypermri.utils.utils_spectroscopy import apply_lb, multi_dim_linebroadening, get_metab_cs_ppm, generate_fid, get_freq_axis, make_NDspec_6Dspec, freq_to_index, find_npeaks
# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri package
%load_ext autoreload
%autoreload 2

%matplotlib widget


In [None]:
revision_path=r'.../Publication/Revision1/RefittedData/CSI_brain/'

# Input

In [None]:
concentration = '5mM' # options: 10mM, 20mM, 50mM, 100mM, 200mM, 400mM, 600mM, todo: blood, ldh
error_thld='both' # options: SNR,fiterror
interpolation_factor=1

snr_threshold=5 # SNR
fit_error_threshold=3 #°C
temp_error_thld=fit_error_threshold
fit_path='2025_refit_t2_400ms_freq_100Hz/'

assume_ctr_s_is_center=True

studyfolder_nums=[122,127,128,129,130,131,132,133,134]


In [None]:
studyfolder_num=130

In [None]:
def analyze(studyfolder_num,fit_folder_name):
    fit_path=fit_folder_name
    all_files_in_dir = os.listdir(os.path.join(basepath,fit_path))
    fitted_files = []
    for index,file in enumerate(all_files_in_dir):
        if file.endswith('.pkl'):
            fitted_files.append(file)

    # sort files by number
    fitted_files.sort()
    # look for file matching the studyfolder_num
    for n in range(len(fitted_files)):
        if fitted_files[n].find(str(studyfolder_num))<0:
            pass
        else:
            file_to_load = fitted_files[n]

    fit_results=utg.load_as_pkl(dir_path=os.path.join(basepath, fit_path), filename=file_to_load, global_vars=globals())

    print('----------------------------------------------------------------------')
    print('Wanted to load file responding to number HV-',studyfolder_num)
    print('Loaded file ',os.path.join(basepath, fit_path),file_to_load)
    print('----------------------------------------------------------------------')
    fit_spectrums = fit_results['fit_spectrums']
    fit_amps = fit_results['fit_amps']
    fit_freqs = fit_results['fit_freqs']
    fit_t2s = fit_results['fit_t2s']
    fit_params = fit_results['fit_params']
    fit_stds=fit_results['fit_stds']
    fit_freqs_ppm  = np.array(uts.freq_Hz_to_ppm(freq_Hz=fit_freqs, hz_axis=fit_params["freq_range_Hz"], ppm_axis=fit_params["freq_range_ppm"]))
    pyr_lac_freq_diff = np.array([np.abs(fit_freqs_ppm[0,slice_num,:,:,0,0,0]-fit_freqs_ppm[0,slice_num,:,:,0,0,1]) for slice_num in range(fit_freqs_ppm.shape[1])])

    input_data, freq_range,time_axis=load(studyfolder_num)

    hz_axis = freq_range
    ppm_ax = fit_params['freq_range_ppm']
    temp_map,_ =  utf.temperature_from_frequency(pyr_lac_freq_diff,concentration,True)
    # Error of difference between lac and pyruvate in Hz
    fit_freqs_diff_stds = np.sqrt(np.abs(fit_stds[0,:,:,:,0,0,0,1]**2) +
                                            np.abs(fit_stds[0,:,:,:,0,0,1,1]**2))
    # get the uncertainty of the frequency difference in ppm:
    fit_freqs_diff_stds_ppm = uts.freq_Hz_to_ppm(fit_freqs_diff_stds,
                                                 hz_axis=hz_axis,
                                                 ppm_axis=ppm_ax,
                                                 ppm_centered_at_0=True)


    # calculate the temperature and frequency errors:
    fit_temps_stds, _ =  np.squeeze(temp_map - np.asarray(utf.temperature_from_frequency(
        pyr_lac_freq_diff+fit_freqs_diff_stds_ppm, calibration_type=concentration)))

    # Calculate mask:
    temp_error_mask = np.where(fit_temps_stds <= temp_error_thld, 1, np.nan)

    masked_pyr_lac_freq_diff=np.where(~np.isnan(temp_error_mask), pyr_lac_freq_diff, np.nan)
    temp_map,temp_map_stds = utf.temperature_from_frequency(masked_pyr_lac_freq_diff,concentration,True)
    mask_pyr_lac=temp_error_mask

    temp_map = np.array([utf.temp_from_ppm_and_concentration(c=concentration, 
                                                                 ppm=np.where(~np.isnan(mask_pyr_lac[k]), pyr_lac_freq_diff[k], np.nan), 
                                                                 return_kelvin=False) for k in range(mask_pyr_lac.shape[0])])
    pixels=[]
    for s in range(fit_freqs.shape[1]):
        pix=np.count_nonzero(~np.isnan(temp_map[s]))
        pixels.append(pix)
    output_dataframe = pd.DataFrame()
    output_dataframe['fmean']=[np.nanmean(masked_pyr_lac_freq_diff[n]) for n in range(masked_pyr_lac_freq_diff.shape[0])]
    output_dataframe['fmedian']=[np.nanmedian(masked_pyr_lac_freq_diff[n]) for n in range(masked_pyr_lac_freq_diff.shape[0])]
    output_dataframe['df']=[np.nanstd(masked_pyr_lac_freq_diff[n]) for n in range(masked_pyr_lac_freq_diff.shape[0])]
    output_dataframe['Tmean']=[np.nanmean(temp_map[n]) for n in range(temp_map.shape[0])]
    output_dataframe['Tmedian']=[np.nanmedian(temp_map[n]) for n in range(temp_map.shape[0])]
    output_dataframe['dT']=[np.nanstd(temp_map[n]) for n in range(temp_map.shape[0])]
    output_dataframe['pixels']=pix

    return input_data,fit_spectrums,fit_amps,fit_t2s,fit_freqs,fit_stds,fit_params, temp_map,output_dataframe, mask_pyr_lac

In [None]:
output_old=analyze(studyfolder_num,fit_folder_name=fit_folder_names[0])
output_40Hz=analyze(studyfolder_num,fit_folder_name=fit_folder_names[1])
output_80Hz_t190=analyze(studyfolder_num,fit_folder_name=fit_folder_names[3])
output_100Hz_t190=analyze(studyfolder_num,fit_folder_name=fit_folder_names[2])
output_100Hz_t250=analyze(studyfolder_num,fit_folder_name=fit_folder_names[4])
output_150Hz_t350=analyze(studyfolder_num,fit_folder_name=fit_folder_names[5])
output_150Hz_t500=analyze(studyfolder_num,fit_folder_name=fit_folder_names[6])
output_150Hz_t750=analyze(studyfolder_num,fit_folder_name=fit_folder_names[7])
output_150Hz_t1000=analyze(studyfolder_num,fit_folder_name=fit_folder_names[8])
output_150Hz_t5000=analyze(studyfolder_num,fit_folder_name=fit_folder_names[9])
output_100Hz_t400=analyze(studyfolder_num,fit_folder_name=fit_folder_names[10])


In [None]:
output_old=analyze(studyfolder_num,fit_folder_name=fit_folder_names[0])
output_40Hz=analyze(studyfolder_num,fit_folder_name=fit_folder_names[1])
output_150Hz_t5000=analyze(studyfolder_num,fit_folder_name=fit_folder_names[9])
output_100Hz_t400=analyze(studyfolder_num,fit_folder_name=fit_folder_names[10])


In [None]:
plt.close('all')
def plot_t2(output):
    
    fig,ax=plt.subplots(4,6,tight_layout=True,figsize=(8,6),width_ratios=(1,1,1,1,1,0.05))
    vmax=(output[6]['max_t2_s'][0]-output[6]['max_t2_s'][0]/1000)*1000
    masked_t2_pyr=output[3].squeeze()[:,:,:,0]*output[-1]*1000
    masked_t2_lac=output[3].squeeze()[:,:,:,1]*output[-1]*1000
    T_map=output[7].squeeze()
    cmap = plt.cm.viridis
    cmap = cmap.copy()  # Copy to modify
    cmap.set_over('red') 
    for n in range(output[3].squeeze().shape[0]):

        pyrt2img=ax[0,n].imshow(masked_t2_pyr[n],cmap=cmap,vmin=0,vmax=vmax)
        lact2img=ax[1,n].imshow(masked_t2_lac[n],cmap=cmap,vmin=0,vmax=vmax)
        timg=ax[2,n].imshow(T_map[n],cmap='jet',vmin=30,vmax=42)
        ax[2,n].set_title(str(np.nanmean(T_map[n]).round(1))+'±'+str(np.nanstd(T_map[n]).round(1))+'°C',fontsize=11)
        ax[3,n].scatter(np.ravel(T_map[n]),np.ravel(masked_t2_pyr[n]),s=5,c='k')
        mask_t2_bounds = (np.ravel(masked_t2_pyr[n]) == output[6]['max_t2_s'][0]*1000)
        
        ax[3,n].scatter(np.ravel(T_map[n])[mask_t2_bounds],np.ravel(masked_t2_pyr[n])[mask_t2_bounds],s=8,color='r')
        unique, counts = np.unique(np.ravel(T_map[n]), return_counts=True)
        duplicates = unique[counts > 1]
        mask_temp_stat = np.isin(np.ravel(T_map[n]), duplicates)
        ax[3,n].scatter(np.ravel(T_map[n])[mask_temp_stat],np.ravel(masked_t2_pyr[n])[mask_temp_stat],s=8,color='orange',marker='s')
        
        ax[0,0].set_ylabel('Pyruvate T2')
        ax[1,0].set_ylabel('Lactate T2')
        ax[2,0].set_ylabel('Temperature')
        ax[3,0].set_ylabel(r'$T_{2,pyr}$[ms]')
        ax[3,n].set_xlabel('T [°C]')
        for m in range(3):
            ax[m,n].set_xticks([])
            ax[m,n].set_yticks([])
            
        #print(np.nanmean(T_map[n]).round(1),np.nanstd(T_map[n]).round(1))
    fig.colorbar(timg,cax=ax[2,5],ticks=[30,35,40],label='T [°C]')
    
    fig.colorbar(pyrt2img,cax=ax[0,5],ticks=[0,int(math.ceil(output[6]['max_t2_s'][0]*1000/2)),int(output[6]['max_t2_s'][0]*1000/1.01)],label=r'$T_{2,pyr}$[ms]')
    fig.colorbar(lact2img,cax=ax[1,5],ticks=[0,int(math.ceil(output[6]['max_t2_s'][0]*1000/2)),int(output[6]['max_t2_s'][0]*1000/1.01)],label=r'$T_{2,lac}$[ms]')
    ax[3,5].axis('off')
    plt.subplots_adjust(wspace=0,hspace=0)
    #print('___')
    titlestr=r'$T_{2,max}$='+str(output[6]['max_t2_s'][0]*1000)+r'ms, $f_{range}$='+str(output[6]['range_freqs_Hz'])+'Hz'+r' $T_{mean}=$'+str(np.nanmean(T_map).round(1))+'°C'
    fig.suptitle(titlestr)

    return None

In [None]:
plot_t2(output_old),plot_t2(output_40Hz),plot_t2(output_100Hz_t400),plot_t2(output_150Hz_t5000)

In [None]:
plot_t2(output_old),plot_t2(output_40Hz),plot_t2(output_80Hz_t190),plot_t2(output_100Hz_t190),plot_t2(output_100Hz_t250),plot_t2(output_150Hz_t350),plot_t2(output_100Hz_t400),plot_t2(output_150Hz_t500),plot_t2(output_150Hz_t750),plot_t2(output_150Hz_t1000),plot_t2(output_150Hz_t5000)

In [None]:
plt.close('all')

fig,ax=plt.subplots(1)
s,x,y=0,1,3


data1=output_150Hz_t5000
data2=output_old

raw=data1[0].squeeze()
fit=data1[1].squeeze()
ppm=data1[6]['freq_range_Hz']

ax.plot(ppm,np.abs(raw[:,s,x,y]),color='k',label='Raw')
ax.plot(ppm,np.sum(np.abs(fit[:,s,x,y,:]),axis=1),color='r',label='t2*max 5000ms',linewidth=2)

print('T=',data1[7].squeeze()[s,x,y])
print('fs=',data1[4].squeeze()[s,x,y,0:2])
print('T2s=',data1[3].squeeze()[s,x,y,0:2])
print('amps=',np.abs(data1[2].squeeze()[s,x,y,0:2]))

print('--------')
raw=data2[0].squeeze()
fit=data2[1].squeeze()
ppm=data2[6]['freq_range_Hz']


ax.plot(ppm,np.sum(np.abs(fit[:,s,x,y,:]),axis=1),color='b',label='t2*max 70ms',linewidth=2)

print('T=',data2[7].squeeze()[s,x,y])
print('fs=',data2[4].squeeze()[s,x,y,0:2])
print('T2s=',data2[3].squeeze()[s,x,y,0:2])
print('amps=',np.abs(data2[2].squeeze()[s,x,y,0:2]))
ax.legend()
#ax.set_xlim([-200,500])
ax.set_xlabel('Hz')

## Looking at an interpolated fit with more points

In [None]:
data1[6]['sampling_dt']/5

In [None]:
fig,ax=plt.subplots(1,2,figsize=(9,4.5),tight_layout=True)
s,x,y=0,1,3
data1=output_100Hz_t400
itp_fit = uts.generate_fid(amplitude=data1[2],
                 T2_s=data1[3],
               freq0_Hz=data1[4],
               sampling_dt=data1[6]['sampling_dt'],
               npoints=5000,
               noise_amplitude=0,
               sum_fids=True) 
itpl_timeax=np.arange(0,itp_fit.shape[0]*data1[6]['sampling_dt'],data1[6]['sampling_dt'])

ppm=data1[6]['freq_range_Hz']
raw_fid=np.fft.ifft(np.fft.ifftshift(data1[0].squeeze(),axes=(0,)),axis=0)
raw_spec=data1[0].squeeze()

fit_fid=np.sum(np.fft.ifft(np.fft.ifftshift(data1[1].squeeze(),axes=(0,)),axis=0),axis=-1)
ax[0].plot(itpl_timeax,itp_fit[:,0,s,x,y,0,0],label='Itpl')
ax[0].plot(data1[6]['time_axis'],np.real(raw_fid[:,s,x,y]),color='k',label='Raw')
ax[0].plot(data1[6]['time_axis'],np.real(fit_fid[:,s,x,y]),color='r',label='Fit')


ax[0].legend()
ax[0].set_xlabel('t [ms]')
ax[1].set_xlabel('f [Hz]')
ax[1].plot(ppm,np.abs(raw_spec[:,s,x,y]),color='k',label='Raw')
itpl_ppm=np.linspace(np.min(ppm),np.max(ppm),5000)
ax[1].plot(itpl_ppm,np.abs(np.fft.fftshift(np.fft.fft(itp_fit[:,0,s,x,y,0,0],axis=0),axes=(0,))),color='r',label='Fit itpl')
ax[1].plot(ppm,np.sum(np.abs(data1[1].squeeze()[:,s,x,y,:]),axis=1),color='b',label='Fit')

ax[1].legend()
print('T=',data1[7].squeeze()[s,x,y])
print('fs=',data1[4].squeeze()[s,x,y,0:2])
print('T2s=',data1[3].squeeze()[s,x,y,0:2])
print('amps=',np.abs(data1[2].squeeze()[s,x,y,0:2]))
ax[1].set_xlim([-200,200])


In [None]:
fig,ax=plt.subplots(1,figsize=(6,3),tight_layout=True)
s,x,y=0,1,3
data1=output_old
itp_fit = uts.generate_fid(amplitude=data1[2],
                 T2_s=data1[3],
               freq0_Hz=data1[4],
               sampling_dt=data1[6]['sampling_dt'],
               npoints=1000,
               noise_amplitude=0,
               sum_fids=True) 
itpl_timeax=np.arange(0,itp_fit.shape[0]*data1[6]['sampling_dt'],data1[6]['sampling_dt'])

ppm=data1[6]['freq_range_Hz']
raw_fid=np.fft.ifft(np.fft.ifftshift(data1[0].squeeze(),axes=(0,)),axis=0)
raw_spec=data1[0].squeeze()

fit_fid=np.sum(np.fft.ifft(np.fft.ifftshift(data1[1].squeeze(),axes=(0,)),axis=0),axis=-1)
ax.plot(itpl_timeax,itp_fit[:,0,s,x,y,0,0],label='Fit longer range',color='r')
ax.plot(data1[6]['time_axis'],np.real(raw_fid[:,s,x,y]),color='k',label='Raw')
#ax.plot(data1[6]['time_axis'],np.real(fit_fid[:,s,x,y]),color='r',label='Fit')


ax.legend()
ax.set_xlabel('t [ms]')

print('T=',data1[7].squeeze()[s,x,y])
print('fs=',data1[4].squeeze()[s,x,y,0:2])
print('T2s=',data1[3].squeeze()[s,x,y,0:2])
print('amps=',np.abs(data1[2].squeeze()[s,x,y,0:2]))



In [None]:
plt.close('all')

fig,ax=plt.subplots(1)
s,x,y=0,1,3


data1=output_150Hz_t5000
data2=output_old

raw=np.fft.ifftshift(np.fft.ifft(data1[0].squeeze(),axis=0),axes=(0,))
fit=np.fft.ifft(np.fft.ifftshift(data1[1].squeeze(),axes=(0,)),axis=0)
#ppm=data1[6]['freq_range_Hz']

ax.plot(np.abs(raw[:,s,x,y]),color='k',label='Raw')
ax.plot(np.sum(np.abs(fit[:,s,x,y,:]),axis=1),color='r',label='t2*max 5000ms',linewidth=2)

print('T=',data1[7].squeeze()[s,x,y])
print('fs=',data1[4].squeeze()[s,x,y,0:2])
print('T2s=',data1[3].squeeze()[s,x,y,0:2])
print('amps=',np.abs(data1[2].squeeze()[s,x,y,0:2]))

print('--------')
fit=np.fft.ifft(np.fft.ifftshift(data2[1].squeeze(),axes=(0,)),axis=0)


ax.plot(np.sum(np.abs(fit[:,s,x,y,:]),axis=1),color='b',label='t2*max 70ms',linewidth=2)

print('T=',data2[7].squeeze()[s,x,y])
print('fs=',data2[4].squeeze()[s,x,y,0:2])
print('T2s=',data2[3].squeeze()[s,x,y,0:2])
print('amps=',np.abs(data2[2].squeeze()[s,x,y,0:2]))
ax.legend()
#ax.set_xlim([-200,500])
ax.set_xlabel('Hz')

In [None]:
plot_spectra(output_old)

In [None]:
plot_csi_fit_and_raw(output_150Hz_t5000[0],output_150Hz_t5000[1],output_150Hz_t5000[6]['freq_range_Hz'])

# 1. Analyze all datasets

In [None]:
all_temp_maps = []
all_anats = []
all_fit_amps=[]
all_c13exts=[]
all_anat_exts=[]


for index,studyfolder_num in enumerate([122,127,128,129,130,131,132,133,134]):
    
    studyfolder ='HV-'+str(studyfolder_num)
    savepath=os.path.join(basepath,studyfolder)
    animal_id  = studyfolder
    file_name_prefix = 'hv_'+str(studyfolder_num)
    dataset = sio.loadmat(os.path.join(basepath, studyfolder, file_name_prefix+'.mat'))['spec']
    dataset_header_struct_arr = sio.loadmat(os.path.join(basepath, studyfolder, file_name_prefix+"_header.mat"))['header']
    
    anat_folder = os.path.join(basepath,studyfolder,'Anatomical/')
    all_files = os.listdir(anat_folder)
    image_data=[]
    anat_img_pos=[]
    anat_slice_pos=[]
    for file_num in range(1,len(all_files)+1):
        file=pydicom.read_file(anat_folder+'%04i.dcm'%file_num)
        image_data.append(file.pixel_array)
        anat_slice_pos.append(file.SliceLocation)
        anat_img_pos.append(file.ImagePositionPatient)
    anatomical_images=np.array(image_data)
    anat_slice_pos=np.array(anat_slice_pos)
    anat_img_pos=np.array(anat_img_pos)

    # Get field names from the structured array
    field_names = dataset_header_struct_arr.dtype.names

    dataset_header = {}

    # Iterate over the field names and process the data
    for field in field_names:
        # Access the data for each field
        dataset_header[field] = np.squeeze(dataset_header_struct_arr[field][0][0][0])
    for key in dataset_header['image'].dtype.names:
        temp = np.array(dataset_header['image'][key]).item()

        # Check if the structure is a NumPy array
        if isinstance(temp, np.ndarray):
            if temp.ndim == 0:
                value = temp.item()  # Use .item() for 0-dimensional array
            else:
                try:
                    value = temp[0][0]  # Use indexing for higher-dimensional array
                except:
                    pass
            dataset_header['image'][key] = value
    patient_info = {}
    patient_info['ID'] = animal_id
    if dataset_header['exam']['patsex'] == 0:
        patient_info['sex'] = 'male'
    else:
        patient_info['sex'] = 'female'
    patient_info['weight'] = float(dataset_header['exam']['patweight'] / 1e3)
    patient_info['pyr_vol'] = patient_info['weight'] * 0.4
    patient_info['scan_date'] = str(dataset_header['rdb_hdr']['scan_date'])
    patient_info['scan_time'] = str(dataset_header['rdb_hdr']['scan_time'])
    slice_offsets_c13 = dataset_header['image']['ctr_A']
    slice_thickness_c13 = dataset_header['image']['slthick']
    
    
    rdb_hdr = dataset_header['rdb_hdr']
    fields =  rdb_hdr.dtype.names
    tr = dataset_header['image']['tr'] / 1e6
    tr = 4e3
    bw = dataset_header['rdb_hdr']['spectral_width']
    freq_cent_hz = dataset_header['rdb_hdr']['ps_mps_freq'] / 10.0 
    gmr_mhz_t = get_gmr(nucleus="13c")
    b0_off_t = freq_cent_hz / (gmr_mhz_t * 1e6)
    freq0 = 3 *  (gmr_mhz_t)
    freq_off_hz = freq_cent_hz - freq0*1e6
    freq_off_ppm = (freq_cent_hz - freq0*1e6) / (freq0*1e6) * 1e6
    dt = 1./bw
    fa = dataset_header['image']['mr_flip']
    dyn_fid = dataset
    # spectrum has to be flipped (and complex conjugated to have proper FID) (see flip_spec_complex_
    dyn_spec = dyn_fid

    freq_range = np.squeeze(get_freq_axis(npoints=dyn_spec.shape[0], sampling_dt=dt, unit='Hz'))
    time_axis = calc_sampling_time_axis(npoints=dyn_spec.shape[0], sampling_dt=dt)
    input_data= make_NDspec_6Dspec(input_data=dyn_spec, provided_dims=["spec", "x", "y","z"])

    input_data_raw = np.conj(np.flip(input_data, axis=0))
    
    ## Interpolate CSI data
    input_data = utg.interpolate_dataset(input_data=input_data_raw,
                                               interp_size=(input_data_raw.shape[0],
                                                            input_data_raw.shape[1],
                                                            interpolation_factor*input_data_raw.shape[2],
                                                            interpolation_factor*input_data_raw.shape[3],
                                                            input_data_raw.shape[4],
                                                            input_data_raw.shape[5]),
                                               interp_method="cubic")
    
    all_files_in_dir = os.listdir(os.path.join(basepath,fit_path))
    fitted_files = []
    for index,file in enumerate(all_files_in_dir):
        if file.endswith('.pkl'):
            fitted_files.append(file)
            
    # sort files by number
    fitted_files.sort()
    # look for file matching the studyfolder_num
    for n in range(9):
        if fitted_files[n].find(str(studyfolder_num))<0:
            pass
        else:
            file_to_load = fitted_files[n]


    fit_results=utg.load_as_pkl(dir_path=os.path.join(basepath, fit_path), filename=file_to_load, global_vars=globals())


    print('----------------------------------------------------------------------')
    print('Wanted to load file responding to number HV-',studyfolder_num)
    print('Loaded file ',file_to_load)
    print('----------------------------------------------------------------------')
    fit_spectrums = fit_results['fit_spectrums']
    fit_amps = fit_results['fit_amps']
    fit_freqs = fit_results['fit_freqs']
    fit_t2s = fit_results['fit_t2s']
    fit_params = fit_results['fit_params']
    fit_stds=fit_results['fit_stds']
    fit_freqs_ppm  = np.array(uts.freq_Hz_to_ppm(freq_Hz=fit_freqs, hz_axis=fit_params["freq_range_Hz"], ppm_axis=fit_params["freq_range_ppm"]))
    pyr_lac_freq_diff = np.abs(fit_freqs_ppm[...,0]-fit_freqs_ppm[...,1])
    spectra_for_norming = np.squeeze(input_data)
    # shape: 256,5,10,10
    ppm_ax = fit_params['freq_range_ppm']
    pyr_snr_map=np.zeros_like(spectra_for_norming[0,:,:,:],dtype=float)
    lac_snr_map=np.zeros_like(spectra_for_norming[0,:,:,:],dtype=float)
    for sl in range(spectra_for_norming.shape[1]):
        for n in range(spectra_for_norming.shape[2]):
            for m in range(spectra_for_norming.shape[3]):
                test_spec= np.real(spectra_for_norming[:,sl,n,m])
                #fig,ax=plt.subplots(1)
                normed_test_spec = np.abs((test_spec-np.mean(test_spec[15:30]))/np.std(test_spec[15:30]))

                pyr_peak_roi=[np.argmin(np.abs(ppm_ax-170)),np.argmin(np.abs(ppm_ax-180))]
                lac_peak_roi=[np.argmin(np.abs(ppm_ax-180)),np.argmin(np.abs(ppm_ax-190))]

                pyr_max_peak_val = np.max(normed_test_spec[pyr_peak_roi[0]:pyr_peak_roi[1]])
                lac_max_peak_val = np.max(normed_test_spec[lac_peak_roi[0]:lac_peak_roi[1]])

                pyr_snr_map[sl,n,m]=pyr_max_peak_val
                lac_snr_map[sl,n,m]=lac_max_peak_val
    

    hz_axis = freq_range
    ppm_ax = fit_params['freq_range_ppm']
    temp_map,_ =  utf.temperature_from_frequency(pyr_lac_freq_diff,concentration,True)
    # Error of difference between lac and pyruvate in Hz
    fit_freqs_diff_stds = np.sqrt(np.abs(fit_stds[...,0,1]**2) +
                                            np.abs(fit_stds[...,1,1]**2))
    # get the uncertainty of the frequency difference in ppm:
    fit_freqs_diff_stds_ppm = uts.freq_Hz_to_ppm(fit_freqs_diff_stds,
                                                 hz_axis=hz_axis,
                                                 ppm_axis=ppm_ax,
                                                 ppm_centered_at_0=True)


    # calculate the temperature and frequency errors:
    fit_temps_stds, _ =  temp_map - np.asarray(utf.temperature_from_frequency(
        pyr_lac_freq_diff+fit_freqs_diff_stds_ppm, calibration_type=concentration))
        
    ## SNR masking
    snr,noise=uts.compute_snr_from_fit(input_data,fit_spectrums)

    snr_mask_cond = (snr[..., 0] < snr_threshold) | (snr[..., 1] < snr_threshold)
    snr_mask = np.where(snr_mask_cond, np.nan, 1)
    snr_masked_temp=snr_mask*temp_map


    ## Fit error masking
    fit_error_mask=np.where(fit_temps_stds <= fit_error_threshold, 1, np.nan)
    fit_error_masked_temp=np.where(~np.isnan(fit_error_mask), temp_map, np.nan)


    # combine and report all three variants: 1. SNR masked, 2. Fit error masked and 3. combined
    snr_and_fit_error_masked_temp=snr_mask*fit_error_masked_temp

    if error_thld=='both':
        temp_map=snr_and_fit_error_masked_temp
        mask_pyr_lac=snr_mask*fit_error_mask
    elif error_thld=='SNR':
        temp_map=snr_masked_temp
        mask_pyr_lac=snr_mask
    elif error_thld=='fiterror':
        temp_map=fit_error_masked_temp
        mask_pyr_lac=fit_error_mask
    else:
        raise KeyError("Error thresholding must be selected")
    
    

    fig,ax=plt.subplots(2,5,tight_layout=True,figsize=(12,4))

    for slice_num in range(5):
        ax[0,slice_num].cla()
        ax[1,slice_num].cla()
        ax[0,slice_num].set_xticks([])
        ax[0,slice_num].set_yticks([])

        im=ax[0,slice_num].imshow(temp_map.squeeze()[slice_num],vmin=20,vmax=45,cmap='plasma')
        fig.colorbar(im,ax=ax[0,slice_num],label='°C')
        hist_x,hist_y,hist_bins = utg.Get_Hist(temp_map.squeeze()[slice_num],15)

        ax[1,slice_num].bar(hist_x,hist_y,hist_bins)
        ax[0,slice_num].set_title(r'$T_{mean}=$'+str(np.round(np.nanmean(temp_map.squeeze()[slice_num]),2))+' °C')
        ax[1,slice_num].set_title(r'$T_{std}=$'+str(np.round(np.nanstd(temp_map.squeeze()[slice_num]),2))+' °C')

        ax[1,slice_num].set_xlabel('T [°C]')
        ax[1,slice_num].set_ylabel('Pixel count')

    
    plt.close('all')
    from scipy.stats import norm
    Tns=[]
    fig, ax = plt.subplots(2,fit_freqs.shape[1], figsize=(12,8),tight_layout=True)
    for s in range(fit_freqs.shape[1]):
        fdata = (fit_freqs[0,s,:,:,0,0,0] - fit_freqs[0,s,:,:,0,0,1])*mask_pyr_lac.squeeze()[s]
        fmean, fstd = np.nanmean(fdata), np.nanstd(fdata)
        fn=np.count_nonzero(~np.isnan(fdata))
        temp_data = temp_map.squeeze()[s]
        Tmean, Tstd = np.nanmean(temp_data), np.nanstd(temp_data)
        Tn=np.count_nonzero(~np.isnan(temp_data))
        # Plot Gaussian
        ax[0,s].errorbar(1, fmean, yerr=fstd, fmt='o', color='red', ecolor='lightgray', elinewidth=3, capsize=7,alpha=0.7, label='Average')
        ax[0,s].scatter(np.ones_like(fdata), fdata, label='Data', color='black', marker='.',alpha=1)

        title = f"{fmean:.1f} +/- {fstd:.1f} Hz"

        ax[0,s].set_title(title)
        ax[0,s].set_xticklabels([])
        ax[0,s].set_ylabel('Freq. diff. Pyr - Lac [Hz]')
        ax[0,s].legend()

        ax[1,s].errorbar(1, Tmean, yerr=Tstd, fmt='o', color='red', ecolor='lightgray', elinewidth=3, capsize=7,alpha=0.7, label='Average')
        ax[1,s].scatter(np.ones_like(temp_data), temp_data, label='Data', color='black', marker='.',alpha=1)

        title = f"{Tmean:.1f} +/- {Tstd:.1f} °C"

        ax[1,s].set_title(title)
        ax[1,s].set_xticklabels([])
        ax[1,s].set_ylabel('T[°C]')
        ax[1,s].legend()


        Tns.append(Tn)
    print(Tns)

    fig.suptitle('Nr. Pixels: '+str(Tns))

   
    plt.close('all')

    output_dataframe = pd.DataFrame()
    mean_values = np.nanmean(temp_map, axis=(0, 2, 3, 4, 5))  # Using nanmean to ignore NaNs
    median_values = np.nanmedian(temp_map, axis=(0, 2, 3, 4, 5))  # Using nanmean to ignore NaNs
    
    # Compute standard deviation similarly
    std_values = np.nanstd(temp_map, axis=(0, 2, 3, 4, 5))
    # Compute count of non-NaN values
    count_values = np.sum(~np.isnan(temp_map), axis=(0, 2, 3, 4, 5))

    output_dataframe['Tmean']=mean_values
    output_dataframe['Tmedian']=median_values
    output_dataframe['dT']=std_values
    output_dataframe['pixels']=count_values


    output_dataframe.to_excel(revision_path+str(studyfolder)+'_temp_freq_results_SNR_thresholding_'+str(concentration)+'.xlsx')

    output_dict={}
    output_dict.update({'patient_info':patient_info,'temperature_map':temp_map,'frequency_map':pyr_lac_freq_diff,'pyruvate_snr_map':pyr_snr_map,
                        'lactate_snr_map':lac_snr_map,'calibrationfunction':concentration,'thresholding':error_thld})

    utg.save_as_pkl(revision_path,str(studyfolder)+'_results_dictionary_'+str(concentration)+'_SNR_thresholding',output_dict,use_timestamp=False)
    
    

    # reading slice position data
    
    
    ## getting rid of temp_map dimensions
    temp_map=temp_map.squeeze()
    
    c13_slice_fov = float(dataset_header['rdb_hdr']['fov'])
    c13_slice_centers = (float(dataset_header['image']['ctr_R']),float(dataset_header['image']['ctr_A']),float(dataset_header['image']['ctr_S']))
    c13_slice_upper_right = float(dataset_header['image']['tlhc_R']),float(dataset_header['image']['tlhc_A']),float(dataset_header['image']['tlhc_S'])
    # so R is x in axial, A is y in axial and S is slices
    c13_slice_upper_left = (c13_slice_upper_right[0]*-1,c13_slice_upper_right[1])
    c13_slice_pos=c13_slice_upper_right[2]
    c13_slice_thickness = float(dataset_header['image']['slthick'])
    meta_data = pydicom.read_file(anat_folder+'0001.dcm')
    anat_slice_thick = float(meta_data.SliceThickness)
    anat_fov = float(meta_data.ReconstructionDiameter)
    anat_upper_left = np.array(anat_img_pos[0][0:2])
    # computing shift of anatomicals to matxh csi data
    from hypermri.utils.utils_general import calc_mat_origin_diff
    from scipy.ndimage import shift
    fov_csi=[c13_slice_fov,c13_slice_fov]
    fov_anat=[anat_fov,anat_fov]
    mat_csi=[input_data_raw.shape[2],input_data_raw.shape[3]]
    mat_anat = [anatomical_images[0].shape[0],anatomical_images[0].shape[1]]

    res_csi = [a / b for a, b in zip(fov_csi, mat_csi)]
    res_anat = [a / b for a, b in zip(fov_anat, mat_anat)]

    # calc necessary shift in all directions:
    shift_vox_list = calc_mat_origin_diff(
        res_metab=res_csi, res_anat=res_anat, fov_anat=fov_anat, mat_anat=mat_anat
    )

    shift_vox_list=[shift_vox_list[0]*-1,shift_vox_list[1]]
    
    #plotting
    nrows, ncols = 7, 6
    width_ratios = [1, 1, 1, 1, 1, 0.05]
    total_width = 6.9
    subplot_width = total_width / (sum(width_ratios) - 0.1)
    subplot_height = subplot_width * 3 / 4
    total_height = subplot_height * nrows

    fig, ax = plt.subplots(nrows, ncols, figsize=(total_width, total_height*1.35), width_ratios=width_ratios)

    # Calculate the total height
    total_height = subplot_height * nrows
    image_width_mm = anat_fov
    image_height_mm = anat_fov
    # Top-left-hand corner coordinates
    tlhc_R = anat_img_pos[0][0]
    tlhc_A = -anat_img_pos[0][1]
    # Calculate extent to center the image
    anat_extent = [tlhc_R, tlhc_R + image_width_mm, tlhc_A - image_height_mm, tlhc_A]
    
    if assume_ctr_s_is_center == True:
        ref_slices=np.array([np.argmin(np.abs(anat_slice_pos-((float(c13_slice_centers[2]))))),
                      np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])-c13_slice_thickness))),
                      np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])-2*c13_slice_thickness))),
                      np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])-3*c13_slice_thickness))),
                      np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])-4*c13_slice_thickness)))])
    else:

        ref_slices=np.array([np.argmin(np.abs(anat_slice_pos-((float(c13_slice_centers[2]))-c13_slice_thickness/2))),
                              np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])-c13_slice_thickness-c13_slice_thickness/2))),
                              np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])-2*c13_slice_thickness-c13_slice_thickness/2))),
                              np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])-3*c13_slice_thickness-c13_slice_thickness/2))),
                              np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])-4*c13_slice_thickness-c13_slice_thickness/2)))])

    if studyfolder_num==129:
        if assume_ctr_s_is_center == True:
            ref_slices=np.array([np.argmin(np.abs(anat_slice_pos-((float(c13_slice_centers[2]))))),
                      np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])+c13_slice_thickness))),
                      np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])+2*c13_slice_thickness))),
                      np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])+3*c13_slice_thickness))),
                      np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])+4*c13_slice_thickness)))])
        else:

            ref_slices=np.array([np.argmin(np.abs(anat_slice_pos-((float(c13_slice_centers[2]))+c13_slice_thickness/2))),
                                  np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])+c13_slice_thickness+c13_slice_thickness/2))),
                                  np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])+2*c13_slice_thickness+c13_slice_thickness/2))),
                                  np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])+3*c13_slice_thickness+c13_slice_thickness/2))),
                                  np.argmin(np.abs(anat_slice_pos-(float(c13_slice_centers[2])+4*c13_slice_thickness+c13_slice_thickness/2)))])

    else:
        pass
    [[ax[n,idx].imshow(shift(anatomical_images[ref_slices[idx]], shift=shift_vox_list, mode="wrap"),vmin=0,vmax=1400,cmap='bone',extent=anat_extent) for idx,ref_slice in enumerate(ref_slices)] for n in range(7)]

    img1=ax[0,0].imshow(shift(anatomical_images[ref_slices[0]], shift=shift_vox_list, mode="wrap"),vmin=0,vmax=1400,cmap='bone',extent=anat_extent)

    [ax[0,idx].set_title(str(np.round(anat_slice_pos[ref_slices[idx]]))+' mm') for idx in range(5)]


    image_width_mm = c13_slice_fov
    image_height_mm = c13_slice_fov
    # Top-left-hand corner coordinates
    tlhc_R = c13_slice_upper_left[0]
    tlhc_A = c13_slice_upper_left[1]
    # Calculate extent to center the image
    c13_extent = [tlhc_R, tlhc_R + image_width_mm, tlhc_A - image_height_mm, tlhc_A]

    [ax[1,csi_slice].imshow(np.abs(fit_amps[0,csi_slice,:,:,0,0,1]),cmap='magma',
                                    extent=c13_extent) for csi_slice in range(5)]
    img2=ax[1,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,1]),cmap='magma',
                                    extent=c13_extent)


    [ax[2,idx].imshow(np.abs(fit_amps[0,idx,:,:,0,0,1]),cmap='magma',alpha=0.75,interpolation='spline36',
                                    extent=c13_extent) for idx,ref_slice in enumerate(ref_slices)]
    img3=ax[2,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,1]),cmap='magma',alpha=0.75,interpolation='spline36',
                                    extent=c13_extent)

    [ax[3,idx].imshow(np.abs(fit_amps[0,idx,:,:,0,0,0]),cmap='magma',alpha=0.75,interpolation='spline36',
                                    extent=c13_extent) for idx,ref_slice in enumerate(ref_slices)]
    img4=ax[3,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,0]),cmap='magma',alpha=0.75,interpolation='spline36',
                                    extent=c13_extent)


    [ax[4,idx].imshow(np.abs(fit_amps[0,idx,:,:,0,0,2]),cmap='magma',alpha=0.75,interpolation='spline36',
                                    extent=c13_extent) for idx,ref_slice in enumerate(ref_slices)]
    img5=ax[4,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,2]),cmap='magma',alpha=0.75,interpolation='spline36',
                                    extent=c13_extent)


    [ax[5,idx].imshow(np.abs(fit_amps[0,idx,:,:,0,0,3]),cmap='magma',alpha=0.75,interpolation='spline36',
                                    extent=c13_extent) for idx,ref_slice in enumerate(ref_slices)]
    img6=ax[5,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,3]),cmap='magma',alpha=0.75,interpolation='spline36',
                                    extent=c13_extent)


    [ax[6,idx].imshow(temp_map[idx],cmap='jet',vmin=25,vmax=42,extent=c13_extent,alpha=0.7) for idx in range(0,5)]
    temp_img = ax[6,4].imshow(temp_map[4],cmap='jet',vmin=25,vmax=42,extent=c13_extent,alpha=0.7)


    [[ax[n,idx].set_xlim([c13_extent[0],c13_extent[1]]) for idx in range(5)] for n in range(7)]
    [[ax[n,idx].set_ylim([c13_extent[2],c13_extent[3]]) for idx in range(5)] for n in range(7)]
    [[ax[n,idx].set_xticks([]) for idx in range(5)] for n in range(7)]
    [[ax[n,idx].set_yticks([]) for idx in range(5)] for n in range(7)]




    ax[0,0].set_ylabel('T1W',rotation=90,labelpad=10)
    ax[1,0].set_ylabel('Lactate 2x',rotation=90,labelpad=10)
    ax[2,0].set_ylabel('Lactate',rotation=90,labelpad=10)
    ax[3,0].set_ylabel('Pyruvate',rotation=90,labelpad=10)
    ax[4,0].set_ylabel('Bicarbonate',rotation=90,labelpad=10)
    ax[5,0].set_ylabel('Hydrate',rotation=90,labelpad=10)
    ax[6,0].set_ylabel('Temperature',rotation=90,labelpad=10)


    fig.colorbar(img1,cax=ax[0,5],label='I [a.u.]')
    fig.colorbar(img2,cax=ax[1,5],label='I [a.u.]')
    fig.colorbar(img3,cax=ax[2,5],label='I [a.u.]')
    fig.colorbar(img4,cax=ax[3,5],label='I [a.u.]')
    fig.colorbar(img5,cax=ax[4,5],label='I [a.u.]')
    fig.colorbar(img6,cax=ax[5,5],label='I [a.u.]')
    fig.colorbar(temp_img,cax=ax[6,5],label='T [$^\circ$C]')


    [[ax[n,m].set_xticks([])for n in range(7)] for m in range(5)]
    [[ax[n,m].set_yticks([])for n in range(7)] for m in range(5)]

    plt.subplots_adjust(wspace=0.05, hspace=0.05)


    if studyfolder_num==131:
        nrows, ncols = 6, 6
        overlay_alpha=1
        width_ratios = [1, 1, 1, 1, 1, 0.05]
        total_width = 6.9
        subplot_width = total_width / (sum(width_ratios) - 0.1)
        subplot_height = subplot_width * 3 / 4
        total_height = subplot_height * nrows
        fig, ax = plt.subplots(nrows, ncols, figsize=(total_width, total_height*1.35), width_ratios=width_ratios)

        [[ax[n,idx].imshow(shift(anatomical_images[ref_slices[idx]], shift=shift_vox_list, mode="wrap")/np.max(shift(anatomical_images[ref_slices[idx]], shift=shift_vox_list, mode="wrap")),
                           vmin=0,vmax=0.7,cmap='bone',extent=anat_extent) for idx,ref_slice in enumerate(ref_slices)] for n in range(1)]
        img1=ax[0,0].imshow(shift(anatomical_images[ref_slices[0]], shift=shift_vox_list, mode="wrap")/np.max(shift(anatomical_images[ref_slices[0]], 
                                                                                    shift=shift_vox_list, mode="wrap")),vmin=0,vmax=0.7,cmap='bone',extent=anat_extent)

        [ax[1,idx].imshow(np.abs(fit_amps[0,idx,:,:,0,0,1])/np.max(np.abs(fit_amps[0,idx,:,:,0,0,1])),cmap='magma',alpha=overlay_alpha,
                                        extent=c13_extent) for idx,ref_slice in enumerate(ref_slices)]
        img3=ax[1,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,1])/np.max(np.abs(fit_amps[0,0,:,:,0,0,1])),cmap='magma',alpha=overlay_alpha,
                                        extent=c13_extent,vmin=0)

        [ax[2,idx].imshow(np.abs(fit_amps[0,idx,:,:,0,0,0])/np.max(np.abs(fit_amps[0,idx,:,:,0,0,0])),cmap='magma',alpha=overlay_alpha,
                                        extent=c13_extent) for idx,ref_slice in enumerate(ref_slices)]
        img4=ax[2,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,0])/np.max(np.abs(fit_amps[0,0,:,:,0,0,0])),cmap='magma',alpha=overlay_alpha,
                                        extent=c13_extent,vmin=0)


        [ax[3,idx].imshow(np.abs(fit_amps[0,idx,:,:,0,0,2])/np.max(np.abs(fit_amps[0,idx,:,:,0,0,2])),cmap='magma',alpha=overlay_alpha,
                                        extent=c13_extent) for idx,ref_slice in enumerate(ref_slices)]
        img5=ax[3,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,2])/np.max(np.abs(fit_amps[0,0,:,:,0,0,2])),cmap='magma',alpha=overlay_alpha,
                                        extent=c13_extent,vmin=0)


        [ax[4,idx].imshow(np.abs(fit_amps[0,idx,:,:,0,0,3])/np.max(np.abs(fit_amps[0,idx,:,:,0,0,3])),cmap='magma',alpha=overlay_alpha,
                                        extent=c13_extent) for idx,ref_slice in enumerate(ref_slices)]
        img6=ax[4,0].imshow(np.abs(fit_amps[0,0,:,:,0,0,3])/np.max(np.abs(fit_amps[0,0,:,:,0,0,3])),cmap='magma',alpha=overlay_alpha,
                                        extent=c13_extent,vmin=0)

        [ax[5,idx].imshow(shift(anatomical_images[ref_slices[idx]], shift=shift_vox_list, mode="wrap")/np.max(shift(anatomical_images[ref_slices[idx]], shift=shift_vox_list, mode="wrap")),
                           vmin=0,vmax=0.7,cmap='bone',extent=anat_extent) for idx,ref_slice in enumerate(ref_slices)]
        [ax[5,idx].imshow(temp_map[idx],cmap='jet',vmin=28,vmax=42,extent=c13_extent,alpha=overlay_alpha) for idx in range(0,5)]
        temp_img = ax[5,4].imshow(temp_map[4],cmap='jet',vmin=28,vmax=42,extent=c13_extent,alpha=overlay_alpha)


        [[ax[n,idx].set_xlim([c13_extent[0],c13_extent[1]]) for idx in range(5)] for n in range(6)]
        [[ax[n,idx].set_ylim([c13_extent[2],c13_extent[3]]) for idx in range(5)] for n in range(6)]
        [[ax[n,idx].set_xticks([]) for idx in range(5)] for n in range(6)]
        [[ax[n,idx].set_yticks([]) for idx in range(5)] for n in range(6)]




        ax[0,0].set_ylabel('A',rotation=0,labelpad=10)
        ax[1,0].set_ylabel('B',rotation=0,labelpad=10)
        ax[2,0].set_ylabel('C',rotation=0,labelpad=10)
        ax[3,0].set_ylabel('D',rotation=0,labelpad=10)
        ax[4,0].set_ylabel('E',rotation=0,labelpad=10)
        ax[5,0].set_ylabel('F',rotation=0,labelpad=10)


        fig.colorbar(img1,cax=ax[0,5],label='I [a.u.]',ticks=([0,1]))
        fig.colorbar(img3,cax=ax[1,5],label='I [a.u.]',ticks=([0,1]))
        fig.colorbar(img4,cax=ax[2,5],label='I [a.u.]',ticks=([0,1]))
        fig.colorbar(img5,cax=ax[3,5],label='I [a.u.]',ticks=([0,1]))
        fig.colorbar(img6,cax=ax[4,5],label='I [a.u.]',ticks=([0,1]))
        fig.colorbar(temp_img,cax=ax[5,5],label='T [$^\circ$C]',ticks=([30,35,40]))


        [[ax[n,m].set_xticks([])for n in range(6)] for m in range(5)]
        [[ax[n,m].set_yticks([])for n in range(6)] for m in range(5)]

        plt.subplots_adjust(wspace=0.05, hspace=0.05)






    ##append temp map and anatomicals to plot all in one setting
    all_temp_maps.append(temp_map)
    all_anats_per_pat=[shift(anatomical_images[ref_slices[idx]], shift=shift_vox_list, mode="wrap") for idx,ref_slice in enumerate(ref_slices)]
    all_anats.append(all_anats_per_pat)
    all_anat_exts.append(anat_extent)
    all_c13exts.append(c13_extent)
    all_fit_amps.append(fit_amps)

## Plot Figure S4

In [None]:
plt.close('all')
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig,ax=plt.subplots(9,6,figsize=(6.9,12),width_ratios=(1,1,1,1,1,0.05))
for pat_index in range(9):
    for slice_num in range(5):
        ax[pat_index,slice_num].imshow(all_anats[pat_index][slice_num],cmap='bone',extent=all_anat_exts[pat_index],vmax=1200)
        img=ax[pat_index,slice_num].imshow(all_temp_maps[pat_index][slice_num],vmin=28,vmax=42,cmap='jet',extent=all_c13exts[pat_index],alpha=0.9)
        ax[pat_index,slice_num].set_xticks([])
        ax[pat_index,slice_num].set_yticks([])

        ax[pat_index,0].set_ylabel('HV-'+str(studyfolder_nums[pat_index]),rotation=90)
        fig.colorbar(img,cax=ax[pat_index,5],label=r'T[$^\circ$C]',ticks=[30,35,40])
        alphabet=['A','B','C','D','E','F','G','H','I']

plt.subplots_adjust(wspace=0,hspace=0.1)


## Plot AUCR, Lactate, Pyruvate maps

In [None]:
#plt.close('all')
masked_aucr=[]
masked_pyruvate=[]
fig,ax=plt.subplots(9,6,figsize=(6.9,12),width_ratios=(1,1,1,1,1,0.05))
for pat_index in range(9):
    masked_aucr_per_pat=[]
    masked_pyruvate_per_pat=[]
    for slice_num in range(5):
        mask=np.where(np.isnan(all_temp_maps[pat_index][slice_num]),np.nan,1)
        ax[pat_index,slice_num].imshow(all_anats[pat_index][slice_num],cmap='bone',extent=all_anat_exts[pat_index],vmax=1300)
        img=ax[pat_index,slice_num].imshow(np.abs(all_fit_amps[pat_index][0,slice_num,:,:,0,0,1]/all_fit_amps[pat_index][0,slice_num,:,:,0,0,0]),cmap='magma',alpha=0.75,extent=all_c13exts[pat_index])
        masked_aucr_per_pat.append(np.abs(all_fit_amps[pat_index][0,slice_num,:,:,0,0,1]/all_fit_amps[pat_index][0,slice_num,:,:,0,0,0])*mask)
        masked_pyruvate_per_pat.append(np.abs(all_fit_amps[pat_index][0,slice_num,:,:,0,0,0]*mask))
        ax[pat_index,slice_num].set_xticks([])
        ax[pat_index,slice_num].set_yticks([])
        #ax[0,slice_num].set_title('Slice '+str(slice_num))
        ax[pat_index,0].set_ylabel('HV-'+str(studyfolder_nums[pat_index]),rotation=90)
        fig.colorbar(img,cax=ax[pat_index,5],label=r'AUCR')
        alphabet=['A','B','C','D','E','F','G','H','I']
        ax[pat_index,0].set_title(alphabet[pat_index],x=0.11,y=0.7)
    masked_aucr.append(masked_aucr_per_pat)
    masked_pyruvate.append(masked_pyruvate_per_pat)
fig.suptitle('AUCR maps of 9 patients in 5 slices')
plt.subplots_adjust(wspace=0,hspace=0.1)



fig,ax=plt.subplots(9,6,figsize=(6.9,12),width_ratios=(1,1,1,1,1,0.05))
for pat_index in range(9):
    masked_aucr_per_pat=[]
    for slice_num in range(5):
        
        ax[pat_index,slice_num].imshow(all_anats[pat_index][slice_num],cmap='bone',extent=all_anat_exts[pat_index],vmax=1300)
        img=ax[pat_index,slice_num].imshow(np.abs(all_fit_amps[pat_index][0,slice_num,:,:,0,0,1]),cmap='magma',alpha=0.65,extent=all_c13exts[pat_index])
        ax[pat_index,slice_num].set_xticks([])
        ax[pat_index,slice_num].set_yticks([])
        #ax[0,slice_num].set_title('Slice '+str(slice_num))
        ax[pat_index,0].set_ylabel('HV-'+str(studyfolder_nums[pat_index]),rotation=90)
        fig.colorbar(img,cax=ax[pat_index,5],label=r'I[a.u.]')
        alphabet=['A','B','C','D','E','F','G','H','I']
        ax[pat_index,0].set_title(alphabet[pat_index],x=0.11,y=0.7)
    
fig.suptitle('Lactate maps of 9 patients in 5 slices')
plt.subplots_adjust(wspace=0,hspace=0.1)


fig,ax=plt.subplots(9,6,figsize=(6.9,12),width_ratios=(1,1,1,1,1,0.05))
for pat_index in range(9):
    masked_aucr_per_pat=[]
    for slice_num in range(5):
        
        ax[pat_index,slice_num].imshow(all_anats[pat_index][slice_num],cmap='bone',extent=all_anat_exts[pat_index],vmax=1300)
        img=ax[pat_index,slice_num].imshow(np.abs(all_fit_amps[pat_index][0,slice_num,:,:,0,0,0]),cmap='magma',alpha=0.65,extent=all_c13exts[pat_index])
        ax[pat_index,slice_num].set_xticks([])
        ax[pat_index,slice_num].set_yticks([])
        #ax[0,slice_num].set_title('Slice '+str(slice_num))
        ax[pat_index,0].set_ylabel('HV-'+str(studyfolder_nums[pat_index]),rotation=90)
        fig.colorbar(img,cax=ax[pat_index,5],label=r'I[a.u.]')
        alphabet=['A','B','C','D','E','F','G','H','I']
        ax[pat_index,0].set_title(alphabet[pat_index],x=0.11,y=0.7)
    
fig.suptitle('Pyruvate maps of 9 patients in 5 slices')
plt.subplots_adjust(wspace=0,hspace=0.1)

## Correlate Pyruvate and Temperature, Figure S5

In [None]:
plt.close('all')

def lin(x,k,M):
    return x*k+M
import scipy
import seaborn as sns
person_vals=[]
rsquared=np.zeros((9))
slopes=np.zeros_like(rsquared)
dslope=np.zeros_like(slopes)
fig,ax=plt.subplots(3,3,figsize=(6.9,6.9),tight_layout=True)
for pat in range(9):
    pvals_per_pat=[]
    nx,ny=pat//3,pat%3

    #print(n)
    temps = np.ravel(all_temp_maps[pat])[~np.isnan(np.ravel(all_temp_maps[pat]))]
    pa_map = np.ravel(masked_pyruvate[pat])[~np.isnan(np.ravel(masked_pyruvate[pat]))]
    ax[nx,ny].scatter(pa_map,temps,color='k',s=3)
    y=temps
    x=pa_map
    if pat not in [0,2]:
        sns.regplot(x=x,y=y,ax=ax[nx,ny],scatter=False,fit_reg=True,ci=95,color='k')

        mask = ~np.isnan(x) & ~np.isnan(y)
        x_clean = x[mask]
        y_clean = y[mask]
        if len(x_clean)>1:
            if len(y_clean>1):
                coeff,cov=curve_fit(lin,x_clean,y_clean)
                slopes[pat]=coeff[0]

                err=np.sqrt(np.diag(cov))
                dslope[pat]=err[0]
                residuals = y_clean - lin(x_clean, *coeff)
                ss_res = np.sum(residuals**2)
                ss_tot = np.sum((y_clean-np.mean(y_clean))**2)
                pearson=scipy.stats.pearsonr(y_clean,x_clean)
                
                person_vals.append(pearson)
                rsquared[pat] = 1 - (ss_res / ss_tot)
                ax[nx,ny].set_title('p='+str(np.round(pearson.pvalue,3))+ ', PCC='+str(np.round(pearson.statistic,3)),fontsize=11)
                #print(coeff[0].round(1),err[0].round(1))

    else:

        pass
    
    ax[nx,ny].yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter(useMathText=True))
    ax[nx,ny].ticklabel_format(style='sci', axis='x', scilimits=(0, 0))  # Force scientific notation


[ax[2,n].set_xlabel('$I_{pyr}$ [a.u.]') for n in range(3)]
[ax[n,0].set_ylabel('T[$^\circ$C]')for n in range(3)]


# Save all temp values in npy file

In [None]:
data = np.array(all_temp_maps).flatten()  
np.save(revision_path+'human_csi_healthy_brain_temp_values_snr_thld.npy',np.array(all_temp_maps))
