# This Notebook is used to analyse mouse data from hyperpolarized 1-13C pyruvate injections

**Authors: Wolfgang Gottwald, Luca Nagel (2024)**

# 1. Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib
matplotlib.rcParams.update({'font.size': 12,'font.family':'serif','font.serif':['Computer Modern'],"text.usetex" : True,})

import ipywidgets as widgets
import datetime
import pandas as pd
import os
import nmrglue
import hypermri
import hypermri.utils.utils_anatomical as ut_anat
import hypermri.utils.utils_spectroscopy as ut_spec
import hypermri.utils.utils_fitting as ut_fitting
import hypermri.utils.utils_general as utg
from scipy.optimize import curve_fit
from matplotlib import cm
from astropy.modeling import models, fitting
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
    
import sys
# Get the current working directory
cwd = os.getcwd()
parent_dir = os.path.dirname(os.path.dirname(cwd))
sys.path.append(parent_dir)
# define paths:
sys.path.append('../')
import TEMPLATE
# get paths:
repopath, basepath, savepath,publication_path = TEMPLATE.import_all_packages(True)
savepath = os.path.join(savepath,'AnimalCSIResults')
from mpl_interactions import image_segmenter_overlayed


# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri or magritek folders
%load_ext autoreload
%autoreload 2

%matplotlib widget


#### Import fitting stuff

In [None]:
from hypermri.utils.utils_spectroscopy import find_npeaks as ut_find_npeaks
from hypermri.utils.utils_fitting import temperature_from_frequency

### Frequency to temperature calibration functions obtained from calibration measurements

In [None]:
from hypermri.utils.utils_fitting import temperature_from_frequency as temperature

# 2. Manually Input information about the experiment

In [None]:
# from docu sheet:
csi_high_temp_temperature = 38.3
csi_low_temp_temperature = 34.5

# animal details
animal_ID = ''
animal_gender = ''
animal_weight =  #g
scan_date = ''

# 3. Load/reconstruct Scans

In [None]:
dirpath = basepath+''
scans = hypermri.BrukerDir(dirpath,verbose=False)


### Select scans and perform reconstruction

In [None]:
csi_high_temp = scans[13]
csi_low_temp = scans[22]

# anatomical images
coronal = scans[27]
axial = scans[15]
sagittal = scans[16]
# coregistered with CSI
t2w_as_csi = scans[14]


# perform a reconstruction, including shifting the anatomical image:
csi_high_temp.full_reco(anatomical=t2w_as_csi)
csi_low_temp.full_reco(anatomical=t2w_as_csi)

# define path that the figures and stuff should be stored in:
csi_high_temp.savepath = basepath
csi_low_temp.savepath = basepath


# get Extent
ax_ext,   _, _   = utg.get_extent(data_obj=t2w_as_csi)
csi_ext,  _, _   = utg.get_extent(data_obj=csi_high_temp)

# 4. Interpolate csi data 2x

In [None]:
interpolation_factor = 2

In [None]:
csi_high_temp_interpolated = utg.interpolate_dataset(input_data=csi_high_temp.csi_image,
                                           interp_size=(csi_high_temp.csi_image.shape[0],
                                                        csi_high_temp.csi_image.shape[1],
                                                        interpolation_factor*csi_high_temp.csi_image.shape[2],
                                                        interpolation_factor*csi_high_temp.csi_image.shape[3],
                                                        csi_high_temp.csi_image.shape[4],
                                                        csi_high_temp.csi_image.shape[5]),
                                           data_obj=csi_high_temp,
                                           interp_method="cubic")

In [None]:
csi_low_temp_interpolated = utg.interpolate_dataset(input_data=csi_low_temp.csi_image,
                                           interp_size=(csi_low_temp.csi_image.shape[0],
                                                        csi_low_temp.csi_image.shape[1],
                                                        interpolation_factor*csi_low_temp.csi_image.shape[2],
                                                        interpolation_factor*csi_low_temp.csi_image.shape[3],
                                                        csi_low_temp.csi_image.shape[4],
                                                        csi_low_temp.csi_image.shape[5]),
                                           data_obj=csi_low_temp,
                                           interp_method="cubic")
                                               


### Quick look at the data

In [None]:
plt.close('all')
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(12,5),tight_layout=True)
csi_high_temp.plot2d(axlist=(ax1,ax2),axial_image=t2w_as_csi,csi_data=csi_high_temp_interpolated,fig=fig)

