In [8]:
import os
os.chdir('/home/victorhuang/projects/gtx/')
# os.chdir('/home/victorh/projects/gtx/')

import numpy as np
import mat73
import scipy.io as sio
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from numerical.op_props_util import *

In [9]:
# Spatial frequency
fx = [0, 0.05, 0.1, 0.15, 0.2, 0.25] 
fx = np.array(fx)

dt_data_path = 'data/20241118_data_splited.mat'
phantom_data_path = 'data/phantom_data.mat'
cylinder_data_path_r1 = 'data/cylinder/h1_r5_d0.mat'
cylinder_data_path_r2 = 'data/cylinder/h1_r4_d2.mat' 

In [10]:
def read_mat_data(path):
    mat_data = mat73.loadmat(path)
    return {
        'fluorescence': mat_data['F'],
        'op': mat_data['OP'],
        'depth': mat_data['DF'],
        'concentration_fluor': mat_data['QF'],
        'reflectance': mat_data['RE']
    }

# physics calculation

In [11]:
def getReff_Haskell(n_in, n_out=1.0):
    """
    Compute the effective reflection coefficient Reff using Fresnel integrals (Haskell et al.).

    Parameters
    ----------
    n_in : float
        Refractive index of the tissue
    n_out : float
        Refractive index of the outside medium (default: 1.0 for air)

    Returns
    -------
    Reff : float
        Effective reflection coefficient
    """
    oc = np.arcsin(n_out / n_in)
    o = np.linspace(0, oc, 2000)
    coso = np.cos(o)
    cosop = np.sqrt(1 - (n_in / n_out)**2 * np.sin(o)**2)

    r_fres1 = ((n_in * cosop - n_out * coso) / (n_in * cosop + n_out * coso))**2
    r_fres2 = ((n_in * coso - n_out * cosop) / (n_in * coso + n_out * cosop))**2
    r_fres = 0.5 * (r_fres1 + r_fres2)

    sin_o = np.sin(o)

    r_phi_int = 2 * sin_o * coso * r_fres
    r_phi = np.trapz(r_phi_int, o)

    r_j_int = 3 * sin_o * coso**2 * r_fres
    r_j = np.trapz(r_j_int, o)

    Reff = (r_phi + r_j) / (2 - r_phi + r_j)
    return Reff

In [12]:
def gtxDTBoundaryCondition(nrel, method='Groenhuis'):
    """
    Compute boundary correction constants for diffuse reflectance and fluence modeling.

    Parameters
    ----------
    nrel : float
        Relative refractive index (e.g., 1.4)
    method : str
        Method to compute internal reflectance. Options: 'Groenhuis' (default), 'Haskell'

    Returns
    -------
    Cnd : float
        Correction factor for fluence rate at boundary
    K : float
        Extrapolated boundary constant
    rid : float
        Internal diffuse reflectance
    """
    if nrel == 1.0:
        rid = 0.0
        K = 1.0
    else:
        if method == 'Groenhuis':
            # Empirical formula from Cuccia (2009), based on Groenhuis (1983)
            rid = -1.44 / (nrel**2) + 0.71 / nrel + 0.67 + 0.0636 * nrel
        elif method == 'Haskell':
            rid = getReff_Haskell(nrel)
        else:
            raise ValueError(f"Unsupported boundary method: {method}")
        K = (1 + rid) / (1 - rid)

    Cnd = 2 * K
    return Cnd, K, rid

In [13]:
def compute_mueff(mu_a, mu_s, fx):
    """
    Compute the effective attenuation coefficient (mu_eff) used in diffusion theory.

    Parameters
    ----------
    mu_a : np.ndarray, shape (H, W)
        Absorption coefficient at excitation (mm⁻¹)
    mu_s : np.ndarray, shape (H, W)
        Reduced scattering coefficient at excitation (mm⁻¹)
    fx : np.ndarray, shape (F,)
        Spatial frequencies (mm⁻¹)

    Returns
    -------
    mu_eff : np.ndarray, shape (H, W, F)
        Effective attenuation coefficient at each pixel and frequency
    """

    # Diffusion coefficient D = 1 / (3 * (μₐ + μₛ′))
    mut = mu_a + mu_s  # shape (H, W)
    D = 1.0 / (3.0 * mut)  # shape (H, W)

    # Broadcasting with spatial frequencies
    mu_a_term = mu_a / D  # shape (H, W)
    mu_a_term = mu_a_term[:, :, np.newaxis]  # shape (H, W, 1)
    fx_term = (2 * np.pi * fx)**2  # shape (F,)

    mu_eff = np.sqrt(mu_a_term + fx_term[np.newaxis, np.newaxis, :])  # shape (H, W, F)

    return mu_eff

