In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib
matplotlib.rcParams.update({'font.size': 12,'font.family':'serif','font.serif':['Computer Modern'],"text.usetex" : True,})

import ipywidgets as widgets
import datetime
import pandas as pd
import os
import nmrglue
import hypermri
import hypermri.utils.utils_anatomical as ut_anat
import hypermri.utils.utils_spectroscopy as ut_spec
import hypermri.utils.utils_fitting as ut_fitting
import hypermri.utils.utils_general as utg
from scipy.optimize import curve_fit
from matplotlib import cm
from astropy.modeling import models, fitting
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
    
import sys
# Get the current working directory
cwd = os.getcwd()
parent_dir = os.path.dirname(os.path.dirname(cwd))
sys.path.append(parent_dir)
# define paths:
sys.path.append('../../')
import TEMPLATE
# get paths:
repopath, basepath, savepath = TEMPLATE.import_all_packages()
savepath = os.path.join(savepath,'AnimalCSIResults')
from mpl_interactions import image_segmenter_overlayed
from hypermri.utils.utils_spectroscopy import find_npeaks as ut_find_npeaks
from hypermri.utils.utils_fitting import temperature_from_frequency
from hypermri.utils.utils_fitting import temperature_from_frequency as temperature
# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri or magritek folders
%load_ext autoreload
%autoreload 2

%matplotlib widget


# 1. Define input parameters

In [None]:
dirpaths = ['',
           '',
            '',
            '',
            '',
            '',
            '',
            ''
           ]
animal_IDs = ['',
              '',
              '',
              '',
              '',
              '',
              '',
              ''
             ]
csi_nums =[[9,28],
           [23,32],
           [24,35],
           [13,22],
           [13,22],
           [13,22],
           [13,22],
           [13,22]
          ]

t2w_nums=[34,
          34,
          38,
          14,
          19,
          28,
          19,
          37
         ]
csi_ref_slice=3
animal_info_df = pd.read_excel(savepath+'/Animal_information.xlsx')

In [None]:
interpolation_factor=2
bitwise_thld = 0.5
temp_error_thld=0.5

In [None]:
temperature_maps=[]
organ_rois=[]
anat_refs=[]