In [None]:
plt.close('all')
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(14,6),tight_layout=True)
csi_low_temp.plot2d(axlist=(ax1,ax2),axial_image=t2w_as_csi,csi_data=csi_low_temp_interpolated,fig=fig)

# 5. Load or fit function to spectra


In [None]:
filename = animal_ID+'_fit_spectra_interp.pkl'

fit_results = utg.load_as_pkl(dir_path=savepath,
            filename=filename)
fit_spectrums_high_temp_interp = fit_results["fit_spectrums_high_temp"]
fit_spectrums_low_temp_interp = fit_results["fit_spectrums_low_temp"]
fit_freqs_high_temp_interp = fit_results["fit_freqs_high_temp"]
fit_freqs_low_temp_interp = fit_results["fit_freqs_low_temp"]
fit_amps_high_temp_interp = fit_results["fit_amps_high_temp"]
fit_amps_low_temp_interp = fit_results["fit_amps_low_temp"]
fit_t2s_high_temp_interp = fit_results["fit_t2s_high_temp"]
fit_t2s_low_temp_interp = fit_results["fit_t2s_low_temp"]
fit_stds_high_temp_interp = fit_results["fit_stds_high_temp"]
fit_stds_low_temp_interp = fit_results["fit_stds_low_temp"]

fit_params_low_temp = fit_results["fit_params_low_temp"]
fit_params_high_temp = fit_results["fit_params_high_temp"]

print('Loaded ',savepath+'/'+filename)


## If loading was successful, skip until 6.

### Fitting: High temperature scan

#### Initialize fit params

In [None]:
# define metabolites:
metabs = ['pyruvate', 'lactate', 'alanine', 'pyruvatehydrate', 'urea']

niter = 1 # number of iterations:
npoints  = 21# number of tested points per iteration:
rep_fitting = 11


# define fit parameters:
fit_params = {}

fit_params["cut_off"] = 0
fit_params["rep_fitting"] = rep_fitting
fit_params["zoomfactor"] = 1.5
fit_params["fit_range_repetitions"] = 1
fit_params["metab_t2s"] = 0.01
fit_params["max_t2_s"] = 0.05 # 0.05
fit_params["range_t2s_s"] = 0.2
fit_params["init_fit"] = False
fit_params["range_freqs_Hz"] = 45

# define peak frequencies:
fit_params["metabs"] = metabs
fit_params["niter"] = niter # number of iterations:
fit_params["npoints"] = npoints # number of tested points per iteration:
fit_params["use_all_cores"]=True
# fill missing parameters:
fit_params = ut_fitting.def_fit_params(fit_params=fit_params,
                                       data_obj=csi_high_temp)

#### Perform fitting

In [None]:
# perform fitting:
fit_spectrums_high_temp_interp, fit_amps_high_temp_interp, fit_freqs_high_temp_interp, fit_t2s_high_temp_interp,fit_stds_high_temp_interp = ut_fitting.fit_data_pseudo_inv(input_data=csi_high_temp_interpolated,
                                                                                                             fit_params=fit_params,
                                                                                                             data_obj=csi_high_temp,
                                                                                                             use_multiprocessing=True,
                                                                                                             dbmode=False,
                                                                                                            )

In [None]:
plt.close('all')
fig,(ax1,ax2)=plt.subplots(1,2,tight_layout=True,figsize=(8,2.5))
im1=ax1.imshow(np.rot90(fit_freqs_high_temp_interp[0,0,:,:,0,0,0]))
ax1.set_title('Pyr freq')
im2=ax2.imshow(np.rot90(fit_freqs_high_temp_interp[0,0,:,:,0,0,1]))
ax2.set_title('Lac freq')

fig.colorbar(im1,ax=ax1,label='f[Hz]')
fig.colorbar(im2,ax=ax2,label='f[Hz]')


#### Low temperature scan

In [None]:
# define metabolites:
metabs = ['pyruvate', 'lactate', 'alanine', 'pyruvatehydrate', 'urea']

niter = 1 # number of iterations:
npoints  = 21# number of tested points per iteration:
rep_fitting = 11


# define fit parameters:
fit_params = {}

fit_params["cut_off"] = 0
fit_params["rep_fitting"] = rep_fitting
fit_params["zoomfactor"] = 1.5
fit_params["fit_range_repetitions"] = 1
fit_params["metab_t2s"] = 0.01
fit_params["max_t2_s"] = 0.05 # 0.05
fit_params["range_t2s_s"] = 0.2
fit_params["init_fit"] = False
fit_params["range_freqs_Hz"] = 45

