In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib
params = {"text.usetex" : True,
          "font.family" : "serif",
          "font.serif" : ["Computer Modern Serif"],
         'font.size':12}
plt.rcParams.update(params)
import ipywidgets as widgets
import datetime
import pandas as pd
import os
import hypermri

from scipy.optimize import curve_fit
import scipy.io as sio
from matplotlib import cm

import hypermri.utils.utils_general as utg
import hypermri.utils.utils_spectroscopy as uts
import hypermri.utils.utils_fitting as utf
import sys
# define paths:
sys.path.append('../../../')
import Template_Cambridge
basepath,savepath = Template_Cambridge.import_all_packages(False)

# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri package
%load_ext autoreload
%autoreload 2

%matplotlib widget
def get_colors_from_cmap(cmap_name, N):
    cmap = plt.get_cmap(cmap_name)
    colors = cmap(np.linspace(0, 1, N))
    return colors


In [None]:
revision_path=r'.../Publication/Revision1/RefittedData/SP_Brain_GBM/'

In [None]:
file_names = ['mbt_02', 'mgs_01', 'mgs_03', 'mgs_07', 'mgs_09', 'mgs_10', 'mgs_11']
studyfolder_names = ['MBT-02', 'MGS-01', 'MGS-03', 'MGS-07', 'MGS-09', 'MGS-10', 'MGS-11']

# Iterate through patients by changing the pat_number and then executing all cells

In [None]:
pat_number=6
studyfolder = studyfolder_names[pat_number]

## Analysis

In [None]:
# Extracting the last folder name from the basepath
savefolder = basepath.split('\\')[-1]
savepath = os.path.join(savepath, savefolder, "Slicespec")
mgs_file = sio.loadmat(os.path.join(basepath, studyfolder, file_names[pat_number] + ".mat"))["A"]
dataset_header_struct_arr = \
    sio.loadmat(os.path.join(basepath, studyfolder, file_names[pat_number] + "_header.mat"))['header']

# Get field names from the structured array
field_names = dataset_header_struct_arr.dtype.names

dataset_header = {}

# Iterate over the field names and process the data
for field in field_names:
    # Access the data for each field
    dataset_header[field] = np.squeeze(dataset_header_struct_arr[field][0][0][0])
for key in dataset_header['image'].dtype.names:
    temp = np.array(dataset_header['image'][key]).item()

    # Check if the structure is a NumPy array
    if isinstance(temp, np.ndarray):
        if temp.ndim == 0:
            value = temp.item()  # Use .item() for 0-dimensional array
        else:
            try:
                value = temp[0][0]  # Use indexing for higher-dimensional array
            except:
                pass
        dataset_header['image'][key] = value
patient_info = {}
patient_info['ID'] = studyfolder
if dataset_header['exam']['patsex'] == 0:
    patient_info['sex'] = 'male'
else:
    patient_info['sex'] = 'female'
patient_info['weight'] = float(dataset_header['exam']['patweight'] / 1e3)
patient_info['pyr_vol'] = patient_info['weight'] * 0.4
patient_info['scan_date'] = str(dataset_header['rdb_hdr']['scan_date'])
patient_info['scan_time'] = str(dataset_header['rdb_hdr']['scan_time'])
import numpy as np

rdb_hdr = dataset_header['rdb_hdr']
fields = rdb_hdr.dtype.names
for f in fields:
    # print(f"{f}: {rdb_hdr[f]}")
    pass
# repetition time:
tr = dataset_header['image']['tr'] / 1e6
tr = 4e3
# bandwidth
bw = dataset_header['rdb_hdr']['spectral_width']
# center frequency:
freq_cent_hz = dataset_header['rdb_hdr']['ps_mps_freq'] / 10.0
# gyromagnetic ratio MHz/T
gmr_mhz_t = utg.get_gmr(nucleus="13c")
# B0 in Tesla:
b0_off_t = freq_cent_hz / (gmr_mhz_t * 1e6)
freq0 = 3 * (gmr_mhz_t)
freq_off_hz = freq_cent_hz - freq0 * 1e6
freq_off_ppm = (freq_cent_hz - freq0 * 1e6) / (freq0 * 1e6) * 1e6
# sampling time:
dt = 1. / bw
# flip angle:
fa = dataset_header['image']['mr_flip']
dyn_fid = mgs_file[0::8, 1:, 0, 0, :]