In [14]:
def fluence_total(mu_a, mu_s, fx, z_vals, nrel=1.4, method='Cuccia'):
    """
    Compute fluence Φ(z, fx) at depth using spatial-frequency domain diffusion theory.

    Parameters
    ----------
    mu_a : np.ndarray, shape (H, W)
        Absorption coefficient
    mu_s : np.ndarray, shape (H, W)
        Reduced scattering coefficient
    fx : np.ndarray, shape (F,)
        Spatial frequencies (mm⁻¹)
    z_vals : np.ndarray, shape (Z,)
        Depths to evaluate fluence (mm)
    nrel : float
        Relative refractive index
    method : str
        Method for fluence calculation: 'Cuccia', 'Kim', or 'Gardner'

    Returns
    -------
    phi : np.ndarray, shape (H, W, Z, F)
        Fluence values at each pixel, depth, and spatial frequency
    """

    H, W = mu_a.shape
    F = len(fx)
    Z = len(z_vals)

    # Boundary condition constants
    Cnd, K, _ = gtxDTBoundaryCondition(nrel)

    mutr = mu_a + mu_s
    D = 1.0 / (3.0 * mutr)
    ap = mu_s / mutr

    # Compute effective attenuation coefficient
    mu_eff = compute_mueff(mu_a, mu_s, fx)

    # Expand for broadcasting
    z = z_vals[None, None, :, None]         # (1, 1, Z, 1)
    mu_eff = mu_eff[:, :, None, :]          # (H, W, 1, F)
    mutr = mutr[:, :, None, None]           # (H, W, 1, 1)
    ap = ap[:, :, None, None]               # (H, W, 1, 1)
    
    if method == 'Cuccia':
        # Cuccia model: valid for fx < 1/(3~4)utr
        A = 1/Cnd

        C0 = 3 * ap / (mu_eff**2 / mutr**2 - 1)
        C = - 3 * ap * (1 + 3*A) / ((mu_eff**2 / mutr**2 - 1) * (mu_eff / mutr + 3*A))
        phi = C0 * np.exp(-mutr * z) + C* np.exp(-mu_eff * z)  # (H, W, Z, F)

        return phi
    
    elif method == 'Kim':
        raise NotImplementedError("Kim model not implemented")
    
    elif method == 'Gardner':
        raise NotImplementedError("Gardner model not implemented")
    
    else:
        raise ValueError(f"Invalid method: {method}")