# define peak frequencies:
fit_params["metabs"] = metabs
fit_params["niter"] = niter # number of iterations:
fit_params["npoints"] = npoints # number of tested points per iteration:
fit_params["use_all_cores"]=True
# fill missing parameters:
fit_params = ut_fitting.def_fit_params(fit_params=fit_params,
                                       data_obj=csi_high_temp)

In [None]:
# perform fitting:
fit_spectrums_low_temp_interp, fit_amps_low_temp_interp, fit_freqs_low_temp_interp, fit_t2s_low_temp_interp,fit_stds_low_temp_interp = ut_fitting.fit_data_pseudo_inv(input_data=csi_low_temp_interpolated,
                                                                                                         fit_params=fit_params,
                                                                                                         data_obj=csi_low_temp,
                                                                                                         use_multiprocessing=True,
                                                                                                         dbmode=False,
                                                                                                        use_all_cores=True)


In [None]:
plt.close('all')
fig,(ax1,ax2)=plt.subplots(1,2,tight_layout=True,figsize=(8,2.5))
im1=ax1.imshow(np.rot90(fit_freqs_low_temp_interp[0,0,:,:,0,0,0]))
ax1.set_title('Pyr freq')
im2=ax2.imshow(np.rot90(fit_freqs_low_temp_interp[0,0,:,:,0,0,1]))
ax2.set_title('Lac freq')

fig.colorbar(im1,ax=ax1,label='f[Hz]')
fig.colorbar(im2,ax=ax2,label='f[Hz]')


### Save fit results for faster processing later

In [None]:
fit_results = {}
fit_results["fit_spectrums_high_temp"] = fit_spectrums_high_temp_interp
fit_results["fit_spectrums_low_temp"] = fit_spectrums_low_temp_interp
fit_results["fit_freqs_high_temp"] = fit_freqs_high_temp_interp
fit_results["fit_freqs_low_temp"] = fit_freqs_low_temp_interp
fit_results["fit_amps_high_temp"] = fit_amps_high_temp_interp
fit_results["fit_amps_low_temp"] = fit_amps_low_temp_interp
fit_results["fit_t2s_high_temp"] = fit_t2s_high_temp_interp
fit_results["fit_t2s_low_temp"] = fit_t2s_low_temp_interp

fit_results["fit_params"] = fit_params
fit_results["frequency_error_low_temp"] = fit_stds_low_temp_interp[:,:,:,:,:,:,:,1]
fit_results["frequency_error_high_temp"] = fit_stds_high_temp_interp[:,:,:,:,:,:,:,1]
fit_results["csi_high_temp_temperature"] = csi_high_temp_temperature
fit_results["csi_low_temp_temperature"] = csi_low_temp_temperature
fit_results["animal_ID"]=animal_ID
fit_results["animal_gender"]=animal_gender
fit_results["animal_weight"]=animal_weight
fit_results["scan_date"]=scan_date


import pickle

utg.save_as_pkl(dir_path=savepath,
                filename=animal_ID + '_fit_spectra_interp',
                file = fit_results,
                file_keys=fit_results.keys(),
                use_timestamp=False)
                

# 6. Plot fit results on images

In [None]:
plt.close('all')
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(12,5),tight_layout=True)
csi_high_temp.plot2d(axlist=(ax1,ax2),axial_image=t2w_as_csi,csi_data=csi_high_temp_interpolated,csi_fit_data=np.sum(fit_spectrums_high_temp_interp,axis=-1, keepdims=True), fig=fig)

In [None]:
plt.close('all')
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(12,5),tight_layout=True)
csi_high_temp.plot2d(axlist=(ax1,ax2),axial_image=t2w_as_csi,csi_data=csi_low_temp_interpolated,csi_fit_data=np.sum(fit_spectrums_low_temp_interp,axis=-1, keepdims=True), fig=fig)

# 7. Find background noise region by drawing an ROI around the animal in a T2w scan

In [None]:
csi_ref_slice=3

## 8.2 Draw animal ROI for masking of outside pixels
#### Load animal ROI mask if present
#### Skip to 8.3 if loading below successful
We average the mask across all slices, since the CSI slice is much thicker than the anatomicals.
That way we can more accurately determine which signals of the CSI are from within the animal