# spectrum has to be flipped (and complex conjugated to have proper FID) (see flip_spec_complex_
dyn_spec = np.conj(np.flip(np.fft.fftshift(np.fft.fft(dyn_fid, axis=1), axes=(1,)), axis=1))

freq_range = np.squeeze(uts.get_freq_axis(npoints=dyn_spec.shape[1], sampling_dt=dt, unit='Hz'))
time_axis = utg.calc_sampling_time_axis(npoints=dyn_spec.shape[1], sampling_dt=dt)

input_data = uts.make_NDspec_6Dspec(input_data=dyn_spec, provided_dims=["reps", "spec", "z"])

# Loading prefitted data

In [None]:
all_files_in_dir = os.listdir(os.path.join(savepath)+'/fit_results_400ms_100Hz_refit2025/')
fitted_files = []
for index,file in enumerate(all_files_in_dir):
    if file.endswith('.pkl'):
        fitted_files.append(file)
        print('Found files ',file)
# sort files by number
fitted_files.sort()
# look for file matching the studyfolder_num
for n in range(len(fitted_files)):
    if fitted_files[n].find(str(studyfolder)+'_fit_spectra_2025')==0:
        load_file=fitted_files[n]
    else:
        pass
print('-----')
print('Selected study:',studyfolder)
print('Loading',load_file)

In [None]:
fit_results = utg.load_as_pkl(dir_path=savepath+'/fit_results_400ms_100Hz_refit2025/', filename=load_file, global_vars=globals())
print('Loaded data')
fit_spectrums = fit_results['fit_spectrums']
fit_amps = fit_results['fit_amps']
fit_freqs = fit_results['fit_freqs']
fit_t2s = fit_results['fit_t2s']
fit_params = fit_results['fit_params']
fit_stds=fit_results['fit_stds']

# Plot fit results

In [None]:
utf.plot_fitted_spectra(measured_data=input_data,
                        fitted_data=fit_spectrums,
                        fit_params=fit_params,plot_params={'figsize':(9,2),'ylim':[-1000,1000]}
                    )

In [None]:
fig,ax=plt.subplots(1)

@widgets.interact(slic=(0,input_data.shape[1]-1,1),rep=(0,input_data.shape[4]-1,1))
def update(rep=0,slic=0):
    ax.cla()
    ax.plot(fit_params['freq_range_ppm'],np.abs(np.squeeze(input_data)[:,slic,rep]))
    ax.plot(fit_params['freq_range_ppm'],np.abs(np.sum(np.squeeze(fit_spectrums)[:,slic,rep,:],axis=1)))
    ax.set_xlim([160,190])
    ax.set_title('df='+str(np.round(np.abs(np.sqrt(fit_stds[0,slic,0,0,rep,0,0,1]**2+fit_stds[0,slic,0,0,rep,0,1,1]**2)),1))+' Hz'+
                ', f='+str(np.round(fit_freqs[0,slic,0,0,rep,0,0],1))+' Hz')

In [None]:
plt.close('all')
fig,axes=plt.subplots(input_data.shape[1],input_data.shape[4],figsize=(12,3))

for slic in range(input_data.shape[1]):
    for rep in range(input_data.shape[4]):
        ax = axes[slic,rep]
        ax.plot(fit_params['freq_range_ppm'],np.abs(np.squeeze(input_data)[:,slic,rep]))
        ax.plot(fit_params['freq_range_ppm'],np.abs(np.sum(np.squeeze(fit_spectrums)[:,slic,rep,:],axis=1)))
        ax.set_xlim([160,190])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_title('')
        ax.set_xlabel('')
        ax.set_ylabel('')