In [15]:
def compute_Tef(mu_a_x, mu_s_x, mu_a_m, mu_s_m, fx, z_vals, nrel=1.4, fx_dependence=True, layer='top'):
    """
    Compute the depth-resolved transport kernel Tef(z, fx).

    Parameters
    ----------
    mu_a_x : np.ndarray, shape (H, W)
        Absorption at excitation
    mu_s_x : np.ndarray, shape (H, W)
        Scattering at excitation
    mu_a_m : np.ndarray, shape (H, W)
        Absorption at emission
    mu_s_m : np.ndarray, shape (H, W)
        Scattering at emission
    fx : np.ndarray, shape (F,)
        Spatial frequencies
    z_vals : np.ndarray, shape (Z,)
        Depths to integrate over
    nrel : float
        Relative refractive index
    fx_dependence : bool
        Whether emission fluence depends on fx (True for real data, False for sim)

    Returns
    -------
    Tef_z : np.ndarray, shape (H, W, Z, F)
        Depth-resolved transport kernel
    """

    # if layer == 'top' and z_vals[0] > z_vals[-1]:
    #     raise ValueError("For layer='top', z_vals must increase from surface to depth")
    # elif layer == 'bottom' and z_vals[0] < z_vals[-1]:
    #     raise ValueError("For layer='bottom', z_vals must decrease from depth to deeper layers")

    Cnd, _, _ = gtxDTBoundaryCondition(nrel)

    # Excitation fluence: Φ_x(z, fx)
    phi_x = fluence_total(mu_a_x, mu_s_x, fx, z_vals, nrel=nrel, method='Cuccia')  # (H, W, Z, F)

    # Emission fluence: Φ_m(z, fx) or Φ_m(z, f=0)
    if fx_dependence:
        phi_m = fluence_total(mu_a_m, mu_s_m, fx, z_vals, nrel=nrel, method='Cuccia')
    else:
        phi_m = fluence_total(mu_a_m, mu_s_m, np.array([0.0]), z_vals, nrel=nrel, method='Cuccia')
        phi_m = np.repeat(phi_m, len(fx), axis=3)  # replicate along fx dimension

    dz = np.gradient(z_vals)[None, None, :, None]  # (1, 1, Z, 1)
    product_term = phi_x * phi_m * dz

    if layer == 'top':
        Tef_z = np.cumsum(product_term, axis=2) / Cnd
    elif layer == 'bottom':
        Tef_z = np.cumsum(product_term[:, :, ::-1, :], axis=2)[:, :, ::-1, :] / Cnd
    else:
        raise ValueError(f"Invalid layer: {layer}")
    
    Tef = np.sum(product_term, axis=2) / Cnd
    
    return Tef, Tef_z, phi_x, phi_m

In [16]:
def interpolate_z_at_mval(m_val, z_vals, M_curve, type="all"):
    """
    Find z such that M_curve(z) == m_val using linear interpolation.
    z_vals: shape (Nz,), e.g. np.linspace(0, 10, 100)
    M_curve: shape (Nz,), values of M_Z[i,j,:,fAC]
    """
    if type != "all":
        if m_val < np.min(M_curve):
            return z_vals[-1]  
        elif m_val > np.max(M_curve):
            return z_vals[0] 
    else:
        if m_val < np.min(M_curve) or m_val > np.max(M_curve):
            return z_vals[0]
    
    diff = M_curve - m_val
    sign_change = np.where(np.diff(np.sign(diff)) != 0)[0]

    idx = sign_change[0]
    x0, x1 = z_vals[idx], z_vals[idx+1]
    y0, y1 = M_curve[idx], M_curve[idx+1]

    # Linear interpolation
    z_interp = x0 + (m_val - y0) * (x1 - x0) / (y1 - y0)
    return z_interp