In [None]:
try:
    mask_dict_animal = utg.load_as_pkl(savepath,
                       filename= animal_ID + '_animal_mask.pkl')
    print('Loaded mask',animal_ID + '_animal_mask.pkl')
    mask_dict_animal = {'animal':np.expand_dims(np.sum(mask_dict_animal['animal'],axis=1)>0,axis=1)}
    fig,ax=plt.subplots(1)
    ax.imshow(np.squeeze(mask_dict_animal['animal']))
except:
    print('no noise mask found in'+savepath)


### Draw mask around animal in all slices if you did not load one above

In [None]:
segmenter_list_animal = ut_anat.get_segmenter_list_overlayed(t2w_as_csi.seq2d_oriented,
                                                      np.zeros_like(t2w_as_csi.seq2d_oriented),
                                                      n_rois=1,
                                                      figsize=(6,6),
                                                      overlay=0.0,
                                                      bssfp_cmap='magma',
                                                      rot_images_deg=0, # necessary for coronal images (so far, could change after an update)
                                                      flip_images=False,
                                                      masks_drawn_on="axial")
# draw maks on coronal slices:
ut_anat.draw_masks_on_anatomical(segmenter_list_animal,
                             ['animal'])

### Retrieve mask, save it and average it

In [None]:
mask_dict_animal=ut_anat.get_masks(segmenter_list_animal,
                               plot_res=False,
                               roi_keys=['animal'],
                               masks_drawn_on='axial')

utg.save_as_pkl(dir_path=savepath,
                       filename= animal_ID + '_animal_mask',
                       file = mask_dict_animal,
                       use_timestamp=False)
# sum mask over all slices, need to add axis we just removed by averaging for accurate masking
mask_dict_animal_avg = {'animal':np.expand_dims(np.sum(mask_dict_animal['animal'],axis=1)>0,axis=1)}

## 8.4 Threshold data according to fit accuracy. 
#### 1. Pixels with a fit error above 1°C /

In [None]:
bitwise_thld = 0.5
temp_error_thld=0.5

In [None]:
# Your existing code to generate freq_diff_map_Hz
freq_diff_map_high_temp_Hz = np.abs(np.squeeze(fit_freqs_high_temp_interp)[:,:,0] - np.squeeze(fit_freqs_high_temp_interp)[:,:,1])

# Initialize the output array of the same shape as freq_diff_map_Hz
freq_diff_map_high_temp_ppm = np.zeros_like(freq_diff_map_high_temp_Hz)
# Iterate over all indices of the 2D array
for n in range(freq_diff_map_high_temp_Hz.shape[0]):  # Loop over the first dimension
    for m in range(freq_diff_map_high_temp_Hz.shape[1]):  # Nested loop over the second dimension
#        # Perform your operation on each element
        freq_diff_map_high_temp_ppm[n, m] = np.squeeze(ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs_high_temp_interp)[:,:,1][n, m], data_obj=csi_high_temp))-np.squeeze(ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs_high_temp_interp)[:,:,0][n, m], data_obj=csi_high_temp))

# Calculate the temperature map
temp_map_high, _ = temperature(freq_diff_map_high_temp_ppm, calibration_type='5mM')