[axes[n,0].set_ylabel('Slice '+str(n)) for n in range(input_data.shape[1])]
[axes[0,n].set_title('Rep '+str(n)) for n in range(input_data.shape[4])]
        
plt.subplots_adjust(wspace=0, hspace=0)

plt.savefig(revision_path+str(studyfolder)+'_fit_results.png',dpi=300)

# 5. Compute temperature

In [None]:
# Frequency difference [Hz]:

freq_diff_hz = np.take(fit_freqs, indices=1, axis=-1) - np.take(fit_freqs, indices=0, axis=-1)

# Frequency difference [ppm]:
freq_diff_ppm = uts.freq_Hz_to_ppm(freq_Hz  = freq_diff_hz,
                                   hz_axis  = fit_params["freq_range_Hz"],
                                   ppm_axis = fit_params['freq_range_ppm'],
                                   ppm_centered_at_0=True)

# Temperature [Hz]:
temp, _ = utf.temperature_from_frequency(frequency=freq_diff_ppm,
                                         calibration_type='5mM',
                                         frequency_is_ppm=True)

# std[freq difference [Hz]] [0/1 = pyr/lac, 1=freqs]
freq_diff_hz_std = np.sqrt(fit_stds[..., 0, 1]**2 +
                           fit_stds[..., 1, 1]**2 )

freq_diff_ppm_std = uts.freq_Hz_to_ppm(freq_Hz=freq_diff_hz_std,
                                   hz_axis  = fit_params["freq_range_Hz"],
                                   ppm_axis = fit_params['freq_range_ppm'],
                                   ppm_centered_at_0=True)


temp_plus_std, _ = utf.temperature_from_frequency(frequency=freq_diff_ppm+freq_diff_ppm_std,
                                         calibration_type='5mM',
                                         frequency_is_ppm=True)

temp_std = np.abs(temp_plus_std - temp)

# Selecting only values where the fit error is below 3°C and SNR in Pyruvate and Lactate above 5

In [None]:
fit_error_threshold = 3 #°C
snr_threshold=5

In [None]:
## SNR masking
snr,noise=uts.compute_snr_from_fit(input_data,fit_spectrums)

snr_mask_cond = (snr[..., 0] < snr_threshold) | (snr[..., 1] < snr_threshold)
snr_mask = np.where(snr_mask_cond, np.nan, 1)
snr_masked_temp=snr_mask*temp


## Fit error masking
fit_error_mask=np.where(temp_std <= fit_error_threshold, 1, np.nan)
fit_error_masked_temp=np.where(~np.isnan(fit_error_mask), temp, np.nan)


# combine and report all three variants: 1. SNR masked, 2. Fit error masked and 3. combined
snr_and_fit_error_masked_temp=snr_mask*fit_error_masked_temp


masked_temp=snr_and_fit_error_masked_temp
mean_values = np.nanmean(masked_temp, axis=(0, 2, 3, 4, 5))  # Using nanmean to ignore NaNs
# Compute standard deviation similarly
std_values = np.nanstd(masked_temp, axis=(0, 2, 3, 4, 5))
# Compute count of non-NaN values
count_values = np.sum(~np.isnan(masked_temp), axis=(0, 2, 3, 4, 5))
inverse_masked_T=np.where(np.isnan(snr_mask*fit_error_mask), temp, np.nan)

In [None]:
plt.close('all')
colors = get_colors_from_cmap('tab10', 10)
fig,ax=plt.subplots(3,3,figsize=(7,5),tight_layout=True)