In [17]:
def inverse_fluorescence_depth(F, fx, mua_x, mus_x, mua_m, mus_m, nrel, numerical_vars):
    """
    Estimate depth and fluorescence yield from fluorescence SFDI images.

    Parameters
    ----------
    F : ndarray of shape (H, W, F)
        Measured fluorescence images at different spatial frequencies.
    fx : ndarray of shape (F,)
        Spatial frequencies.
    mua_x, mus_x : ndarray of shape (H, W)
        Absorption and scattering at excitation.
    mua_m, mus_m : ndarray of shape (H, W)
        Absorption and scattering at emission.
    nrel : float
        Relative refractive index.
    numerical_vars : dict
        Parameters: fxDependence, hFl, zDelta, layerFl

    Returns
    -------
    z_est : ndarray of shape (H, W)
        Estimated depth.
    qF_est : ndarray of shape (H, W)
        Estimated fluorescence yield.
    """

    fx_dependence = numerical_vars.get("fxDependence", True)
    hFl = numerical_vars.get("hFl", 20.0)
    zDelta = numerical_vars.get("zDelta", 0.05)
    depth_offset = numerical_vars.get("depth_offset", 0)
    layerFl = numerical_vars.get("layerFl", "top")
    masked = numerical_vars.get("masked", True)
    idx = numerical_vars.get("idx", 4)

    # Depth range
    if layerFl == "bottom":
        # z_vals = np.arange(depth_offset, hFl + depth_offset, zDelta)
        z_vals = np.arange(0, hFl, zDelta)
    else:
        z_vals = np.arange(0, hFl, zDelta)


    Tef, Tef_z, phi_x, phi_m = compute_Tef(
        mua_x, mus_x,
        mua_m, mus_m,
        fx, z_vals,
        nrel=nrel,
        fx_dependence=fx_dependence,
        layer=layerFl
    )

    H, W, F_ = F.shape
    fDC = 0
    F_DC = F[..., fDC]
    range_ac = range(1, F_)


    z_est = np.zeros((F_, H, W))
    qF_est = np.zeros((F_, H, W))

    M_Z = np.zeros((H, W, len(z_vals), F_))
    M_Z_Data = np.zeros((H, W, F_))

    for fAC in range_ac:
        F_AC = F[..., fAC]
        M_Z_Data[..., fAC] = np.nan_to_num(F_AC / F_DC, nan=0.0)
        M_Z[..., fAC] = Tef_z[..., fAC] / Tef_z[..., fDC]


        for i in range(H):
            for j in range(W):
                m_val = M_Z_Data[i, j, fAC]

                z_interp = interpolate_z_at_mval(m_val, z_vals, M_Z[i, j, :, fAC])
                z_est[fAC, i, j] += z_interp

                Tef_dc_interp = np.interp(z_interp, z_vals, Tef_z[i, j, :, fDC], left=1e-6, right=1e-6)
                qF_est[fAC, i, j] += F_DC[i, j] / Tef_dc_interp if Tef_dc_interp > 0 else 0.0

    def masked_min(z_est):
        masked = np.ma.masked_equal(z_est, 0)
        min_nonzero = np.ma.min(masked, axis=0).filled(0)
        return min_nonzero
    
    if masked:
        z_est = masked_min(z_est) if masked else np.average(z_est, axis=0)
        qF_est = np.average(qF_est, axis=0)
    else:
        if idx != 0:
            z_est = np.average(z_est, axis=0)
            qF_est = np.average(qF_est, axis=0)

    return z_est, qF_est, M_Z, M_Z_Data, phi_x, Tef_z, z_vals

# depth evaluation

In [18]:
def plot_depth(data, z_est, type='depth'):
    if type == 'depth':
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        vmin = min(data.get('depth').min(), z_est.min())
        vmax = max(data.get('depth').max(), z_est.max())
        im0 = ax[0].imshow(data.get('depth'), vmin=vmin, vmax=vmax)
        im1 = ax[1].imshow(z_est, vmin=vmin, vmax=vmax)
        ax[0].set_title('Depth')
        ax[1].set_title('Estimated Depth')
        cbar = fig.colorbar(im0, ax=ax[:], orientation='vertical')
        cbar.set_label('Depth (mm)')  # Optional
        plt.show()
    elif type == 'concentration':
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        vmin = min(data.get('concentration').min(), z_est.min())
        vmax = max(data.get('concentration').max(), z_est.max())
        im0 = ax[0].imshow(data.get('concentration'), vmin=vmin, vmax=vmax)
        im1 = ax[1].imshow(z_est, vmin=vmin, vmax=vmax)
        ax[0].set_title('Concentration')
        ax[1].set_title('Estimated Concentration')
        cbar = fig.colorbar(im0, ax=ax[:], orientation='vertical')
        cbar.set_label('Concentration (mg/ml)')  # Optional
        plt.show()