# mask both frequency and temp map with animal ROI
freq_diff_map_high_temp_Hz = np.squeeze(utg.apply_mask(freq_diff_map_high_temp_Hz,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
freq_diff_map_high_temp_ppm = np.squeeze(utg.apply_mask(freq_diff_map_high_temp_ppm,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
temp_map_high = np.squeeze(utg.apply_mask(temp_map_high,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))

In [None]:
# Your existing code to generate freq_diff_map_Hz
freq_diff_map_low_temp_Hz = np.abs(np.squeeze(fit_freqs_low_temp_interp)[:,:,0] - np.squeeze(fit_freqs_low_temp_interp)[:,:,1])

# Initialize the output array of the same shape as freq_diff_map_Hz
freq_diff_map_low_temp_ppm = np.zeros_like(freq_diff_map_low_temp_Hz)
# Iterate over all indices of the 2D array
for n in range(freq_diff_map_low_temp_Hz.shape[0]):  # Loop over the first dimension
    for m in range(freq_diff_map_low_temp_Hz.shape[1]):  # Nested loop over the second dimension
        # Perform your operation on each element
        freq_diff_map_low_temp_ppm[n, m] = np.squeeze(ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs_low_temp_interp)[:,:,1][n, m], data_obj=csi_low_temp))-np.squeeze(ut_spec.freq_Hz_to_ppm(freq_Hz=np.squeeze(fit_freqs_low_temp_interp)[:,:,0][n, m], data_obj=csi_low_temp))

# Calculate the temperature map
temp_map_low, _ = temperature(freq_diff_map_low_temp_ppm, calibration_type='5mM')
# mask both frequency and temp map with animal ROI
freq_diff_map_low_temp_Hz = np.squeeze(utg.apply_mask(freq_diff_map_low_temp_Hz,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
freq_diff_map_low_temp_ppm = np.squeeze(utg.apply_mask(freq_diff_map_low_temp_ppm,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))
temp_map_low = np.squeeze(utg.apply_mask(temp_map_low,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.50))

### Hot Scan

In [None]:
plt.close('all')
# Error of difference between lac and pyruvate in Hz
fit_freqs_diff_stds_high_temp = np.sqrt(np.abs(fit_stds_high_temp_interp[0,0,:,:,0,0,0,1]**2) +
                                        np.abs(fit_stds_high_temp_interp[0,0,:,:,0,0,1,1]**2))


# get the uncertainty of the frequency difference in ppm:
fit_freqs_diff_stds_ppm_high_temp = ut_spec.freq_Hz_to_ppm(fit_freqs_diff_stds_high_temp,
                                                 data_obj=csi_high_temp,
                                                 ppm_centered_at_0=True)


# calculate the temperature and frequency errors:
fit_temps_stds_high_temp, _ =  np.squeeze(temp_map_high - np.asarray(temperature(freq_diff_map_high_temp_ppm+fit_freqs_diff_stds_ppm_high_temp, calibration_type='5mM')))

# Calculate mask:
temp_error_mask = np.where(fit_temps_stds_high_temp <= temp_error_thld, 1, np.nan)
temp_map_high=temp_map_high*temp_error_mask
dT = np.squeeze(utg.apply_mask(temp_error_mask*fit_temps_stds_high_temp,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),
                                mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.5))


fig,((ax1,ax2,ax3),(ax4,ax5,ax6))=plt.subplots(2,3,tight_layout=True,figsize=(10,8))

im1=ax1.imshow(np.rot90(np.real(np.squeeze(fit_stds_high_temp_interp[0,0,:,:,0,0,0,1]))),vmin=0, vmax=5, cmap='jet')
ax1.set_title('Pyr sigma freq')

im2=ax2.imshow(np.rot90(np.real(np.squeeze(fit_stds_high_temp_interp[0,0,:,:,0,0,1,1]))),vmin=0, vmax=5, cmap='jet')
ax2.set_title('Lac sigma freq')

im3=ax3.imshow(np.rot90(fit_freqs_diff_stds_high_temp), cmap='jet',vmin=0,vmax=5)
ax3.set_title('$\Delta f$')

im4=ax4.imshow(np.rot90(dT), cmap='jet',vmin=0,vmax=2*temp_error_thld)
ax4.set_title('$\Delta T<$'+str(temp_error_thld)+'°C+animal ROI')


im5=ax5.imshow(np.rot90(temp_map_high),vmin=30,vmax=42,cmap='jet')

ax6.axis('off')


fig.colorbar(im1,ax=ax1,label='Hz',shrink=0.65)
fig.colorbar(im2,ax=ax2,label='Hz',shrink=0.65)
fig.colorbar(im3,ax=ax3,label='Hz',shrink=0.65)
fig.colorbar(im4,ax=ax4,label='°C',shrink=0.65)
fig.colorbar(im5,ax=ax5,label='°C',shrink=0.65)



fig,ax=plt.subplots(1,2,tight_layout=True,figsize=(8,4))
ax[0].imshow(np.rot90(np.squeeze(t2w_as_csi.seq2d_oriented)[3]),extent=ax_ext,cmap='bone')
ax[0].imshow(np.rot90(temp_map_high),vmin=30,vmax=42,cmap='jet',alpha=0.4,extent=ax_ext)

ax[1].hist(np.ravel(temp_map_high)[~np.isnan(np.ravel(temp_map_high))],bins=15)
ax[0].set_title('Temp map')

ax[1].set_xlabel('T [°C]')
ax[1].set_ylabel('Num of pixels')



### Cold scan

In [None]:
plt.close('all')
# Error of difference between lac and pyruvate in Hz
fit_freqs_diff_stds_low_temp = np.sqrt(np.abs(fit_stds_low_temp_interp[0,0,:,:,0,0,0,1]**2) +
                                        np.abs(fit_stds_low_temp_interp[0,0,:,:,0,0,1,1]**2))


# get the uncertainty of the frequency difference in ppm:
fit_freqs_diff_stds_ppm_low_temp = ut_spec.freq_Hz_to_ppm(fit_freqs_diff_stds_low_temp,
                                                 data_obj=csi_low_temp,
                                                 ppm_centered_at_0=True)


# calculate the temperature and frequency errors:
fit_temps_stds_low_temp, _ =  np.squeeze(temp_map_low - np.asarray(temperature(freq_diff_map_low_temp_ppm+fit_freqs_diff_stds_ppm_low_temp, calibration_type='5mM')))

# Calculate mask:
temp_error_mask = np.where(fit_temps_stds_low_temp <= temp_error_thld, 1, np.nan)
temp_map_low=temp_map_low*temp_error_mask
dT = np.squeeze(utg.apply_mask(temp_error_mask*fit_temps_stds_low_temp,mask_dict_animal,'animal',return_nans=True,provided_dims=(2,3),
                                mask_slice_ind=0,bitwise=True,bitwise_lower_threshold=0.5))


fig,((ax1,ax2,ax3),(ax4,ax5,ax6))=plt.subplots(2,3,tight_layout=True,figsize=(10,8))

im1=ax1.imshow(np.rot90(np.real(np.squeeze(fit_stds_low_temp_interp[0,0,:,:,0,0,0,1]))),vmin=0, vmax=5, cmap='jet')
ax1.set_title('Pyr sigma freq')

im2=ax2.imshow(np.rot90(np.real(np.squeeze(fit_stds_low_temp_interp[0,0,:,:,0,0,1,1]))),vmin=0, vmax=5, cmap='jet')
ax2.set_title('Lac sigma freq')

im3=ax3.imshow(np.rot90(fit_freqs_diff_stds_low_temp), cmap='jet',vmin=0,vmax=5)
ax3.set_title('$\Delta f$')

im4=ax4.imshow(np.rot90(dT), cmap='jet',vmin=0,vmax=2*temp_error_thld)
ax4.set_title('$\Delta T<$'+str(temp_error_thld)+'°C+animal ROI')


im5=ax5.imshow(np.rot90(temp_map_low),vmin=30,vmax=42,cmap='jet')

ax6.axis('off')


fig.colorbar(im1,ax=ax1,label='Hz',shrink=0.65)
fig.colorbar(im2,ax=ax2,label='Hz',shrink=0.65)
fig.colorbar(im3,ax=ax3,label='Hz',shrink=0.65)
fig.colorbar(im4,ax=ax4,label='°C',shrink=0.65)
fig.colorbar(im5,ax=ax5,label='°C',shrink=0.65)



fig,ax=plt.subplots(1,2,tight_layout=True,figsize=(8,4))
ax[0].imshow(np.rot90(np.squeeze(t2w_as_csi.seq2d_oriented)[3]),extent=ax_ext,cmap='bone')
ax[0].imshow(np.rot90(temp_map_low),vmin=30,vmax=42,cmap='jet',alpha=0.4,extent=ax_ext)

ax[1].hist(np.ravel(temp_map_low)[~np.isnan(np.ravel(temp_map_low))],bins=15)
ax[0].set_title('Temp map')

ax[1].set_xlabel('T [°C]')
ax[1].set_ylabel('Num of pixels')



In [None]:
fig_width=6.9 # inch

In [None]:
fig,ax=plt.subplots(1,figsize=(fig_width/3,fig_width/3),tight_layout=True)

ax.imshow(np.rot90(t2w_as_csi.seq2d_oriented[0, csi_ref_slice, :, :, 0, 0]), cmap='bone',
                     extent=hypermri.utils.utils_general.get_plotting_extent(data_obj=csi_high_temp))
ax.axis('off')
plt.savefig(publication_path+'/03_Figure_csi_example/03_Figure_anatomical.svg')

In [None]:
fig,ax=plt.subplots(1,figsize=(fig_width/3,fig_width/3),tight_layout=True)

im1=ax.imshow(np.rot90(np.abs(fit_amps_low_temp_interp[0,0,:,:,0,0,0]/std_noise-1)), cmap='magma',
                     extent=hypermri.utils.utils_general.get_plotting_extent(data_obj=csi_high_temp))
ax.axis('off')
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cbar=fig.colorbar(im1, cax=cax, orientation='vertical', label='I [a.u.]')
cbar.set_ticks([0,20,40])
ax.text(0,5,'Pyruvate',color='w')
plt.savefig(publication_path+'/03_Figure_csi_example/03_Figure_pyruvate.svg')

In [None]:
fig,ax=plt.subplots(1,figsize=(fig_width/3,fig_width/3),tight_layout=True)

im1=ax.imshow(np.rot90(np.abs(fit_amps_low_temp_interp[0,0,:,:,0,0,1]/std_noise-1)), cmap='magma',
                     extent=hypermri.utils.utils_general.get_plotting_extent(data_obj=csi_high_temp))
ax.axis('off')
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cbar=fig.colorbar(im1, cax=cax, orientation='vertical', label='I [a.u.]')
cbar.set_ticks([0,7.5,15])
ax.text(0,5,'Lactate',color='w')
plt.savefig(publication_path+'/03_Figure_csi_example/03_Figure_lactate.svg')

In [None]:
plt.close('all')
fig,ax=plt.subplots(1,figsize=(fig_width/2,fig_width/2),tight_layout=True)
ax.imshow(np.rot90(t2w_as_csi.seq2d_oriented[0, csi_ref_slice, :, :, 0, 0]), cmap='bone',
                     extent=hypermri.utils.utils_general.get_plotting_extent(data_obj=csi_high_temp))

img=ax.imshow(np.rot90(temp_map_high),alpha=1,cmap='jet',vmin=30,vmax=42,extent=hypermri.utils.utils_general.get_plotting_extent(data_obj=csi_high_temp))
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cbar=fig.colorbar(img, cax=cax, orientation='vertical', label=r'T[$^\circ$C]')
cbar.set_ticks([30,32,34,36,38,40,42])
ax.axis('off')

plt.savefig(publication_path+'/Final/03_Figure_example_hot.svg')

fig,ax=plt.subplots(1,figsize=(fig_width/2,fig_width/2),tight_layout=True)
ax.imshow(np.rot90(t2w_as_csi.seq2d_oriented[0, csi_ref_slice, :, :, 0, 0]), cmap='bone',
                     extent=hypermri.utils.utils_general.get_plotting_extent(data_obj=csi_high_temp))
img=ax.imshow(np.rot90(temp_map_low),cmap='jet',vmin=30,vmax=42,alpha=1,extent=hypermri.utils.utils_general.get_plotting_extent(data_obj=csi_high_temp))


divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
cbar=fig.colorbar(img, cax=cax, orientation='vertical', label=r'T[$^\circ$C]')
cbar.set_ticks([30,32,34,36,38,40,42])
ax.axis('off')
plt.savefig(publication_path+'/Final/03_Figure_example_cold.svg')



In [None]:
fig,ax=plt.subplots(1,figsize=(fig_width*0.9,2.5),tight_layout=True)
bin_size=0.5


hot_map_plot=temp_map_high
cold_map_plot=temp_map_low
bins_hot = int(np.round((np.nanmax(hot_map_plot) - np.nanmin(hot_map_plot)) / bin_size))
bins_cold = int(np.round((np.nanmax(cold_map_plot) - np.nanmin(cold_map_plot)) / bin_size))
ax.hist(np.ravel(hot_map_plot)[~np.isnan(np.ravel(hot_map_plot))],bins=bins_hot,ec='k',color='C3',label='T='+str(csi_high_temp_temperature)+r'$^\circ$C')
ax.hist(np.ravel(cold_map_plot)[~np.isnan(np.ravel(cold_map_plot))],bins=bins_cold,ec='k',color='C0',label='T='+str(csi_low_temp_temperature)+r'$^\circ$C')

ax.set_xticks([30,32,34,36,38,40,42,44])
ax.set_yticks([0,10,20,30,40])
ax.set_xlabel(r'T[$^\circ$C]')
ax.set_ylabel('Voxels per bin')
ax.legend()
plt.savefig(publication_path+'/Final/03_Figure_histogram.svg')
print(np.nanmean(hot_map_plot).round(1),np.nanstd(hot_map_plot).round(1))
print(np.nanmean(cold_map_plot).round(1),np.nanstd(cold_map_plot).round(1))