for n in range(input_data.shape[1]):
    
    ax[0,n].plot(inverse_masked_T[0,n,0,0,:,0],'o',markersize=3,color='r',label='Excluded points')
    ax[0,n].plot(masked_temp[0,n,0,0,:,0],'o',markersize=3,color=colors[0])

    ax[1,n].plot(np.abs(temp_std[0,n,0,0,:,0]),'o',markersize=3,color=colors[1])
    ax[2,n].plot(np.abs(freq_diff_hz_std[0,n,0,0,:,0]),'o',markersize=3,color=colors[2])
    ax[1,n].hlines(3,0,15,linestyle='dashed',color='g')
    ax[2,n].hlines(1,0,15,linestyle='dashed',color='g')
    

    ax[0,n].set_ylim([20,45])
    ax[0,n].set_title('Slice '+str(n)+', T='+str(np.round(mean_values[n],1))+'±'+str(np.round(std_values[n],1))+r'$^\circ$C')
    ax[2,n].set_xlabel('Repetition')
    ax[0,n].set_ylabel(r'T[$^\circ$C]')
    ax[1,n].set_ylabel(r'dT[$^\circ$C]')
    ax[2,n].set_ylabel(r'df[Hz]')
    ax[0,n].legend(fontsize=6)

print(mean_values)
print(std_values)
fig.suptitle(str(studyfolder)+' Temperature filtered')

plt.savefig(revision_path+str(studyfolder)+'_temperature_results_SNR_thld.png',dpi=300)

# Compute AUCRs and errors for all metabolites for all repetitions

In [None]:
pyr_amp=np.abs(np.squeeze(fit_amps[:,:,:,:,:,:,0]))
lac_amp=np.abs(np.squeeze(fit_amps[:,:,:,:,:,:,1]))

d_pyr_amp = np.squeeze(fit_stds[..., 0, 0])
d_lac_amp = np.squeeze(fit_stds[..., 1, 0])

sum_pyr=np.sum(pyr_amp,axis=1)
sum_lac=np.sum(lac_amp,axis=1)


d_sum_pyr = np.sqrt(np.sum(d_pyr_amp**2, axis=1)) 
d_sum_lac = np.sqrt(np.sum(d_lac_amp**2, axis=1)) 

AUCR=sum_lac/sum_pyr

d_AUCR = np.abs(AUCR * np.sqrt((d_sum_pyr / sum_pyr)**2 + (d_sum_lac / sum_lac)**2))

print("AUCR:", AUCR)
print("Error in AUCR:", d_AUCR)


In [None]:
plt.close('all')
fig,ax=plt.subplots(input_data.shape[1],1,tight_layout=True,figsize=(5,7))
for slic in range(input_data.shape[1]):
    ax[slic].errorbar(np.arange(0,input_data.shape[4],1),pyr_amp[slic,:]/np.max(pyr_amp[slic,:]),yerr=np.abs(d_pyr_amp[slic,:])/np.max(pyr_amp[slic,:]),label='Pyruvate')
    ax[slic].errorbar(np.arange(0,input_data.shape[4],1),lac_amp[slic,:]/np.max(pyr_amp[slic,:]),yerr=np.abs(d_lac_amp[slic,:])/np.max(pyr_amp[slic,:]),label='Lactate')
    ax[slic].set_ylabel('Slice '+str(slic))
    ax[slic].legend()
    ax[slic].set_xlabel('Repetition')
    ax[slic].set_title('AUCR='+str(np.round(AUCR[slic],1))+'±'+str(np.round(d_AUCR[slic],3)))
fig.suptitle(str(studyfolder)+' AUCR')

plt.savefig(revision_path+str(studyfolder)+'_AUCR_results_SNR_thld.png',dpi=300)

# Output dataframe

In [None]:
output_df = pd.DataFrame()


output_df['ID']=[patient_info['ID']]*3
output_df['sex']=[patient_info['sex']]*3

output_df['Slice']=range(0,len(mean_values))
output_df['T']=mean_values
output_df['dT']=std_values
output_df['AUCR']=AUCR
output_df['dAUCR']=d_AUCR
output_df['n']=count_values
output_df.to_excel(revision_path+studyfolder+'_results_SNR_thld.xlsx')