In [None]:
import numpy as np
import matplotlib.pyplot as plt

import ipywidgets as widgets
import pandas as pd
import os
import pydicom
import scipy.io as sio
import hypermri.utils.utils_general as utg
import hypermri.utils.utils_fitting as utf
import hypermri.utils.utils_anatomical as uta
import hypermri.utils.utils_spectroscopy as uts
import seaborn as sns

from hypermri.utils.utils_fitting import def_fit_params, fit_t2_pseudo_inv, fit_freq_pseudo_inv, fit_data_pseudo_inv, fit_func_pseudo_inv, plot_fitted_spectra, basefunc
from hypermri.utils.utils_general import get_gmr, calc_sampling_time_axis
from hypermri.utils.utils_spectroscopy import apply_lb, multi_dim_linebroadening, get_metab_cs_ppm, generate_fid, get_freq_axis, make_NDspec_6Dspec, freq_to_index, find_npeaks

import sys
# define paths:
sys.path.append('../../')

import hypermri.utils.utils_spectroscopy as uts


# Autoreload extension so that you dont have to reload the kernel every time something is changed in the hypermri or magritek folders
%load_ext autoreload
%autoreload 2

%matplotlib widget

In [None]:
studyfolder_names = ['','','','','','','']
all_patient_df=pd.DataFrame(index=range(len(studyfolder_names)))

file_names = ['','','','','','','']