In [None]:
def plot_graphs(Tef_z, M_Z, phi_x, z_vals, M_Z_Data, number, xy, random=False):
    x, y = list(), list()
    if random:
        for i in range(number):
            x += [np.random.randint(0, M_Z.shape[1])]
            y += [np.random.randint(0, M_Z.shape[0])]
    else:
        x += [xy[0]]
        y += [xy[1]]

    for i in range(len(x)):
        print("--------------------------------")
        xi = x[i]
        yi = y[i]
        print(f"x: {xi}, y: {yi}")
        fig = plt.figure(figsize=(15, 6))
        gs = gridspec.GridSpec(2, 5, figure=fig)

        # First row: span first two columns
        ax0 = fig.add_subplot(gs[0, 0])
        ax1 = fig.add_subplot(gs[0, 1])
        ax2 = fig.add_subplot(gs[0, 2])

        # Second row: all 5 columns
        ax3 = fig.add_subplot(gs[1, 0])
        ax4 = fig.add_subplot(gs[1, 1])
        ax5 = fig.add_subplot(gs[1, 2])
        ax6 = fig.add_subplot(gs[1, 3])
        ax7 = fig.add_subplot(gs[1, 4])

        ax = [ax0, ax1, ax2, ax3, ax4, ax5, ax6, ax7]

        ax[0].plot(z_vals, Tef_z[yi, xi, :, 0])
        ax[0].plot(z_vals, Tef_z[yi, xi, :, 1])
        ax[0].plot(z_vals, Tef_z[yi, xi, :, 2])
        ax[0].plot(z_vals, Tef_z[yi, xi, :, 3])
        ax[0].plot(z_vals, Tef_z[yi, xi, :, 4])
        ax[0].plot(z_vals, Tef_z[yi, xi, :, 5])
        ax[0].legend(["fDC", "f=0.05", "f=0.1", "f=0.15", "f=0.2", "f=0.25"])
        ax[0].set_xlabel("Depth (mm)")
        ax[0].set_ylabel("Tef_z")

        ax[1].plot(z_vals, phi_x[yi, xi, :, 0])
        ax[1].plot(z_vals, phi_x[yi, xi, :, 1])
        ax[1].plot(z_vals, phi_x[yi, xi, :, 2])
        ax[1].plot(z_vals, phi_x[yi, xi, :, 3])
        ax[1].plot(z_vals, phi_x[yi, xi, :, 4])
        ax[1].plot(z_vals, phi_x[yi, xi, :, 5])
        ax[1].legend(["fDC", "f=0.05", "f=0.1", "f=0.15", "f=0.2", "f=0.25"])
        ax[1].set_xlabel("Depth (mm)")
        ax[1].set_ylabel("phi_x")

        ax[2].plot(z_vals, M_Z[yi, xi, :, 0])
        ax[2].plot(z_vals, M_Z[yi, xi, :, 1])
        ax[2].plot(z_vals, M_Z[yi, xi, :, 2])
        ax[2].plot(z_vals, M_Z[yi, xi, :, 3])
        ax[2].plot(z_vals, M_Z[yi, xi, :, 4])
        ax[2].plot(z_vals, M_Z[yi, xi, :, 5])
        ax[2].legend(["fDC", "f=0.05", "f=0.1", "f=0.15", "f=0.2", "f=0.25"])
        ax[2].set_xlabel("Depth (mm)")
        ax[2].set_ylabel("M_Z")

        temp_arr = np.ones((len(z_vals), 6))
        temp_arr[..., 0] = M_Z_Data[yi, xi, 0]
        temp_arr[..., 1] = M_Z_Data[yi, xi, 1]
        temp_arr[..., 2] = M_Z_Data[yi, xi, 2]
        temp_arr[..., 3] = M_Z_Data[yi, xi, 3]
        temp_arr[..., 4] = M_Z_Data[yi, xi, 4]
        temp_arr[..., 5] = M_Z_Data[yi, xi, 5]

        result1 = interpolate_z_at_mval(M_Z_Data[yi, xi, 1], z_vals, M_Z[yi, xi, :, 1])
        # ax[2].plot(z_vals, M_Z[y, x, :, 0])
        ax[3].plot(z_vals, M_Z[yi, xi, :, 1])
        ax[3].plot(z_vals, temp_arr[..., 1])
        ax[3].legend(["f=0.05", "m_val"])
        ax[3].set_xlabel("Depth (mm)")
        ax[3].set_ylabel("M_Z")
        ax[3].axvline(x=result1, color='r', linestyle='--', label='Interpolated Depth')

        result2 = interpolate_z_at_mval(M_Z_Data[yi, xi, 2], z_vals, M_Z[yi, xi, :, 2])
        # ax[3].plot(z_vals, M_Z[y, x, :, 0])
        ax[4].plot(z_vals, M_Z[yi, xi, :, 2])
        ax[4].plot(z_vals, temp_arr[..., 2])
        ax[4].legend(["f=0.1", "m_val"])
        ax[4].set_xlabel("Depth (mm)")
        ax[4].set_ylabel("M_Z")
        ax[4].axvline(x=result2, color='r', linestyle='--', label='Interpolated Depth')

        result3 = interpolate_z_at_mval(M_Z_Data[yi, xi, 3], z_vals, M_Z[yi, xi, :, 3])
        # ax[4].plot(z_vals, M_Z[y, x, :, 0])
        ax[5].plot(z_vals, M_Z[yi, xi, :, 3])
        ax[5].plot(z_vals, temp_arr[..., 3])
        ax[5].legend(["f=0.15", "m_val"])
        ax[5].set_xlabel("Depth (mm)")
        ax[5].set_ylabel("M_Z")
        ax[5].axvline(x=result3, color='r', linestyle='--', label='Interpolated Depth')


        result4 = interpolate_z_at_mval(M_Z_Data[yi, xi, 4], z_vals, M_Z[yi, xi, :, 4])
        # ax[5].plot(z_vals, M_Z[y, x, :, 0])
        ax[6].plot(z_vals, M_Z[yi, xi, :, 4])
        ax[6].plot(z_vals, temp_arr[..., 4])
        ax[6].legend(["f=0.2", "m_val"])
        ax[6].set_xlabel("Depth (mm)")
        ax[6].set_ylabel("M_Z")
        ax[6].axvline(x=result4, color='r', linestyle='--', label='Interpolated Depth')


        result5 = interpolate_z_at_mval(M_Z_Data[yi, xi, 5], z_vals, M_Z[yi, xi, :, 5])
        # ax[6].plot(z_vals, M_Z[y, x, :, 0])
        ax[7].plot(z_vals, M_Z[yi, xi, :, 5])
        ax[7].plot(z_vals, temp_arr[..., 5])
        ax[7].legend(["f=0.25", "m_val"])
        ax[7].set_xlabel("Depth (mm)")
        ax[7].set_ylabel("M_Z")
        ax[7].axvline(x=result5, color='r', linestyle='--', label='Interpolated Depth')

        print(f"results: {result1}, {result2}, {result3}, {result4}, {result5}")

        plt.show()