In [None]:
for animal_idx, animal_ID in enumerate(animal_IDs):
    plt.close('all')
    scans=hypermri.BrukerDir(basepath+dirpaths[animal_idx],verbose=False)
    
    csi_high_temp = scans[csi_nums[animal_idx][0]]
    csi_low_temp = scans[csi_nums[animal_idx][1]]
    t2w_as_csi=scans[t2w_nums[animal_idx]]
    for animal in animal_info_df['Animal'].unique():
        csi_high_temp_temperature=np.squeeze(animal_info_df[animal_info_df['Animal'] == animal]['High temp rect'])
        csi_low_temp_temperature=np.squeeze(animal_info_df[animal_info_df['Animal'] == animal]['Low temp rect'])                     
        scan_date=np.squeeze(animal_info_df[animal_info_df['Animal'] == animal]['Date'])
        animal_gender=np.squeeze(animal_info_df[animal_info_df['Animal'] == animal]['Gender'])
        animal_weight=np.squeeze(animal_info_df[animal_info_df['Animal'] == animal]['Weight'])
    # perform a reconstruction, including shifting the anatomical image:
    csi_high_temp.full_reco(anatomical=t2w_as_csi)
    csi_low_temp.full_reco(anatomical=t2w_as_csi)

    # define path that the figures and stuff should be stored in:
    csi_high_temp.savepath = basepath
    csi_low_temp.savepath = basepath

    # get Extent
    ax_ext,   _, _   = utg.get_extent(data_obj=t2w_as_csi)
    csi_ext,  _, _   = utg.get_extent(data_obj=csi_high_temp)
    
    csi_high_temp_interpolated = utg.interpolate_dataset(input_data=csi_high_temp.csi_image,
                                           interp_size=(csi_high_temp.csi_image.shape[0],
                                                        csi_high_temp.csi_image.shape[1],
                                                        interpolation_factor*csi_high_temp.csi_image.shape[2],
                                                        interpolation_factor*csi_high_temp.csi_image.shape[3],
                                                        csi_high_temp.csi_image.shape[4],
                                                        csi_high_temp.csi_image.shape[5]),
                                           data_obj=csi_high_temp,
                                           interp_method="cubic")
    
    csi_low_temp_interpolated = utg.interpolate_dataset(input_data=csi_low_temp.csi_image,
                                           interp_size=(csi_low_temp.csi_image.shape[0],
                                                        csi_low_temp.csi_image.shape[1],
                                                        interpolation_factor*csi_low_temp.csi_image.shape[2],
                                                        interpolation_factor*csi_low_temp.csi_image.shape[3],
                                                        csi_low_temp.csi_image.shape[4],
                                                        csi_low_temp.csi_image.shape[5]),
                                           data_obj=csi_low_temp,
                                           interp_method="cubic")
                                               
    filename = animal_ID+'_fit_spectra_interp.pkl'

    fit_results = utg.load_as_pkl(dir_path=savepath,
                filename=filename)
    fit_spectrums_high_temp_interp = fit_results["fit_spectrums_high_temp"]
    fit_spectrums_low_temp_interp = fit_results["fit_spectrums_low_temp"]
    fit_freqs_high_temp_interp = fit_results["fit_freqs_high_temp"]
    fit_freqs_low_temp_interp = fit_results["fit_freqs_low_temp"]
    fit_amps_high_temp_interp = fit_results["fit_amps_high_temp"]
    fit_amps_low_temp_interp = fit_results["fit_amps_low_temp"]
    fit_t2s_high_temp_interp = fit_results["fit_t2s_high_temp"]
    fit_t2s_low_temp_interp = fit_results["fit_t2s_low_temp"]
    try:
        fit_stds_high_temp_interp = fit_results["fit_stds_high_temp"]
        fit_stds_low_temp_interp = fit_results["fit_stds_low_temp"]
    except:
        fit_stds_high_temp_interp = fit_results["fit_stds_high_temp_interp"]
        fit_stds_low_temp_interp = fit_results["fit_stds_low_temp_interp"]
    fit_params_low_temp = fit_results["fit_params_low_temp"]
    fit_params_high_temp = fit_results["fit_params_high_temp"]

    
    try:
        mask_dict_animal = utg.load_as_pkl(savepath,
                           filename= animal_ID + '_animal_mask.pkl')
        print('Loaded mask',animal_ID + '_animal_mask.pkl')
        mask_dict_animal = {'animal':np.expand_dims(np.sum(mask_dict_animal['animal'],axis=1)>0,axis=1)}
        fig,ax=plt.subplots(1)
        ax.imshow(np.squeeze(mask_dict_animal['animal']))
    except:
        print('no animal ROI mask found in'+savepath)
    

    freq_diff_map_high_temp_Hz = np.abs(np.squeeze(fit_freqs_high_temp_interp)[:,:,0] - np.squeeze(fit_freqs_high_temp_interp)[:,:,1])

    # Initialize the output array of the same shape as freq_diff_map_Hz
    freq_diff_map_high_temp_ppm = np.zeros_like(freq_diff_map_high_temp_Hz)
    # Iterate over all indices of the 2D array
    for n in range(freq_diff_map_high_temp_Hz.shape[0]):  # Loop over the first dimension
        for m in range(freq_diff_map_high_temp_Hz.shape[1]):  # Nested loop over the second dimension
    #        # Perform your operation on each element
            freq_diff_map_high_temp_ppm[n, m] = np.squeeze(ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs_high_temp_interp)[:,:,1][n, m], data_obj=csi_high_temp))-np.squeeze(ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs_high_temp_interp)[:,:,0][n, m], data_obj=csi_high_temp))

    # Calculate the temperature map
    temp_map_high, _ = temperature(freq_diff_map_high_temp_ppm, calibration_type='5mM')

    # mask both frequency and temp map with animal ROI
    freq_diff_map_high_temp_Hz = np.squeeze(utg.apply_mask(freq_diff_map_high_temp_Hz,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
    freq_diff_map_high_temp_ppm = np.squeeze(utg.apply_mask(freq_diff_map_high_temp_ppm,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
    temp_map_high = np.squeeze(utg.apply_mask(temp_map_high,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
    temperature_maps.append(temp_map_high)
    

    # Your existing code to generate freq_diff_map_Hz
    freq_diff_map_low_temp_Hz = np.abs(np.squeeze(fit_freqs_low_temp_interp)[:,:,0] - np.squeeze(fit_freqs_low_temp_interp)[:,:,1])

    # Initialize the output array of the same shape as freq_diff_map_Hz
    freq_diff_map_low_temp_ppm = np.zeros_like(freq_diff_map_low_temp_Hz)
    # Iterate over all indices of the 2D array
    for n in range(freq_diff_map_low_temp_Hz.shape[0]):  # Loop over the first dimension
        for m in range(freq_diff_map_low_temp_Hz.shape[1]):  # Nested loop over the second dimension
            # Perform your operation on each element
            freq_diff_map_low_temp_ppm[n, m] = np.squeeze(ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs_low_temp_interp)[:,:,1][n, m], data_obj=csi_low_temp))-np.squeeze(ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs_low_temp_interp)[:,:,0][n, m], data_obj=csi_low_temp))

    # Calculate the temperature map
    temp_map_low, _ = temperature(freq_diff_map_low_temp_ppm, calibration_type='5mM')
    
    temperature_maps.append(temp_map_low)
    # mask both frequency and temp map with animal ROI
    freq_diff_map_low_temp_Hz = np.squeeze(utg.apply_mask(freq_diff_map_low_temp_Hz,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
    freq_diff_map_low_temp_ppm = np.squeeze(utg.apply_mask(freq_diff_map_low_temp_ppm,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
    temp_map_low = np.squeeze(utg.apply_mask(temp_map_low,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
    
       


    # Error of difference between lac and pyruvate in Hz
    fit_freqs_diff_stds_high_temp = np.sqrt(np.abs(fit_stds_high_temp_interp[0,0,:,:,0,0,0,1]**2) +
                                            np.abs(fit_stds_high_temp_interp[0,0,:,:,0,0,1,1]**2))


    # get the uncertainty of the frequency difference in ppm:
    fit_freqs_diff_stds_ppm_high_temp = ut_spec.freq_Hz_to_ppm(fit_freqs_diff_stds_high_temp,
                                                     data_obj=csi_high_temp,
                                                     ppm_centered_at_0=True)


    # calculate the temperature and frequency errors:
    fit_temps_stds_high_temp, _ =  np.squeeze(temp_map_high - np.asarray(temperature(freq_diff_map_high_temp_ppm+fit_freqs_diff_stds_ppm_high_temp, calibration_type='5mM')))

    # Calculate mask:
    temp_error_mask_high = np.where(fit_temps_stds_high_temp <= temp_error_thld, 1, np.nan)

    
    # Error of difference between lac and pyruvate in Hz
    fit_freqs_diff_stds_low_temp = np.sqrt(np.abs(fit_stds_low_temp_interp[0,0,:,:,0,0,0,1]**2) +
                                            np.abs(fit_stds_low_temp_interp[0,0,:,:,0,0,1,1]**2))


    # get the uncertainty of the frequency difference in ppm:
    fit_freqs_diff_stds_ppm_low_temp = ut_spec.freq_Hz_to_ppm(fit_freqs_diff_stds_low_temp,
                                                     data_obj=csi_low_temp,
                                                     ppm_centered_at_0=True)


    # calculate the temperature and frequency errors:
    fit_temps_stds_low_temp, _ =  np.squeeze(temp_map_low - np.asarray(temperature(freq_diff_map_low_temp_ppm+fit_freqs_diff_stds_ppm_low_temp, calibration_type='5mM')))

    # Calculate mask:
    temp_error_mask_low = np.where(fit_temps_stds_low_temp <= temp_error_thld, 1, np.nan)



    low_temp_indiv_res_freq_error = csi_low_temp.analyze_freq_maps(np.squeeze(temp_error_mask_low),
                                                        freq_diff_map_low_temp_Hz,
                                                        t2w_as_csi,
                                                        csi_ref_slice,
                                                        temperature_map=temp_map_low,colormap='jet',savepath=None)
    high_temp_indiv_res_freq_error = csi_high_temp.analyze_freq_maps(np.squeeze(temp_error_mask_high),
                                                        freq_diff_map_high_temp_Hz,
                                                        t2w_as_csi,
                                                        csi_ref_slice,
                                                        temperature_map=temp_map_high,colormap='jet',savepath=None)

 

    output_dict_indiv_mask_error_thld = {'Date':[scan_date],
                   'Animal':animal_ID,
                  'Gender':animal_gender,
                  'Weight':animal_weight,
                  'High temp rect':csi_high_temp_temperature,
                  'Low temp rect':csi_low_temp_temperature,
                   'Mask':'indiv',
                    'Threshold':'Error',
                    'Whole body high Freq.':high_temp_indiv_res_freq_error['meaned_frq_all_pixels'],
                   'Whole body low Freq. ':low_temp_indiv_res_freq_error['meaned_frq_all_pixels'],
                   'Whole body high Freq. std':high_temp_indiv_res_freq_error['std_frq_all_pixels'],
                   'Whole body low Freq. std':low_temp_indiv_res_freq_error['std_frq_all_pixels'],

                   'Whole body high Temp.':high_temp_indiv_res_freq_error['meaned_temp_all_pixels'],
                   'Whole body low Temp.':low_temp_indiv_res_freq_error['meaned_temp_all_pixels'],
                   'Whole body high Temp. std':high_temp_indiv_res_freq_error['std_temp_all_pixels'],
                   'Whole body low Temp. std':low_temp_indiv_res_freq_error['std_temp_all_pixels'],

                  }


    output_df_indiv_mask_error_thld=pd.DataFrame(data = output_dict_indiv_mask_error_thld)

    
    try:
        masks = hypermri.utils.utils_anatomical.load_mask(savepath,str(animal_ID)+'_organ_masks_vessel.npz',plot_res=True)
        roi_names = list(masks.keys())
    except FileNotFoundError as Error:
        print('Organ mask with vessel not found')
    
    
    output_dict_keys = [
            'high Freq.',
            'low Freq.',
            'high Temp.',
            'low Temp.'
    ]
    organ_rois.append(masks)

    masked_param_maps_indiv_error_thld = [high_temp_indiv_res_freq_error['masked_frq_map'],
                         low_temp_indiv_res_freq_error['masked_frq_map'],
                        high_temp_indiv_res_freq_error['masked_temp_map'],
                        low_temp_indiv_res_freq_error['masked_temp_map']]
    
    anat_data = np.squeeze(t2w_as_csi.seq2d_oriented)
    for param_number in range(len(output_dict_keys)):
        meaned_params_indiv_error_thld,std_params_indiv_error_thld = ut_anat.mask_parameter_map(masked_param_maps_indiv_error_thld[param_number],masks,weight_result=True)

        for ROI_num,ROI_name in enumerate(masks):

            key_mean = str(ROI_name+' '+output_dict_keys[param_number])
            key_std = str(ROI_name+' '+output_dict_keys[param_number])+' std'
            output_dict_indiv_mask_error_thld.update({key_mean:meaned_params_indiv_error_thld[ROI_name]})
            output_dict_indiv_mask_error_thld.update({key_std:std_params_indiv_error_thld[ROI_name]})


    
    

    gridspec = {'width_ratios': [1, 1,  0.1]}
    fig, ax = plt.subplots(1, 3, figsize=(12, 4), gridspec_kw=gridspec)
    anat_shape=np.squeeze(t2w_as_csi.seq2d_oriented).shape
    ax[0].imshow(np.rot90(t2w_as_csi.seq2d_oriented[0, csi_ref_slice, :, :, 0, 0]),cmap="bone",extent=[0,anat_shape[1],anat_shape[2],0])
    ax[0].imshow(high_temp_indiv_res_freq_error['masked_temp_map'],cmap='jet',extent=[0,anat_shape[1],anat_shape[2],0],alpha=0.4,vmin=30,vmax=42)
    ax[1].imshow(np.rot90(t2w_as_csi.seq2d_oriented[0, csi_ref_slice, :, :, 0, 0]),cmap="bone",extent=[0,anat_shape[1],anat_shape[2],0])
    img=ax[1].imshow(low_temp_indiv_res_freq_error['masked_temp_map'],cmap='jet',extent=[0,anat_shape[1],anat_shape[2],0],alpha=0.4,vmin=30,vmax=42)

    roi_coords=ut_anat.get_roi_coords(masks)
    for n in range(2):
        for roi_num in range(len(roi_coords)):

            if len(roi_coords[roi_num]) > 0:
                mask_ID = list(masks.keys())[roi_num]
                if n==0:
                    roi_temp = 'T='+str(np.round(output_dict_indiv_mask_error_thld[str(mask_ID)+' high Temp.'],1))+' +/- '+str(np.round(output_dict_indiv_mask_error_thld[str(mask_ID)+' high Temp. std'],1))+'°C'
                if n==1:
                    roi_temp = 'T='+str(np.round(output_dict_indiv_mask_error_thld[str(mask_ID)+' low Temp.'],1))+' +/- '+str(np.round(output_dict_indiv_mask_error_thld[str(mask_ID)+' low Temp. std'],1))+'°C'
                ax[n].plot(
                    np.squeeze(roi_coords[roi_num])[:, 1],
                    np.squeeze(roi_coords[roi_num])[:, 0],
                    linewidth=2,
                    color="C" + str(roi_num),
                    label=roi_temp,
                )


        ax[n].legend()
        ax[n].set_xticks([])
        ax[n].set_yticks([])
    ax[0].set_title(r'$T_{hot}$=('+str(np.round(high_temp_indiv_res_freq_error['meaned_temp_all_pixels'],1))+r'$\pm$'+str(np.round(high_temp_indiv_res_freq_error['std_temp_all_pixels'],1))+')°C')
    ax[1].set_title(r'$T_{cold}$=('+str(np.round(low_temp_indiv_res_freq_error['meaned_temp_all_pixels'],1))+r'$\pm$'+str(np.round(low_temp_indiv_res_freq_error['std_temp_all_pixels'],1))+')°C')

    cbar=fig.colorbar(img, cax=ax[2], orientation="vertical", label='T [°C]')
    
    cbar.solids.set(alpha=1)
    fig.suptitle('Rectal temperature: '+str(output_dict_indiv_mask_error_thld['High temp rect'])+'°C vs. '+str(output_dict_indiv_mask_error_thld['Low temp rect'])+'°C')

    
    
    anat_refs.append(np.rot90(t2w_as_csi.seq2d_oriented[0, csi_ref_slice, :, :, 0, 0]))
    
    final_df = pd.DataFrame(data = output_dict_indiv_mask_error_thld)

    
    print('Animal ',animal_ID,'finished')
    print('----------------------')
    print('----------------------')
    print('----------------------')
    