for pat_number, studyfolder in enumerate(studyfolder_names):
    if pat_number in [3,6]:
        pass
    else:
        print(studyfolder)
        animal_id = studyfolder
        # Extracting the last folder name from the basepath
        savefolder = basepath.split('\\')[-1]
        savepath = os.path.join(savepath, savefolder, "Slicespec")
        mgs_file = sio.loadmat(os.path.join(basepath, studyfolder, file_names[pat_number] + ".mat"))["A"]
        dataset_header_struct_arr = \
            sio.loadmat(os.path.join(basepath, studyfolder, file_names[pat_number] + "_header.mat"))['header']

        # Get field names from the structured array
        field_names = dataset_header_struct_arr.dtype.names

        dataset_header = {}

        # Iterate over the field names and process the data
        for field in field_names:
            # Access the data for each field
            dataset_header[field] = np.squeeze(dataset_header_struct_arr[field][0][0][0])
        for key in dataset_header['image'].dtype.names:
            temp = np.array(dataset_header['image'][key]).item()

            # Check if the structure is a NumPy array
            if isinstance(temp, np.ndarray):
                if temp.ndim == 0:
                    value = temp.item()  # Use .item() for 0-dimensional array
                else:
                    try:
                        value = temp[0][0]  # Use indexing for higher-dimensional array
                    except:
                        pass
                dataset_header['image'][key] = value
        patient_info = {}
        patient_info['ID'] = animal_id
        if dataset_header['exam']['patsex'] == 0:
            patient_info['sex'] = 'male'
        else:
            patient_info['sex'] = 'female'
        patient_info['weight'] = float(dataset_header['exam']['patweight'] / 1e3)
        patient_info['pyr_vol'] = patient_info['weight'] * 0.4
        patient_info['scan_date'] = str(dataset_header['rdb_hdr']['scan_date'])
        patient_info['scan_time'] = str(dataset_header['rdb_hdr']['scan_time'])
        import numpy as np

        rdb_hdr = dataset_header['rdb_hdr']
        fields = rdb_hdr.dtype.names
        for f in fields:
            # print(f"{f}: {rdb_hdr[f]}")
            pass
        # repetition time:
        tr = dataset_header['image']['tr'] / 1e6
        tr = 4e3
        # bandwidth
        bw = dataset_header['rdb_hdr']['spectral_width']
        # center frequency:
        freq_cent_hz = dataset_header['rdb_hdr']['ps_mps_freq'] / 10.0
        # gyromagnetic ratio MHz/T
        gmr_mhz_t = utg.get_gmr(nucleus="13c")
        # B0 in Tesla:
        b0_off_t = freq_cent_hz / (gmr_mhz_t * 1e6)
        freq0 = 3 * (gmr_mhz_t)
        freq_off_hz = freq_cent_hz - freq0 * 1e6
        freq_off_ppm = (freq_cent_hz - freq0 * 1e6) / (freq0 * 1e6) * 1e6
        # sampling time:
        dt = 1. / bw
        # flip angle:
        fa = dataset_header['image']['mr_flip']
        dyn_fid = mgs_file[0::8, 1:, 0, 0, :]

        # spectrum has to be flipped (and complex conjugated to have proper FID) (see flip_spec_complex_
        dyn_spec = np.conj(np.flip(np.fft.fftshift(np.fft.fft(dyn_fid, axis=1), axes=(1,)), axis=1))

        freq_range_Hz = np.squeeze(uts.get_freq_axis(npoints=dyn_spec.shape[1], sampling_dt=dt, unit='Hz'))
        time_axis = utg.calc_sampling_time_axis(npoints=dyn_spec.shape[1], sampling_dt=dt)

        input_data = uts.make_NDspec_6Dspec(input_data=dyn_spec, provided_dims=["reps", "spec", "z"])

        metabs = ['pyruvate', 'lactate', 'bicarbonate', 'pyruvatehydrate']

        niter = 4 # number of iterations:
        npoints  = 31 # number of tested points per iteration:
        rep_fitting = 15

        # define fit parameters:
        fit_params = {}
        fit_params['signal_domain']='spectral'
        # define peak frequencies:
        fit_params["metabs"] = metabs

        fit_params["b0"] = 3
        fit_params["init_fit"] = True
        fit_params["coff"] = 0
        fit_params["zoomfactor"] = 1.5

        fit_params["metab_t2s"] =  [0.03 for _ in metabs]
        fit_params["max_t2_s"] = [0.07, 0.07, 0.07, 0.07]
        fit_params["min_t2_s"] = [0.01, 0.01, 0.01, 0.01]

        fit_params["range_t2s_s"] = 0.1

        gmr = get_gmr(nucleus="13c")

        freq_range_ppm = freq_range_Hz/ gmr /  fit_params["b0"] + 175
        fit_params["freq_range_ppm"] = freq_range_ppm
        fit_params["freq_range_Hz"] = freq_range_Hz




        fit_params["range_freqs_Hz"] = 40.

        fit_params["show_tqdm"] = True

        # define peak frequencies:
        fit_params["metabs_freqs_ppm"] = [get_metab_cs_ppm(metab=m) for m in metabs]
        # define peak frequencies:
        max_timepoint = np.argmax(np.sum(np.abs(input_data[:,0,0,0,:,0]), axis=0))
        max_peakindex = uts.find_npeaks(input_data=np.squeeze(np.abs(input_data[:,0,0,0,max_timepoint,0])), npeaks=1,plot=False, freq_range=freq_range_ppm)

        pyr_ppm_lit = uts.get_metab_cs_ppm(metab="pyruvate")
        print('Measured pyr peak at:',fit_params['freq_range_ppm'][max_peakindex])
        print('Expected at:',pyr_ppm_lit)
        pyr_ppm_diff_lit_meas = pyr_ppm_lit-fit_params['freq_range_ppm'][max_peakindex]
        fit_params["metabs_freqs_ppm"] = [uts.get_metab_cs_ppm(metab=m)-pyr_ppm_diff_lit_meas for m in metabs]


        print(fit_params["metabs_freqs_ppm"])

        fit_params["niter"] = niter # number of iterations:
        fit_params["npoints"] = npoints # number of tested points per iteration:
        fit_params["rep_fitting"] = rep_fitting # number of tested points per iteration:

        fit_params["fit_range_repetitions"] = range(0,input_data.shape[4])
        fit_params = def_fit_params(fit_params=fit_params)



        fit_spectrums, fit_amps, fit_freqs, fit_t2s, fit_stds = fit_data_pseudo_inv(input_data=input_data,
                                                                                    fit_params=fit_params,
                                                                                    dbplot=False,
                                                                                    use_multiprocessing=True)
        fit_results = {}
        fit_results["fit_spectrums"] = fit_spectrums
        fit_results["fit_freqs"] = fit_freqs
        fit_results["fit_amps"] = fit_amps
        fit_results["fit_t2s"] = fit_t2s
        fit_results["fit_params"] = fit_params
        fit_results["animal_id"] = animal_id

        fit_results["fit_stds"] = fit_stds

        utg.save_as_pkl(dir_path=savepath,
                        filename=animal_id + '_fit_spectra',
                        file=fit_results,
                        use_timestamp=True)