In [None]:
def run_numerical_solver(filepath, type_data, numerical_vars):

    eta = numerical_vars.get("eta", 0.03)
    nrel = numerical_vars.get("nrel", 1.33)

    zDelta = numerical_vars.get("zDelta", 0.01)
    hFl = numerical_vars.get("hFl", 1)
    depth_offset = numerical_vars.get("depth_offset", 0)
    layerFl = numerical_vars.get("layerFl", "top")
    fxDependence = numerical_vars.get("fxDependence", True)
    extrapolation = numerical_vars.get("extrapolation", False)

    add_Chrom = numerical_vars.get("add_Chrom", True)
    masked = numerical_vars.get("masked", True)
    idx = numerical_vars.get("idx", None)
    plot = numerical_vars.get("plot", False)
    
    if type_data == 'cylinder':
        str_info = filepath.split('/')[-1].split('.')[0].split('_')
        rad = float(str_info[1][1])
        depth_offset = float(str_info[2][1:])
        if depth_offset > 10:
            depth_offset /= 10

    if depth_offset > 0:
        layerFl = 'bottom'

    
    if type_data == 'cylinder':
        data = load_cylinder_data(filepath)
        sample_data = {
            'fluorescence': data['F'],
            'optical_props': data['OP'],
            'reflectance': data['RE'],
            'depth': data['DF'],
            'concentration': data['QF']
        }
    elif type_data == 'phantom':
        temp_data = load_phantom_data(phantom_data_path)
        phantom_data = {
            'fluorescence': temp_data['F'],
            'optical_props': temp_data['OP'],
            'reflectance': temp_data['RE'],
            'depth': temp_data['DF'],
            'concentration': temp_data['QF']
        }

        idx = idx if idx is not None else 0
        sample_data = {
            'fluorescence': phantom_data['fluorescence'][idx, :, :, :],
            'optical_props': phantom_data['optical_props'][idx, :, :, :],
            'reflectance': phantom_data['reflectance'][idx, :, :, :],
            'depth': phantom_data['depth'][idx, :, :],
            'concentration': phantom_data['concentration'][idx, :, :]
        }
    elif type_data == 'dt':
        data_list = load_dt_data(filepath)
        data = data_list[idx]
        sample_data = {
            'fluorescence': data['F'],
            'optical_props': data['OP'],
            'reflectance': data['RE'],
            'depth': data['DF'],
            'concentration': data['QF']
        }

    elif type_data == 'else':
        temp_data = load_phantom_data("data/sample_iceberg.mat")
        phantom_data = {
            'fluorescence': temp_data['F'],
            'optical_props': temp_data['OP'],
            'reflectance': temp_data['RE'],
            'depth': temp_data['DF'],
            'concentration': temp_data['QF']
        }

        sample_data = {
            'fluorescence': phantom_data['fluorescence'][:, :, :],
            'optical_props': phantom_data['optical_props'][:, :, :],
            'reflectance': phantom_data['reflectance'][:, :, :],
            'depth': phantom_data['depth'][:, :],
            'concentration': phantom_data['concentration'][:, :]
        }
    else:
        raise ValueError(f"Invalid data type: {type_data}")
    
    F = sample_data.get('fluorescence')
    op = sample_data.get('optical_props')

    if add_Chrom:
        mua_x = np.ones(op[..., 0].shape) * opt_prop_hb(630, 1.5, 0.95) + np.ones(op[..., 0].shape) * 0.033
        mus_x = np.ones(op[..., 1].shape) 

        mua_m = np.ones(op[..., 0].shape) * opt_prop_hb(700, 1.5, 0.95)
        mus_m = np.ones(op[..., 1].shape) 
    else:
        mua_x = op[..., 0]
        mus_x = op[..., 1]
        if extrapolation:
            mua_m, mus_m = extrapolate_opt_prop(mua_x, mus_x, 630, 700, absorber='IndiaInk', scatterer='Intralipid')
        else:
            mua_m = mua_x
            mus_m = mus_x

    assert mua_m.shape == mua_x.shape
    assert mus_m.shape == mus_x.shape

    print(f"File: {filepath}")
    print(f"Depth offset: {depth_offset}, Layer: {layerFl}")
    print(f"mua_x: {np.max(mua_x)}, mua_m: {np.max(mua_m)}")
    print(f"mus_x: {np.max(mus_x)}, mus_m: {np.max(mus_m)}")
    
    numericalVars = {
        'fxDependence': fxDependence,
        'hFl': hFl,
        'zDelta': zDelta,
        'layerFl': layerFl,
        'depth_offset': depth_offset,
        'masked': masked,
        'idx': idx
    }

    z_est, qf_est, M_Z, M_Z_Data, phi_x, Tef_z, z_vals = inverse_fluorescence_depth(F, fx, mua_x, mus_x, mua_m, mus_m, nrel, numericalVars)

    gt_depth = sample_data.get('depth')
    if depth_offset == 0 and type_data == 'cylinder':
        m = np.max(gt_depth)
        gt_depth = np.where(gt_depth > 0, m, 0)

    depth_with_value = np.where(sample_data.get('depth') > 0)
    depth_abs_err = np.sum(np.abs(z_est[depth_with_value] - sample_data.get('depth')[depth_with_value]))
    depth_rel_err = np.sum(z_est[depth_with_value] - sample_data.get('depth')[depth_with_value]) 

    pixels_with_value = np.sum(sample_data.get('depth') > 0)
    
    print(f"Depth absolute error: {depth_abs_err/pixels_with_value}, Depth relative error: {depth_rel_err/pixels_with_value}")
    
    if plot:
        plot_graphs(Tef_z, M_Z, phi_x, z_vals, M_Z_Data, 3, [50, 50], random=False)
        
    test_data = {
        'depth': gt_depth,
        'concentration': sample_data.get('concentration')
    }
    plot_depth(test_data, z_est, type='depth')

    
    

In [None]:
small_shallow_2000_data = read_mat_data("data/newKAN/2000_small_shallow.mat")