# 1. Importing packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib
#matplotlib.rcParams.update({'font.size': 12,'font.family':'serif','font.serif':['Computer Modern']})
from tqdm.auto import tqdm
import ipywidgets as widgets
import datetime
import pandas as pd
import os
import hypermri
import pydicom

from scipy.optimize import curve_fit
import scipy.io as sio
from matplotlib import cm

import hypermri.utils.utils_general as utg
import hypermri.utils.utils_spectroscopy as uts
import hypermri.utils.utils_fitting as utf

from hypermri.utils.utils_fitting import def_fit_params, fit_data_pseudo_inv
from hypermri.utils.utils_general import get_gmr
from hypermri.utils.utils_spectroscopy import get_metab_cs_ppm, make_NDspec_6Dspec


import sys
# define paths:
sys.path.append('../../')
import Template_utsw
basepath,savepath = Template_utsw.import_all_packages()

# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri package
%load_ext autoreload
%autoreload 2

%matplotlib widget
def get_colors_from_cmap(cmap_name, N):
    cmap = plt.get_cmap(cmap_name)
    colors = cmap(np.linspace(0, 1, N))
    return colors

In [None]:
revisionpath=r'.../Publication/Revision1/RefittedData/SP_Brain_healthy/'

fits_path='.../UTSW_brain/fit_results_400ms_100Hz_raw/'


## Select thresholds and which patient and acquisition to load

In [None]:
fit_error_threshold=3
snr_threshold=5

study_ids=['3TC230','3TC718','3TC741','3TC755']
injections=[[1,],[1,2],[1,2],[1,2]]
calibration_type='5mM'
all_temperatures=[]
plot=False
all_temperatures=[]


## Load raw and fit data and process it

In [None]:

for index in tqdm(range(4)):
    study_ID=study_ids[index]
    for injection in injections[index]:
        print('Looking at',study_ID,', injection',injection)
        fid=np.squeeze(sio.loadmat(os.path.join(basepath,str(study_ID)+'_raw_'+str(injection)+'.mat'))['fid'])
        time=np.squeeze(sio.loadmat(os.path.join(basepath,str(study_ID)+'_raw_'+str(injection)+'.mat'))['time'])
        ppm = np.array(np.squeeze(sio.loadmat(os.path.join(basepath,str(study_ID)+'_raw_'+str(injection)+'.mat'))['ppm']))
        bw=5000 #Hz
        dwelltime=1/bw
        nsample_points=fid.shape[0]
        freq_range_Hz = uts.get_freq_axis(unit="Hz", sampling_dt=dwelltime, npoints=nsample_points)
        gmr = get_gmr(nucleus="13c")
        freq_range_ppm = freq_range_Hz / gmr / 3 + ppm[4095]

        input_fid= make_NDspec_6Dspec(input_data=fid, provided_dims=["spec", "reps", "chan","z"])
        input_spec=np.fft.fftshift(np.fft.fft(input_fid,axis=0),axes=(0,))


        mod_fid = np.conj(np.flip(input_fid,axis=0))
        mod_spec=np.flip(np.fft.fftshift(np.fft.fft(mod_fid,axis=0),axes=(0,)),axis=0)

        all_files_in_dir = os.listdir(fits_path)
        fitted_files = []
        for index,file in enumerate(all_files_in_dir):
            if file.endswith('.pkl'):
                fitted_files.append(file)
                #print('Found files ',file)
        # sort files by number
        fitted_files.sort()
        # look for file matching the studyfolder_num
        for n in range(len(fitted_files)):
            if fitted_files[n].find(str(study_ID)+'_fit_spectra_study_'+str(injection))==0:
                load_file=fitted_files[n]
            else:
                pass

        print('Loading',load_file)
        fit_results=utg.load_as_pkl(dir_path=fits_path, filename=load_file, global_vars=globals())
        fit_spectrums = fit_results['fit_spectrums']
        fit_amps = fit_results['fit_amps']
        fit_freqs = fit_results['fit_freqs']
        fit_t2s = fit_results['fit_t2s']
        fit_params = fit_results['fit_params']
        fit_stds=fit_results['fit_stds']

        # Frequency difference [Hz]:

        freq_diff_hz = np.take(fit_freqs, indices=1, axis=-1) - np.take(fit_freqs, indices=0, axis=-1)

        # Frequency difference [ppm]:
        freq_diff_ppm = uts.freq_Hz_to_ppm(freq_Hz  = freq_diff_hz,
                                           hz_axis  = fit_params["freq_range_Hz"],
                                           ppm_axis = fit_params['freq_range_ppm'],
                                           ppm_centered_at_0=True)

        # Temperature [Hz]:
        temp, _ = utf.temperature_from_frequency(frequency=freq_diff_ppm,
                                                 calibration_type=calibration_type,
                                                 frequency_is_ppm=True)

        # std[freq difference [Hz]] [0/1 = pyr/lac, 1=freqs]
        freq_diff_hz_std = np.sqrt(fit_stds[..., 0, 1]**2 +
                                   fit_stds[..., 1, 1]**2 )

        freq_diff_ppm_std = uts.freq_Hz_to_ppm(freq_Hz=freq_diff_hz_std,
                                           hz_axis  = fit_params["freq_range_Hz"],
                                           ppm_axis = fit_params['freq_range_ppm'],
                                           ppm_centered_at_0=True)


        temp_plus_std, _ = utf.temperature_from_frequency(frequency=freq_diff_ppm+freq_diff_ppm_std,
                                                 calibration_type=calibration_type,
                                                 frequency_is_ppm=True)

        temp_std = np.abs(temp_plus_std - temp)


        ## SNR masking
        snr,noise=uts.compute_snr_from_fit(mod_spec,fit_spectrums)

        snr_mask_cond = (snr[..., 0] < snr_threshold) | (snr[..., 1] < snr_threshold)
        snr_mask = np.where(snr_mask_cond, np.nan, 1)
        snr_masked_temp=snr_mask*temp


        ## Fit error masking
        fit_error_mask=np.where(temp_std <= fit_error_threshold, 1, np.nan)
        fit_error_masked_temp=np.where(~np.isnan(fit_error_mask), temp, np.nan)


        # combine and report all three variants: 1. SNR masked, 2. Fit error masked and 3. combined
        snr_and_fit_error_masked_temp=snr_mask*fit_error_masked_temp


        
        def plot_temp_result(masked_temp):
            if plot==True:
                colors = get_colors_from_cmap('tab10', 10)
                fig,ax=plt.subplots(2,2,figsize=(7,5),tight_layout=True)
                for n in range(masked_temp.shape[1]):
                    nx=n//2
                    ny=n%2
                    for chan in range(masked_temp.shape[5]):
                        ax[nx,ny].plot(masked_temp[0,n,0,0,:,chan],'-o',markersize=3,color=colors[chan],label='chan '+str(chan))
                        ax[nx,ny].set_xlabel('Repetition')
                        ax[nx,ny].set_ylabel(r'T[$^\circ$C]')
                ax[0,1].legend(ncol=3,fontsize=8)
            else:
                pass
            mean_values = np.nanmean(masked_temp, axis=(0, 2, 3, 4, 5))  # Using nanmean to ignore NaNs
            # Compute standard deviation similarly
            std_values = np.nanstd(masked_temp, axis=(0, 2, 3, 4, 5))
            # Compute count of non-NaN values
            count_values = np.sum(~np.isnan(masked_temp), axis=(0, 2, 3, 4, 5))
            dataframe=pd.DataFrame(columns=['ID','slice','T','dT','n'])
            rows=[]
            for slic in range(masked_temp.shape[1]):

                row={
                    'ID': study_ID,
                    'inj':injection,
                    'slice': slic,
                    'T': mean_values[slic],
                    'dT': std_values[slic],
                    'n': count_values[slic]}
                rows.append(row)    
            dataframe = pd.concat([dataframe, pd.DataFrame(rows)], ignore_index=True)

            return dataframe


        all_temperatures.append(snr_and_fit_error_masked_temp)

        df_snr_fit_err_masked=plot_temp_result(snr_and_fit_error_masked_temp)

        df_output=df_snr_fit_err_masked
        df_output.to_excel(revisionpath+str(study_ID)+'_results_rawdata_inj_'+str(injection)+'_thld_snr_fit.xlsx')


In [None]:
data = np.array(all_temperatures).flatten()  
np.save(revisionpath+'human_slicespec_healthy_brain_temp_values_snr_thld.npy',np.array(all_temperatures))