In [None]:
"""
This python script is for SPICE uncertainty analysis

Author: Jason Tian Lyu <tian.lyu@ndcn.ox.ac.uk>

Copyright University of Oxford, 2025.
"""

%load_ext autoreload
%autoreload 2
# import libs
import matplotlib.pyplot as plt
import numpy as np
from fsl_mrs.utils import synthetic as syn
from fsl_mrs.core import basis
from fsl_mrs.utils.plotting import FID2Spec
from fsl_mrs.utils.misc import FIDToSpec
from scipy.sparse.linalg import LinearOperator
from scipy.sparse.linalg import cg
from scipy.optimize import minimize
from fsl_mrs.utils import mrs_io 
from datetime import datetime
import seaborn as sns
from SPICE_Uncert_Ancillary import (
    createSpatialCurve_1D_Brain,
    make_1d_spectral_phantom,
    calc_F,
    gen_gt_ktspace,
    add_noise2kt,
    plot_and_rmse,
    read_training_data_from_csv,
    save_training_data_as_csv,
    constraints_to_B,
    Plot_W_WE,
    SPICEWithSpatialConstrain,
    Undersampe_zeroout,


    plot_spec_ana,
    plot_mc_compare_spec,
    plot_mc_compare_spat,
    runmc,
    calc_Covariance_spat,
    calc_std_uncert,
    plot_spec_mc,
    plot_spec_analyt,
    plot_spatial_mc,
    plot_spatial_mc_ana_combined,
    calc_Covariance_spat_overall,
    calc_Covariance_old,
    Create_laplacian_samples,

    fft_recon,
    fft_mc,

    plot_bm_and_bmFID,
    fit_mrs_spectrum_lstsq_batch_vbv,
    plot_popt,
    Sig_func_Multi_Peak,
    mc_basis,
)

In [None]:
'''
Make Macro Definitions
'''

# Phantom Ground Truth related
Peak_lws_gt = [10, 10]  # Ground truth linewidths for fitting
K_POINTS = 64  # Number of K-space sampled points
N_VOXEL = 64  # Number of voxels
Peak_CS = [-3.05, -1.5] # basis chemical shifts

# Brain image input path
BRAIN_IMG_DIR = './Brain_img/'
BRAIN_IMG_FILE = 'brain_from_k_space_64.npy'

WATER_PEAK_PPM_CENTER = 4.65  # ppm position of the water peak
NOISE_SNR = 100  # Signal-to-noise ratio for added noise

# Rough indices of peak locations in spectrum (used for plotting/peak picking)
PEAK_0_ROUGH_IDX = 381
PEAK_1_ROUGH_IDX = 336


# Metabolites Basis related Values
N_SEQ_POINTS = 512  # Number of time-domain sampling points (FID length)
N_SEQ_BANDWIDTH = 2800  # Spectral bandwidth in Hz
BASIS_DIR = './Basis_Fit_ESMRMB2025/'  
BASIS_Cho = 'Cho'
BASIS_Cr = 'Cr'
BASIS_NAA = 'NAA'
BASIS_Glu = 'Glu'


# Undersampling related param
NUM_SPICE_RANK = 5  # Rank for SPICE low-rank subspace
UNDER_SAMPLE_NUM = 16  # Number of undersampled k-space lines
UNDER_SAMPLE_HDL = True  # Whether to perform undersampling
UNDERSAMPLE_RATIO = UNDER_SAMPLE_NUM/K_POINTS  # Percentage of k-space lines retained

# SPICE Recon Parameters
LAMBDA_WE_max = 500  # Maximum value for edge-preserving regularization weight
Lamda_1 = 0.05  # Regularization weight (lambda)
SUBSPACE_DATA_RW = True  # Whether to read/write subspace data
SAVE_DIR = './Train_data_SPICE/'  # Directory to save training data
CSV_FILE_NAME = 'SS250506'  # File name for saved training set CSV


# Plot related Param
scale_factor = 0.018  # Scale factor for displaying water image
limits = [2e-3, 1.25e-2]  # Color scale limits for Monte Carlo uncertainty maps


# Data saving for ploting (Note, large space required)
SAVDE_DATA_DIR = './Saved_Data/'  # Directory to save output data

# Get current timestamp (format: YYYY-MM-DD_HH)
timestamp = datetime.now().strftime("%Y-%m-%d_%H")

# Compose description string for this run based on current parameters
description = f'SNR({NOISE_SNR})_lambda({Lamda_1})_Wmax({LAMBDA_WE_max})_US({UNDERSAMPLE_RATIO})'


In [None]:
'''
Describe 1D phantom amplitudes
Plot and construct a 1D Phantom
'''

# Load brain image from K-space reconstruction
brain_img = np.load(BRAIN_IMG_DIR + BRAIN_IMG_FILE)
print(brain_img.shape)  # Output shape of the brain image
print(brain_img.dtype)  # Output data type

# Display the magnitude image
plt.imshow(np.abs(brain_img), cmap='viridis')
plt.title('Brain Image from K-space')
plt.axis('on')
plt.show()

# Extract a 1D slice along the vertical axis (column index 25)
SLICE_1D = brain_img[:, 25]
plt.plot(np.abs(SLICE_1D))
plt.title("1D Slice of Brain Image (Magnitude)")
plt.xlabel("Row Index")
plt.ylabel("Intensity")
plt.grid(True)
plt.show()

# Round the slice to obtain quantized intensity values
SLICE_SMOOTHED = np.round(SLICE_1D)
plt.plot(SLICE_SMOOTHED)
plt.title("Rounded/Quantized Slice")
plt.grid(True)
plt.show()

# Generate the 1D spatial phantom from smoothed slice
phantom_res, phantom_size, x_hr, x_hr_pos, C0_GT, C1_GT = createSpatialCurve_1D_Brain(K_POINTS, SLICE_SMOOTHED)

# Load full basis set and format it for simulation
fullbasis = mrs_io.read_basis(BASIS_DIR)
basis = fullbasis.get_formatted_basis(bandwidth=N_SEQ_BANDWIDTH, points=N_SEQ_POINTS)

# Extract metabolite basis FIDs (complex)
Basis_NAA_FID = basis[:, 3]
Basis_Cho_FID = basis[:, 0]
Basis_Cr_FID = basis[:, 1]
Basis_Glu_FID = basis[:, 2]

# Define frequency, ppm, and time axes
sweepwidth = N_SEQ_BANDWIDTH
dwelltime = fullbasis.original_dwell
center_freq = fullbasis._cf
original_points = fullbasis.original_points

FREQ_AXIS = np.linspace(-sweepwidth / 2, sweepwidth / 2, N_SEQ_POINTS)
PPM_AXIS = -(FREQ_AXIS / center_freq) + WATER_PEAK_PPM_CENTER
TIME_AXIS = np.linspace(0.0, dwelltime * (original_points - 1), N_SEQ_POINTS)

# Select Glu basis and compute its spectrum
Basis0_FID = Basis_Glu_FID
Basis0_FID_SPEC = FID2Spec(Basis0_FID)

# Plot Glu spectrum
plt.plot(PPM_AXIS, Basis0_FID_SPEC, label='Metabolite Glu')
plt.xlabel('PPM axis')
plt.ylabel('Signal Spectrum')
plt.title('Basis Glu Spectrum')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot Glu FID (real part)
plt.plot(TIME_AXIS, np.real(Basis0_FID), label='Metabolite Glu')
plt.xlabel('Time axis')
plt.ylabel('Signal')
plt.title('Basis Glu FID (Real Part)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Select Cho basis and compute its spectrum
Basis1_FID = Basis_Cho_FID
Basis1_FID_SPEC = FID2Spec(Basis1_FID)

# Plot Cho spectrum
plt.plot(PPM_AXIS, Basis1_FID_SPEC, label='Metabolite Cho')
plt.xlabel('PPM axis')
plt.ylabel('Signal Spectrum')
plt.title('Basis Cho Spectrum')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot Cho FID (real part)
plt.plot(TIME_AXIS, np.real(Basis1_FID), label='Metabolite Cho')
plt.xlabel('Time axis')
plt.ylabel('Signal')
plt.title('Basis Cho FID (Real Part)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Store basis FIDs for simulation
bm_FIDs = [Basis0_FID, Basis1_FID]

# Generate the ground-truth image-time space (IT-space) data
GT_IT_SPACE = make_1d_spectral_phantom(C0_GT, C1_GT, Peak_CS, Peak_lws_gt, bm_FIDs, TIME_AXIS)

# Compute k-t space from IT-space using Fourier transform
F = calc_F(K_POINTS, True)
GT_KT_SPACE = gen_gt_ktspace(GT_IT_SPACE, K_POINTS, F, TIME_AXIS, True)

# Add noise to k-t space using the given SNR
rng = np.random.default_rng()
NOISE_SD = np.max(np.abs(GT_KT_SPACE)) / NOISE_SNR
noisy_kt_space = add_noise2kt(GT_KT_SPACE, rng, NOISE_SD)

# Undersample k-space by zeroing out lines
noisy_kt_space_us_eg, F_us_eg = Undersampe_zeroout(
    noisy_kt_space, F, UNDER_SAMPLE_NUM, K_POINTS, TIME_AXIS, True, UNDER_SAMPLE_HDL
)

# Plot the magnitude of noisy k-space (truncated time view)
k_x = np.arange(K_POINTS)
plt.pcolor(TIME_AXIS, k_x, np.abs(noisy_kt_space))
plt.title('Noisy k-t Space (Truncated in Time)')
plt.xlabel('Time (#)')
plt.xlim([0, TIME_AXIS[50]])
plt.ylabel('$k_x$')
plt.show()


In [None]:
'''
Generate training set for subspace learning
'''

# Maximum concentration used for training data generation
cmax_training = 1.5

# Linewidth parameters (standard deviation and mean)
lw_sd_training = 2           # Linewidth standard deviation (Hz)
lw_mean_training = 10        # Linewidth mean (Hz)

# Number of training samples to generate
training_datasets = int(1E4)

# Generate random concentrations for two metabolites
# Tissue 1: chemical 1 has 2× the concentration of chemical 2 (implicitly)
training_cs = rng.random((2, training_datasets)) * cmax_training

# Generate random linewidths using a normal distribution
training_lw = lw_mean_training + rng.standard_normal((2, training_datasets)) * lw_sd_training

# Simulate FID signals for all training samples
training_dataset = Sig_func_Multi_Peak(bm_FIDs, training_lw, training_cs, TIME_AXIS, training_datasets)

# Plot the first 5 spectra (real part) from the training set
plt.plot(PPM_AXIS, FID2Spec(training_dataset[:5, :].T).real)
plt.title('Example Spectra from Training Set')
plt.xlabel('PPM')
plt.ylabel('Signal (Real Part)')
plt.grid(True)
plt.tight_layout()
plt.show()

# Save the generated training dataset as a CSV file
save_training_data_as_csv(
    training_data=training_dataset,
    save_dir=SAVE_DIR,
    filename=CSV_FILE_NAME,
    savecondition=SUBSPACE_DATA_RW
)


In [None]:
'''
Take the svd of the data
'''
# Read the training data from the csv
training_dataset = read_training_data_from_csv(save_dir=SAVE_DIR,filename=CSV_FILE_NAME)


u,s,vh = np.linalg.svd(training_dataset)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(s, 'x--')
ax1.set_xlim([0, 20])
# Set the x-axis to display only integer gridlines
ax1.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

ax2.plot(PPM_AXIS, FID2Spec(vh[0:NUM_SPICE_RANK,:].T).real)
# ax2.set_xlim([0, PPM_AXIS[0]])
plt.show()



In [None]:
'''
Run SPICE
'''
# 1. Create anatomical prior image
# Set background (value = 1) to 0, and tissue type 2 to value 3
SLICE_SMOOTHED[SLICE_SMOOTHED == 1] = 0
SLICE_SMOOTHED[SLICE_SMOOTHED == 2] = 3

# Make a copy for processing
SLICE_SMOOTHED_extracted = SLICE_SMOOTHED.copy()

# Plot the modified anatomical slice
plt.figure(figsize=(8, 4))
plt.plot(SLICE_SMOOTHED_extracted, marker='o')
plt.title("SLICE_SMOOTHED with 1s set to 0")
plt.xlabel("Index")
plt.ylabel("Value")
plt.grid(True)
plt.tight_layout()
plt.show()

# Assign the processed anatomical prior to water_rou_1D
water_rou_1D = SLICE_SMOOTHED_extracted

# 2. Calculate Edge-Preserving Matrix for spatial regularization
minpooling_Handler = True  # Enable min pooling
pool_size = 1              # Define pooling window size
W_max = LAMBDA_WE_max      # Max edge weight
lamda_1 = Lamda_1          # Regularization strength

# Generate W_edge matrix from anatomical prior
W_edge, _W, _P = constraints_to_B(water_rou_1D, W_max=W_max, pool_size=pool_size) 

# Visualize the spatial constraint matrix
Plot_W_WE(_W, _P)

# Get top-rank subspace basis (V) from truncated SVD
V = vh[0:NUM_SPICE_RANK, :].conj().T  # Right singular vectors (used for projection)
Vh = vh[0:NUM_SPICE_RANK, :]          # Truncated V^H

# Perform SPICE reconstruction with or without undersampling
if UNDER_SAMPLE_HDL:
    spice_est_cg, est_U = SPICEWithSpatialConstrain(
        noisy_kt_space_us_eg, F_us_eg, V, K_POINTS,
        NUM_SPICE_RANK, W_edge, method='cg', lamda=lamda_1
    )
else:
    spice_est_cg, est_U = SPICEWithSpatialConstrain(
        noisy_kt_space, F, V, K_POINTS,
        NUM_SPICE_RANK, W_edge, method='cg', lamda=lamda_1
    )

# Compute the ground-truth reconstruction (lambda = 0)
GT_spice_est_cg, GT_est_U = SPICEWithSpatialConstrain(
    GT_KT_SPACE, F, V, K_POINTS,
    NUM_SPICE_RANK, W_edge, method='cg', lamda=0
)

# --- Determine the main peak index for plotting/demo purposes ---
# Find the dominant frequency bin in Glu basis
main_peak_index = np.argmax(np.abs(Basis0_FID_SPEC))
print(f"Basis0 Main peak index: {main_peak_index}")
main_peak_ppm = PPM_AXIS[main_peak_index]
print(f"Basis0 Main peak PPM: {main_peak_ppm}")

# Find the dominant frequency bin in Cho basis
main_peak_index = np.argmax(np.abs(Basis1_FID_SPEC))
print(f"Basis1 Main peak index: {main_peak_index}")
main_peak_ppm = PPM_AXIS[main_peak_index]
print(f"Basis1 Main peak PPM: {main_peak_ppm}")

# --- Plot comparison between reconstruction and ground truth ---
plot_and_rmse(
    spice_est_cg,        # Reconstructed image-time space
    GT_IT_SPACE,         # Ground truth
    PPM_AXIS,            # PPM axis for spectral domain
    x_hr_pos,            # Spatial positions
    C0_GT, C1_GT         # Ground-truth metabolite amplitudes
)


In [None]:
'''
Monte Carlo (MC) Method using Conjugate Gradient (CG) Solver for SPICE
'''
spice_mc_cg, _ = runmc(
    SPICEWithSpatialConstrain,
    add_noise=add_noise2kt,
    gen_undersample=Undersampe_zeroout,
    UNDER_SAMPLE_NUM=UNDER_SAMPLE_NUM,
    kt_space_gt=GT_KT_SPACE,
    F=F,
    K_POINTS=K_POINTS,
    time_axis=TIME_AXIS,
    handler=UNDER_SAMPLE_HDL,
    W_edge=W_edge,
    lamda_1=Lamda_1,
    Solver='cg',
    iterations=2000,
    seed_mc=42,
    noise_SD=NOISE_SD,
    V=V,
    NUM_SPICE_RANK=NUM_SPICE_RANK
)

# Optional visualization for MC result using CG
# SPICE.plot_mc_results(spice_mc_cg, plot_and_rmse=PLF.plot_and_rmse, ppm_axis=ppm_axis, water_rou_1D=water_rou_1D)

'''
Monte Carlo (MC) Method using Analytical Solver for SPICE
'''
spice_mc_analy, _ = runmc(
    SPICEWithSpatialConstrain,
    add_noise=add_noise2kt,
    gen_undersample=Undersampe_zeroout,
    UNDER_SAMPLE_NUM=UNDER_SAMPLE_NUM,
    kt_space_gt=GT_KT_SPACE,
    F=F,
    K_POINTS=K_POINTS,
    time_axis=TIME_AXIS,
    handler=UNDER_SAMPLE_HDL,
    W_edge=W_edge,
    lamda_1=Lamda_1,
    Solver='analytical',
    iterations=2000,
    seed_mc=42,
    noise_SD=NOISE_SD,
    V=V,
    NUM_SPICE_RANK=NUM_SPICE_RANK
)

# Optional visualization for MC result using analytical method
# SPICE.plot_mc_results(spice_mc_analy, plot_and_rmse=PLF.plot_and_rmse, ppm_axis=ppm_axis, water_rou_1D=water_rou_1D)

# Save MC results for CG solver
filename = f"spice_mc_cg_{description}_{timestamp}.npy"
np.save(SAVDE_DATA_DIR + filename, spice_mc_cg)
print(f"Data saved as file: {filename}")

# Save MC results for analytical solver
filename = f"spice_mc_analy_{description}_{timestamp}.npy"
np.save(SAVDE_DATA_DIR + filename, spice_mc_analy)
print(f"Data saved as file: {filename}")


In [None]:
'''
Analytical Method for SPICE Recon Uncertainty Estimation (Before fitting Uncertainty)
'''

uncert_array = []

# Compute voxel-wise uncertainty across all frequencies (0 to 511)
for freq in np.arange(N_SEQ_POINTS):
    cov_1 = calc_Covariance_spat(NOISE_SD, Lamda_1, W_edge, K_POINTS, V, freq)  # Covariance at given freq
    uncert_1 = calc_std_uncert(cov_1)  # Standard deviation per voxel
    uncert_array.append(np.array(uncert_1))

# Stack list of 1D arrays into a 2D array: [freq, voxel]
uncert_array = np.vstack(uncert_array)
print('Shape of uncert_array:', uncert_array.shape)

# Visualize the uncertainty matrix (frequency vs spatial location)
plt.matshow(np.abs(uncert_array))
plt.title('Analytical Uncertainty Map')
plt.xlabel('Voxel Index')
plt.ylabel('Frequency Index')
plt.colorbar()
plt.show()

# Identify frequencies with highest voxel-wise uncertainty variance
variances = np.var(uncert_array, axis=1)
max_variance_index = np.argpartition(variances, -5)[-5:]  # Top 5 variance rows (frequencies)
max_variance_index = max_variance_index[np.argsort(variances[max_variance_index])[::-1]]
print(f"Top 5 frequencies with highest spatial uncertainty variance: {max_variance_index}")

# Compute uncertainty at manually specified peak frequencies
cov_p1 = calc_Covariance_spat(NOISE_SD, Lamda_1, W_edge, K_POINTS, V, PEAK_0_ROUGH_IDX)
uncert_p1 = calc_std_uncert(cov_p1)

cov_p2 = calc_Covariance_spat(NOISE_SD, Lamda_1, W_edge, K_POINTS, V, PEAK_1_ROUGH_IDX)
uncert_p2 = calc_std_uncert(cov_p2)

# (Optional redundancy block – repeated above)
# Recompute full uncertainty array again (can be removed if already computed above)
uncert_array = []
for freq in np.arange(N_SEQ_POINTS):
    cov_1 = calc_Covariance_spat(NOISE_SD, Lamda_1, W_edge, K_POINTS, V, freq)
    uncert_1 = calc_std_uncert(cov_1)
    uncert_array.append(np.array(uncert_1))
uncert_array = np.vstack(uncert_array)

# --- Plotting results ---

# Compare spectra uncertainty (MC vs analytical)
plot_mc_compare_spec(
    spice_mc_cg, spice_mc_analy,
    res_array3=uncert_array,
    plot_spec_mc=plot_spec_mc,
    plot_spec_analyt=plot_spec_analyt,
    ppm_axis=PPM_AXIS,
    limits=limits
)

# Compare spatial uncertainty maps (MC vs analytical)
plot_mc_compare_spat(
    spice_mc_cg, spice_mc_analy,
    plot_spatial_mc=plot_spatial_mc,
    water_rou_1D=water_rou_1D,
    x_hr_pos=x_hr_pos,
    limits=limits
)

# Combined spatial uncertainty plot with analytical and MC results at key peaks
mc_rec, laplac = plot_spatial_mc_ana_combined(
    mc_cg=spice_mc_cg,
    mc_analyt=spice_mc_analy,
    std_uncert1_analyt=uncert_p1,
    std_uncert2_analyt=uncert_p2,
    water_rou_1D=water_rou_1D,
    x_hr_pos=x_hr_pos
)


In [None]:
""" 
Monte Carlo simulation using DFT reconstruction
"""

# Run Monte Carlo simulation using DFT reconstruction
fft_mc_result = fft_mc(
    fft_recon,
    add_noise=add_noise2kt,
    gen_undersample=Undersampe_zeroout,
    UNDER_SAMPLE_NUM=UNDER_SAMPLE_NUM,
    kt_space_gt=GT_KT_SPACE,
    F=F,
    K_POINTS=K_POINTS,
    time_axis=TIME_AXIS,
    handler=UNDER_SAMPLE_HDL,
    iterations=2000,
    seed_mc=42,
    noise_SD=NOISE_SD
)

# fft_mc_result shape: (n_iter, N_VOXEL, T)
n_iter, N_VOXEL, T = fft_mc_result.shape

# Compute standard deviation across voxels at each frequency
# Resulting shape: (n_iter, T)
std_across_voxels = np.std(fft_mc_result, axis=1)

# Then compute mean and std of those per-frequency stds across iterations → shape: (T,)
mean_std = np.mean(std_across_voxels, axis=0)
std_std = np.std(std_across_voxels, axis=0)

# Plot per-frequency uncertainty
plt.figure(figsize=(10, 5))
plt.plot(mean_std, label="Mean STD across voxels")
plt.fill_between(range(T), mean_std - std_std, mean_std + std_std, alpha=0.3, label="±1 STD band")
plt.title("Per-frequency Uncertainty (STD across voxels)")
plt.xlabel("Frequency index")
plt.ylabel("STD across voxels")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- Voxel-wise Spectral Uncertainty (3D Plot) ---
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')

# Compute uncertainty (STD across MC samples) per voxel and frequency
uncertainty = np.std(np.real(fft_mc_result), axis=0)  # shape: (N_voxel, T)

# Select a few voxels for 3D spectral visualization
voxel_indices = [2, 8, 16, 24, 30]
for voxel in voxel_indices:
    y = np.full_like(PPM_AXIS, voxel)
    z = uncertainty[voxel, :]
    ax.plot(PPM_AXIS, y, z, label=f'Voxel {voxel}')

ax.set_xlim(PPM_AXIS[-1], PPM_AXIS[0])  # Reverse PPM axis for conventional display
ax.set_xlabel('$\\delta$ / ppm')
ax.set_ylabel('Voxel #')
ax.set_zlabel('Uncertainty')
ax.view_init(elev=15, azim=-80)
ax.legend()
plt.tight_layout()
plt.show()

# --- Spatial Uncertainty at Two Spectral Peaks ---
fig, ax = plt.subplots(figsize=(8, 4))

# Compute STD across MC samples at the two specified peak frequencies
peak1est = np.std(np.abs(fft_mc_result[:, :, PEAK_0_ROUGH_IDX]), axis=0)  # shape: (N_voxel,)
peak2est = np.std(np.abs(fft_mc_result[:, :, PEAK_1_ROUGH_IDX]), axis=0)  # shape: (N_voxel,)

# Plot spatial uncertainty at the peak frequencies
xaxis = x_hr_pos  # X-axis: voxel positions
ax.plot(xaxis, peak1est, label='Peak 1 STD')
ax.plot(xaxis, peak2est, label='Peak 2 STD')

ax.set_xlabel('x (voxel position)')
ax.set_ylabel('Uncertainty')
ax.set_title('Spatial Uncertainty at Two Peaks')
ax.set_ylim([0, max(np.max(peak1est), np.max(peak2est)) * 1.1])
ax.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Save MC result to disk
filename = f"fft_mc_result_{description}_{timestamp}.npy"
np.save(SAVDE_DATA_DIR + filename, fft_mc_result)
print(f"Data saved as file: {filename}")


In [None]:
""" 
Uncertainty Comparison between SPICE and DFT before Spectral Fitting
"""

import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from matplotlib.font_manager import FontProperties

# Set bold font for potential legend or axis use
bold_font = FontProperties()
bold_font.set_weight('bold')

# ✅ Set seaborn plot style
sns.set_theme(style="whitegrid", font_scale=1.2)

# Initialize 3D figure
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')

# Custom color palettes for different methods
colors_dft = sns.color_palette("Reds", 15)[-5:]      # For FFT results
colors_spice = sns.color_palette("Blues", 15)[-5:]   # For SPICE results

# Load data (Please change the name according to yout local files)
spice_mc_cg_W_for_plt = np.load(SAVDE_DATA_DIR + "spice_mc_cg_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-13_14.npy")
spice_mc_cg_WO_for_plt = np.load(SAVDE_DATA_DIR + "spice_mc_cg_SNR(100)_lambda(0)_Wmax(500)_US(0)_2025-05-13_14.npy")
fft_mc_result_for_plt = np.load(SAVDE_DATA_DIR + "fft_mc_result_SNR(100)_lambda(0)_Wmax(500)_US(0)_2025-05-13_14.npy")

# Calculate spectral uncertainties
uncertainty = np.std(np.real(fft_mc_result_for_plt), axis=0)  # Shape: (N_voxel, time)
res_array_spec1 = FIDToSpec(spice_mc_cg_W_for_plt, axis=-1).real.std(axis=0)   # SPICE with regularization
res_array_spec2 = FIDToSpec(spice_mc_cg_WO_for_plt, axis=-1).real.std(axis=0)  # SPICE without regularization

# Choose voxel indices to visualize
voxel_indices = [4, 16, 28, 40, 52]

# Plot each voxel's spectrum uncertainty in 3D
for idx, voxel in enumerate(voxel_indices):
    x = PPM_AXIS
    mask = (x >= 0) & (x <= 5)  # Clip PPM range to [0, 5] to avoid noisy edges
    x_clipped = x[mask]
    y = np.full_like(x_clipped, voxel)

    z_std = uncertainty[voxel, :]
    z_spec_W = res_array_spec1[voxel, :]
    z_spec_WO = res_array_spec2[voxel, :]
    
    # Clip to same range
    z_std_clipped = z_std[mask]
    z_spec_W_clipped = z_spec_W[mask]
    z_spec_WO_clipped = z_spec_WO[mask]

    # Plot FFT-based uncertainty
    ax.plot(x_clipped, y, z_std_clipped, label='FFT' if idx == 0 else None, color=colors_dft[idx], linewidth=2)
    
    # Plot SPICE without anatomical prior (lambda=0)
    ax.plot(x_clipped, y, z_spec_WO_clipped, label='SPICE (No Prior)' if idx == 0 else None, color=colors_spice[idx], alpha=0.5, linewidth=2)


# Axis settings
ax.set_xlim([5, 0])  # PPM axis reversed
ax.set_xlabel('$\\delta$  (ppm)', fontweight='bold')
ax.set_ylabel('Voxel', fontweight='bold')
ax.set_zlabel('Uncertainty', fontweight='bold', labelpad=20)
ax.zaxis.set_label_coords(-0.3, 0.5)

# Set viewing angle
ax.view_init(elev=15, azim=-40)

# Optional legend
# ax.legend(loc='center left', bbox_to_anchor=(1.05, 0.5), fontsize=9, prop=bold_font)

plt.tight_layout()
plt.show()


In [None]:
# --- Monte Carlo-Based Fitting of FFT Reconstructed Data ---
cm1_fft, cm2_fft, lw1_fft, lw2_fft = mc_basis(fft_mc_result, bm_FIDs, TIME_AXIS)

# Compute means across MC samples
cm1_mean = np.mean(cm1_fft, axis=0)
cm2_mean = np.mean(cm2_fft, axis=0)
lw1_mean = np.mean(lw1_fft, axis=0)
lw2_mean = np.mean(lw2_fft, axis=0)

# --- Plot mean fitted concentrations ---
plt.figure(figsize=(8, 5))
plt.plot(cm1_mean, marker='o', linestyle='-', label="Cm1 mean")
plt.plot(cm2_mean, marker='s', linestyle='-', label="Cm2 mean")
plt.plot(water_rou_1D * 0.02, label='Anatomical Prior (scaled)', linestyle=':')

plt.xlabel("Voxel Index")
plt.ylabel("Fitted Cm Value")
plt.title("Mean Concentrations")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- Compute standard deviations for concentrations ---
cm1_fft_std = np.std(cm1_fft, axis=0)
cm2_fft_std = np.std(cm2_fft, axis=0)

# --- Plot uncertainty (STD) in concentrations ---
plt.figure(figsize=(8, 5))
plt.plot(cm1_fft_std, marker='o', linestyle='-', label="Cm1 spatial uncertainty")
plt.plot(cm2_fft_std, marker='s', linestyle='-', label="Cm2 spatial uncertainty")
plt.plot(water_rou_1D * 0.02, label='Anatomical Prior (scaled)', linestyle=':')

plt.xlabel("Voxel Index")
plt.ylabel("Uncertainty")
plt.title("Concentration Uncertainty")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- Compute standard deviations for linewidths ---
lw1_fft_std = np.std(lw1_fft, axis=0)
lw2_fft_std = np.std(lw2_fft, axis=0)

# --- Plot uncertainty (STD) in linewidths ---
plt.figure(figsize=(8, 5))
plt.plot(lw1_fft_std, marker='o', linestyle='-', label="Lw1 spatial uncertainty")
plt.plot(lw2_fft_std, marker='s', linestyle='-', label="Lw2 spatial uncertainty")
plt.plot(water_rou_1D * 2, label='Anatomical Prior (scaled)', linestyle=':')

plt.xlabel("Voxel Index")
plt.ylabel("STD of Lw Value")
plt.title("MC-Based MRS Fitting: Linewidth Uncertainty")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- Save fitted parameter arrays ---
def save_array(arr, name):
    filename = f"{name}_{description}_{timestamp}.npy"
    np.save(SAVDE_DATA_DIR + filename, arr)
    print(f"Saved: {filename}")

save_array(cm1_fft, "cm1_fft")
save_array(cm2_fft, "cm2_fft")
save_array(lw1_fft, "lw1_fft")
save_array(lw2_fft, "lw2_fft")
save_array(cm1_fft_std, "cm1_fft_std")
save_array(cm2_fft_std, "cm2_fft_std")
save_array(lw1_fft_std, "lw1_fft_std")
save_array(lw2_fft_std, "lw2_fft_std")


In [None]:
''' 
Laplacian MC for fitting uncertainty for SPICE
'''

# Compute overall covariance matrix using Laplacian approximation
cov_overall = calc_Covariance_spat_overall(NOISE_SD, Lamda_1, W_edge, K_POINTS, V)
cov_overall_old = calc_Covariance_old(NOISE_SD, Lamda_1, W_edge, K_POINTS)

# Ensure new and old implementations are consistent
assert np.allclose(cov_overall, cov_overall_old, atol=1e-6), "cov_overall and cov_overall_old are not close enough!"

# Generate Laplacian-based Monte Carlo samples from estimated signal and covariance
Laplac_data = Create_laplacian_samples(est_U, Vh, cov_overall, 2000)

print(Laplac_data.shape)
print(spice_mc_cg.shape)

# Save core Laplacian data
def save_array(data, name):
    filename = f"{name}_{description}_{timestamp}.npy"
    np.save(SAVDE_DATA_DIR + filename, data)
    print(f"数据已保存为文件：{filename}")

save_array(cov_overall, "cov_overall")
save_array(Laplac_data, "Laplac_data")

# Plot diagonal of WᵀW matrix
result = np.diag((W_edge.conj().T @ W_edge).toarray())
plt.figure(figsize=(8, 5))
plt.plot(result)
plt.plot(water_rou_1D * 2, label='Water ROU 1D', linestyle=':')
plt.legend()
plt.show()

# Run fitting on Laplacian-generated MC samples
cm1_laplac, cm2_laplac, lw1_laplac, lw2_laplac = mc_basis(Laplac_data, bm_FIDs, TIME_AXIS)

# --- Mean Fitting Results ---
cm1_mean = np.mean(cm1_laplac, axis=0)
cm2_mean = np.mean(cm2_laplac, axis=0)
lw1_mean = np.mean(lw1_laplac, axis=0)
lw2_mean = np.mean(lw2_laplac, axis=0)

plt.figure(figsize=(8, 5))
plt.plot(cm1_mean, marker='o', linestyle='-', label="Cm1 mean")
plt.plot(cm2_mean, marker='s', linestyle='-', label="Cm2 mean")
plt.plot(water_rou_1D * 0.02, label='Water ROU 1D', linestyle=':')
plt.xlabel("Voxel Index")
plt.ylabel("Fitted Cm Value")
plt.title("MC-Based MRS Fitting Results average")
plt.legend()
plt.show()

# --- Standard Deviation (Uncertainty) ---
cm1_laplac_std = np.std(cm1_laplac, axis=0)
cm2_laplac_std = np.std(cm2_laplac, axis=0)

plt.figure(figsize=(8, 5))
plt.plot(cm1_laplac_std, marker='o', linestyle='-', label="Cm1 spatial uncert")
plt.plot(cm2_laplac_std, marker='s', linestyle='-', label="Cm2 spatial uncert")
plt.plot(water_rou_1D * 0.02, label='Water ROU 1D', linestyle=':')
plt.xlabel("Voxel Index")
plt.ylabel("Fitted Cm Value")
plt.title("Laplacian-approx MC-Based MRS Fitting Results with Uncertainty")
plt.legend()
plt.show()

lw1_laplac_std = np.std(lw1_laplac, axis=0)
lw2_laplac_std = np.std(lw2_laplac, axis=0)

plt.figure(figsize=(8, 5))
plt.plot(lw1_laplac_std, marker='o', linestyle='-', label="Cm1 spatial uncert")
plt.plot(lw2_laplac_std, marker='s', linestyle='-', label="Cm2 spatial uncert")
plt.plot(water_rou_1D * 0.02, label='Water ROU 1D', linestyle=':')
plt.xlabel("Voxel Index")
plt.ylabel("Fitted lw Value")
plt.title("Laplacian-approx MC-Based MRS Lw Fitting Results with Uncertainty")
plt.legend()
plt.show()

# Save all Laplacian-derived results
save_array(cm1_laplac, "cm1_laplac")
save_array(cm2_laplac, "cm2_laplac")
save_array(lw1_laplac, "lw1_laplac")
save_array(lw2_laplac, "lw2_laplac")
save_array(cm1_laplac_std, "cm1_laplac_std")
save_array(cm2_laplac_std, "cm2_laplac_std")
save_array(lw1_laplac_std, "lw1_laplac_std")
save_array(lw2_laplac_std, "lw2_laplac_std")


In [None]:
''' 
MC style fitting uncertainty for SPICE
'''
# mc_basis: extract concentration and linewidth from SPICE Monte Carlo results
cm1, cm2, lw1, lw2 = mc_basis(spice_mc_cg, bm_FIDs, TIME_AXIS)

cm1_mean = np.mean(cm1, axis=0)
cm2_mean = np.mean(cm2, axis=0)
lw1_mean = np.mean(lw1, axis=0)
lw2_mean = np.mean(lw2, axis=0)

# Plot results with error bars
plt.figure(figsize=(8, 5))
plt.plot(cm1_mean, marker='o', linestyle='-',  label="Cm1 mean")
plt.plot(cm2_mean, marker='s', linestyle='-',  label="Cm2 mean")
plt.plot(water_rou_1D * 0.02, label='Water ROU 1D', linestyle=':')

plt.xlabel("Voxel Index")
plt.ylabel("Fitted Cm Value")
plt.title("MC-Based MRS Fitting Results average")
plt.legend()
plt.show()

cm1_std = np.std(cm1, axis=0)
cm2_std = np.std(cm2, axis=0)

# Plot results with error bars
plt.figure(figsize=(8, 5))
plt.plot(cm1_std, marker='o', linestyle='-',  label="Cm1 spatial uncert")
plt.plot(cm2_std, marker='s', linestyle='-',  label="Cm2 spatial uncert")
plt.plot(water_rou_1D * 0.02, label='Water ROU 1D', linestyle=':')

plt.xlabel("Voxel Index")
plt.ylabel("Fitted Cm Value")
plt.title("MC-Based MRS Fitting Results with Uncertainty")
plt.legend()
plt.show()

lw1_std = np.std(lw1, axis=0)
lw2_std = np.std(lw2, axis=0)

# Plot results with error bars
plt.figure(figsize=(8, 5))
plt.plot(lw1_std, marker='o', linestyle='-',  label="Cm1 spatial uncert")
plt.plot(lw2_std, marker='s', linestyle='-',  label="Cm2 spatial uncert")
plt.plot(water_rou_1D * 0.02, label='Water ROU 1D', linestyle=':')

plt.xlabel("Voxel Index")
plt.ylabel("Fitted lw Value")
plt.title("MC-Based MRS Lw Fitting Results with Uncertainty")
plt.legend()
plt.show()

# Save data arrays
def save_array(data, name):
    filename = f"{name}_{description}_{timestamp}.npy"
    np.save(SAVDE_DATA_DIR + filename, data)
    print(f"数据已保存为文件：{filename}")

save_array(cm1, "cm1")
save_array(cm2, "cm2")
save_array(lw1, "lw1")
save_array(lw2, "lw2")
save_array(cm1_std, "cm1_std")
save_array(cm2_std, "cm2_std")
save_array(lw1_std, "lw1_std")
save_array(lw2_std, "lw2_std")


In [None]:

# Loading Data 64 voxel
cm1_std_standard = np.load(SAVDE_DATA_DIR+"cm1_laplac_std_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-11_22.npy")
cm2_std_standard = np.load(SAVDE_DATA_DIR+"cm2_laplac_std_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-11_22.npy")
lw1_std_standard = np.load(SAVDE_DATA_DIR+"lw1_laplac_std_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-11_22.npy")
lw2_std_standard = np.load(SAVDE_DATA_DIR+"lw2_laplac_std_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-11_22.npy")

cm1_standard_idf = "_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-11_22.npy"
cm1_standard = np.load(SAVDE_DATA_DIR+"cm1"+cm1_standard_idf)
cm2_standard = np.load(SAVDE_DATA_DIR+"cm2"+cm1_standard_idf)
cm1_standard_mean = np.mean(cm1_standard,axis=0)
cm2_standard_mean = np.mean(cm2_standard,axis=0)
cm1_std_standard_cg = np.load(SAVDE_DATA_DIR+"cm1_std"+cm1_standard_idf)
cm2_std_standard_cg = np.load(SAVDE_DATA_DIR+"cm2_std"+cm1_standard_idf)

cm1_std_fft = np.load(SAVDE_DATA_DIR+"cm1_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-14_21.npy")
cm2_std_fft = np.load(SAVDE_DATA_DIR+"cm2_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-14_21.npy")
lw1_std_fft = np.load(SAVDE_DATA_DIR+"lw1_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-14_21.npy")
lw2_std_fft = np.load(SAVDE_DATA_DIR+"lw2_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(0)_2025-05-14_21.npy")

cm1_fft_idf = "_SNR(100)_lambda(0)_Wmax(500)_US(0)_2025-05-14_19.npy"
cm1_fft = np.load(SAVDE_DATA_DIR+"cm1_fft"+cm1_fft_idf)
cm2_fft = np.load(SAVDE_DATA_DIR+"cm2_fft"+cm1_fft_idf)
cm1_fft_mean = np.mean(cm1_fft,axis=0)
cm2_fft_mean = np.mean(cm2_fft,axis=0)

cm1_std_Zero_Lambda = np.load(SAVDE_DATA_DIR+"cm1_laplac_std_SNR(100)_lambda(0)_Wmax(500)_US(0)_2025-05-11_22.npy")
cm2_std_Zero_Lambda = np.load(SAVDE_DATA_DIR+"cm2_laplac_std_SNR(100)_lambda(0)_Wmax(500)_US(0)_2025-05-11_22.npy")
lw1_std_Zero_Lambda = np.load(SAVDE_DATA_DIR+"lw1_laplac_std_SNR(100)_lambda(0)_Wmax(500)_US(0)_2025-05-11_22.npy")
lw2_std_Zero_Lambda = np.load(SAVDE_DATA_DIR+"lw1_laplac_std_SNR(100)_lambda(0)_Wmax(500)_US(0)_2025-05-11_22.npy")

cm1_Zero_Lambda_idf = "_SNR(100)_lambda(0)_Wmax(500)_US(0)_2025-05-11_22.npy"
cm1_Zero_Lambda = np.load(SAVDE_DATA_DIR+"cm1"+cm1_Zero_Lambda_idf)
cm2_Zero_Lambda = np.load(SAVDE_DATA_DIR+"cm2"+cm1_Zero_Lambda_idf)
cm1_Zero_Lambda_mean = np.mean(cm1_Zero_Lambda,axis=0)
cm2_Zero_Lambda_mean = np.mean(cm2_Zero_Lambda,axis=0)
cm1_std_Zero_Lambda_cg = np.load(SAVDE_DATA_DIR+"cm1_std"+cm1_Zero_Lambda_idf)
cm2_std_Zero_Lambda_cg = np.load(SAVDE_DATA_DIR+"cm2_std"+cm1_Zero_Lambda_idf)


cm1_std_mid_Lambda = np.load(SAVDE_DATA_DIR+"cm1_laplac_std_SNR(100)_lambda(5)_Wmax(500)_US(0)_2025-05-11_22.npy")
cm2_std_mid_Lambda = np.load(SAVDE_DATA_DIR+"cm2_laplac_std_SNR(100)_lambda(5)_Wmax(500)_US(0)_2025-05-11_22.npy")
lw1_std_mid_Lambda = np.load(SAVDE_DATA_DIR+"lw1_laplac_std_SNR(100)_lambda(5)_Wmax(500)_US(0)_2025-05-11_22.npy")
lw2_std_mid_Lambda = np.load(SAVDE_DATA_DIR+"lw2_laplac_std_SNR(100)_lambda(5)_Wmax(500)_US(0)_2025-05-11_22.npy")

cm1_mid_Lambda_idf = "_SNR(100)_lambda(5)_Wmax(500)_US(0)_2025-05-11_22.npy"
cm1_mid_Lambda = np.load(SAVDE_DATA_DIR+"cm1"+cm1_mid_Lambda_idf)
cm2_mid_Lambda = np.load(SAVDE_DATA_DIR+"cm2"+cm1_mid_Lambda_idf)
cm1_mid_Lambda_mean = np.mean(cm1_mid_Lambda,axis=0)
cm2_mid_Lambda_mean = np.mean(cm2_mid_Lambda,axis=0)

cm1_std_high_Lambda = np.load(SAVDE_DATA_DIR+"cm1_laplac_std_SNR(100)_lambda(50)_Wmax(500)_US(0)_2025-05-11_22.npy")
cm2_std_high_Lambda = np.load(SAVDE_DATA_DIR+"cm2_laplac_std_SNR(100)_lambda(50)_Wmax(500)_US(0)_2025-05-11_22.npy")
lw1_std_high_Lambda = np.load(SAVDE_DATA_DIR+"lw1_laplac_std_SNR(100)_lambda(50)_Wmax(500)_US(0)_2025-05-11_22.npy")
lw2_std_high_Lambda = np.load(SAVDE_DATA_DIR+"lw2_laplac_std_SNR(100)_lambda(50)_Wmax(500)_US(0)_2025-05-11_22.npy")

cm1_high_Lambda_idf = "_SNR(100)_lambda(50)_Wmax(500)_US(0)_2025-05-11_22.npy"
cm1_high_Lambda = np.load(SAVDE_DATA_DIR+"cm1"+cm1_high_Lambda_idf)
cm2_high_Lambda = np.load(SAVDE_DATA_DIR+"cm2"+cm1_high_Lambda_idf)
cm1_high_Lambda_mean = np.mean(cm1_high_Lambda,axis=0)
cm2_high_Lambda_mean = np.mean(cm2_high_Lambda,axis=0)

cm1_std_high_Wmax = np.load(SAVDE_DATA_DIR+"cm1_laplac_std_SNR(100)_lambda(0.05)_Wmax(50000)_US(0)_2025-05-11_23.npy")
cm2_std_high_Wmax = np.load(SAVDE_DATA_DIR+"cm2_laplac_std_SNR(100)_lambda(0.05)_Wmax(50000)_US(0)_2025-05-11_23.npy")
lw1_std_high_Wmax = np.load(SAVDE_DATA_DIR+"lw1_laplac_std_SNR(100)_lambda(0.05)_Wmax(50000)_US(0)_2025-05-11_23.npy")
lw2_std_high_Wmax = np.load(SAVDE_DATA_DIR+"lw2_laplac_std_SNR(100)_lambda(0.05)_Wmax(50000)_US(0)_2025-05-11_23.npy")

cm1_high_Wmax_idf = "_SNR(100)_lambda(0.05)_Wmax(50000)_US(0)_2025-05-11_23.npy"
cm1_high_Wmax = np.load(SAVDE_DATA_DIR+"cm1"+cm1_high_Wmax_idf)
cm2_high_Wmax = np.load(SAVDE_DATA_DIR+"cm2"+cm1_high_Wmax_idf)
cm1_high_Wmax_mean = np.mean(cm1_high_Wmax,axis=0)
cm2_high_Wmax_mean = np.mean(cm2_high_Wmax,axis=0)

cm1_std_low_Wmax = np.load(SAVDE_DATA_DIR+"cm1_laplac_std_SNR(100)_lambda(0.05)_Wmax(5)_US(0)_2025-05-11_23.npy")
cm2_std_low_Wmax = np.load(SAVDE_DATA_DIR+"cm2_laplac_std_SNR(100)_lambda(0.05)_Wmax(5)_US(0)_2025-05-11_23.npy")
lw1_std_low_Wmax = np.load(SAVDE_DATA_DIR+"lw1_laplac_std_SNR(100)_lambda(0.05)_Wmax(5)_US(0)_2025-05-11_23.npy")
lw2_std_low_Wmax = np.load(SAVDE_DATA_DIR+"lw2_laplac_std_SNR(100)_lambda(0.05)_Wmax(5)_US(0)_2025-05-11_23.npy")

cm1_low_Wmax_idf = "_SNR(100)_lambda(0.05)_Wmax(5)_US(0)_2025-05-11_23.npy"
cm1_low_Wmax = np.load(SAVDE_DATA_DIR+"cm1"+cm1_low_Wmax_idf)
cm2_low_Wmax = np.load(SAVDE_DATA_DIR+"cm2"+cm1_low_Wmax_idf)
cm1_low_Wmax_mean = np.mean(cm1_low_Wmax,axis=0)
cm2_low_Wmax_mean = np.mean(cm2_low_Wmax,axis=0)

cm1_std_US_50  = np.load(SAVDE_DATA_DIR+"cm1_std_SNR(100)_lambda(0.05)_Wmax(50)_US(50)_2025-05-11_23.npy")
cm2_std_US_50  = np.load(SAVDE_DATA_DIR+"cm2_std_SNR(100)_lambda(0.05)_Wmax(50)_US(50)_2025-05-11_23.npy")
lw1_std_US_50  = np.load(SAVDE_DATA_DIR+"lw1_std_SNR(100)_lambda(0.05)_Wmax(50)_US(50)_2025-05-11_23.npy")
lw2_std_US_50  = np.load(SAVDE_DATA_DIR+"lw2_std_SNR(100)_lambda(0.05)_Wmax(50)_US(50)_2025-05-11_23.npy")

cm1_US_50_idf = "_SNR(100)_lambda(0.05)_Wmax(50)_US(50)_2025-05-11_23.npy"
cm1_US_50 = np.load(SAVDE_DATA_DIR+"cm1"+cm1_US_50_idf)
cm2_US_50 = np.load(SAVDE_DATA_DIR+"cm2"+cm1_US_50_idf)
cm1_US_50_mean = np.mean(cm1_US_50,axis=0)
cm2_US_50_mean = np.mean(cm2_US_50,axis=0)

cm1_std_US_25  = np.load(SAVDE_DATA_DIR+"cm1_std_SNR(100)_lambda(0.05)_Wmax(50)_US(25)_2025-05-12_00.npy")
cm2_std_US_25  = np.load(SAVDE_DATA_DIR+"cm2_std_SNR(100)_lambda(0.05)_Wmax(50)_US(25)_2025-05-12_00.npy")
lw1_std_US_25  = np.load(SAVDE_DATA_DIR+"lw1_std_SNR(100)_lambda(0.05)_Wmax(50)_US(25)_2025-05-12_00.npy")
lw2_std_US_25  = np.load(SAVDE_DATA_DIR+"lw2_std_SNR(100)_lambda(0.05)_Wmax(50)_US(25)_2025-05-12_00.npy")

cm1_US_25_idf = "_SNR(100)_lambda(0.05)_Wmax(50)_US(25)_2025-05-12_00.npy"
cm1_US_25 = np.load(SAVDE_DATA_DIR+"cm1"+cm1_US_25_idf)
cm2_US_25 = np.load(SAVDE_DATA_DIR+"cm2"+cm1_US_25_idf)
cm1_US_25_mean = np.mean(cm1_US_25,axis=0)
cm2_US_25_mean = np.mean(cm2_US_25,axis=0)

cm1_std_full  = np.load(SAVDE_DATA_DIR+"cm1_std_SNR(100)_lambda(0.05)_Wmax(50)_US(0)_2025-05-11_23.npy")
cm2_std_full  = np.load(SAVDE_DATA_DIR+"cm2_std_SNR(100)_lambda(0.05)_Wmax(50)_US(0)_2025-05-11_23.npy")
lw1_std_full  = np.load(SAVDE_DATA_DIR+"lw1_std_SNR(100)_lambda(0.05)_Wmax(50)_US(0)_2025-05-11_23.npy")
lw2_std_full  = np.load(SAVDE_DATA_DIR+"lw2_std_SNR(100)_lambda(0.05)_Wmax(50)_US(0)_2025-05-11_23.npy")

cm1_full_idf = "_SNR(100)_lambda(0.05)_Wmax(50)_US(0)_2025-05-11_23.npy"
cm1_full = np.load(SAVDE_DATA_DIR+"cm1"+cm1_full_idf)
cm2_full = np.load(SAVDE_DATA_DIR+"cm2"+cm1_full_idf)
cm1_full_mean = np.mean(cm1_full,axis=0)
cm2_full_mean = np.mean(cm2_full,axis=0)

cm1_std_fft_US_25 = np.load(SAVDE_DATA_DIR+"cm1_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(25)_2025-05-14_21.npy")
cm2_std_fft_US_25 = np.load(SAVDE_DATA_DIR+"cm2_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(25)_2025-05-14_21.npy")
lw1_std_fft_US_25 = np.load(SAVDE_DATA_DIR+"lw1_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(25)_2025-05-14_21.npy")
lw2_std_fft_US_25 = np.load(SAVDE_DATA_DIR+"lw2_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(25)_2025-05-14_21.npy")

cm1_fft_US_25_idf = "_SNR(100)_lambda(0)_Wmax(500)_US(25)_2025-05-14_22.npy"
cm1_fft_US_25 = np.load(SAVDE_DATA_DIR+"cm1_fft"+cm1_fft_US_25_idf)
cm2_fft_US_25 = np.load(SAVDE_DATA_DIR+"cm2_fft"+cm1_fft_US_25_idf)
cm1_fft_US_25_mean = np.mean(cm1_fft_US_25,axis=0)
cm2_fft_US_25_mean = np.mean(cm2_fft_US_25,axis=0)

cm1_std_fft_US_50 = np.load(SAVDE_DATA_DIR+"cm1_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(50)_2025-05-14_21.npy")
cm2_std_fft_US_50 = np.load(SAVDE_DATA_DIR+"cm2_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(50)_2025-05-14_21.npy")
lw1_std_fft_US_50 = np.load(SAVDE_DATA_DIR+"lw1_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(50)_2025-05-14_21.npy")
lw2_std_fft_US_50 = np.load(SAVDE_DATA_DIR+"lw2_fft_std_SNR(100)_lambda(0.05)_Wmax(500)_US(50)_2025-05-14_21.npy")

cm1_fft_US_50_idf = "_SNR(100)_lambda(0.05)_Wmax(500)_US(50)_2025-05-14_21.npy"
cm1_fft_US_50 = np.load(SAVDE_DATA_DIR+"cm1_fft"+cm1_fft_US_50_idf)
cm2_fft_US_50 = np.load(SAVDE_DATA_DIR+"cm2_fft"+cm1_fft_US_50_idf)
cm1_fft_US_50_mean = np.mean(cm1_fft_US_50,axis=0)
cm2_fft_US_50_mean = np.mean(cm2_fft_US_50,axis=0)




In [None]:
"""  
Undersampling Comparison
"""

plt.figure(figsize=(9, 5))

# --- Full (0%) ---
sns.lineplot(x=range(len(cm1_std_standard)), y=cm1_std_standard,
             label="Glu (Full)", color='tab:blue', linewidth=2, alpha=1.0)

sns.lineplot(x=range(len(cm1_std_fft)), y=cm1_std_fft,
             label="Glu (Full)", color='tab:red', linewidth=2, alpha=1, linestyle='-')

# --- Undersample 25% ---
sns.lineplot(x=range(len(cm1_std_US_25)), y=cm1_std_US_25,
             label="Glu (US 25%)", color='tab:blue', linewidth=2, alpha=0.7)

sns.lineplot(x=range(len(cm1_std_fft_US_25)), y=cm1_std_fft_US_25,
             label="Glu (US 25%)", color='tab:red', linewidth=2, alpha=0.7, linestyle='-')

# --- Undersample 50% ---
sns.lineplot(x=range(len(cm1_std_US_50)), y=cm1_std_US_50,
             label="Glu (US 50%)", color='tab:blue', linewidth=2, alpha=0.3)

sns.lineplot(x=range(len(cm1_std_fft_US_50)), y=cm1_std_fft_US_50,
             label="Glu (US 50%)", color='tab:red', linewidth=2, alpha=0.3, linestyle='-')

# --- Anatomical Reference ---
sns.lineplot(x=range(len(water_rou_1D)), y=water_rou_1D * 0.06,
             label='Anatomical Reference', linestyle=':', color='black', linewidth=2)

plt.xlabel("Voxel Index", fontweight='bold')
plt.ylabel("Uncertainty", fontweight='bold')
plt.title("Concentration Uncertainty", fontweight='bold')
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), prop={'weight': 'bold'})
plt.tight_layout()
plt.show()

# --- Concentration Mean ---
plt.figure(figsize=(9, 5))

# --- FFT-based Mean ---
sns.lineplot(x=np.arange(len(cm1_fft_mean)) + 1, y=cm1_fft_mean,
             color='tab:red', linewidth=2, alpha=1)
sns.lineplot(x=range(len(cm1_fft_US_25_mean)), y=cm1_fft_US_25_mean,
             color='tab:red', linewidth=2, alpha=0.7)
sns.lineplot(x=range(len(cm1_fft_US_25_mean)), y=cm1_fft_US_50_mean,
             color='tab:red', linewidth=2, alpha=0.3)

# --- SPICE-based Mean ---
sns.lineplot(x=range(len(cm1_standard_mean)), y=cm1_standard_mean,
             color='tab:blue', linewidth=2, alpha=1.0)
sns.lineplot(x=range(len(cm1_US_25_mean)), y=cm1_US_25_mean,
             color='tab:blue', linewidth=2, alpha=0.7)
sns.lineplot(x=range(len(cm1_US_25_mean)), y=cm1_US_50_mean,
             color='tab:blue', linewidth=2, alpha=0.3)

# --- Ground Truth ---
sns.lineplot(x=range(len(C0_GT)), y=C0_GT,
             color='black', linewidth=2, alpha=1, linestyle=':')

plt.xlabel("Voxel Index", fontweight="bold")
plt.ylabel("Concentration", fontweight="bold")
plt.title("Concentration mean", fontweight="bold")
plt.tight_layout()
plt.show()



In [None]:
"""  
Constraint Weighting λ Comparison
"""

# Set seaborn style
sns.set(style="whitegrid")

# === Plot Cm uncertainty ===
plt.figure(figsize=(9, 5))

# Zero λ
sns.lineplot(x=range(len(cm1_std_Zero_Lambda)), y=cm1_std_Zero_Lambda,
             label="No Constraint(λ = 0)", linestyle='--', color='tab:blue', linewidth=2, alpha=0.3)

# Standard λ
sns.lineplot(x=range(len(cm1_std_standard)), y=cm1_std_standard,
             label="Moderate λ", linestyle='-', color='tab:blue', linewidth=2, alpha=0.7)

# High λ
sns.lineplot(x=range(len(cm1_std_mid_Lambda)), y=cm1_std_mid_Lambda,
             label="High λ", linestyle='-', color='tab:blue', linewidth=2, alpha=1)

# Reference
sns.lineplot(x=range(len(water_rou_1D)), y=water_rou_1D * 0.04,
             label='Anatomical Reference', linestyle=':', color='gray', linewidth=2)

plt.xlabel("Voxel Index", fontweight='bold')
plt.ylabel("Uncertainty", fontweight='bold')
plt.title("Concentration Uncertainty (λ)", fontweight='bold')
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), prop={'weight': 'bold'})
plt.tight_layout()
plt.show()

# === Plot Cm mean ===
plt.figure(figsize=(9, 5))

# Ground Truth
sns.lineplot(x=range(len(C0_GT)), y=C0_GT,
             label='Ground Truth', color='black', linewidth=2, alpha=1, linestyle=':')

# Zero λ
sns.lineplot(x=range(len(cm1_Zero_Lambda_mean)), y=cm1_Zero_Lambda_mean,
             label="No Constraint(λ = 0)", linestyle='--', color='tab:blue', linewidth=3, alpha=0.3)

# Standard λ
sns.lineplot(x=range(len(cm1_standard_mean)), y=cm1_standard_mean,
             label="Moderate λ", linestyle='-', color='tab:blue', linewidth=2, alpha=0.7)

# High λ
sns.lineplot(x=range(len(cm1_mid_Lambda_mean)), y=cm1_mid_Lambda_mean,
             label="High λ", linestyle='-', color='tab:blue', linewidth=2, alpha=1)

plt.xlabel("Voxel Index", fontweight='bold')
plt.ylabel("concentration", fontweight='bold')
plt.title("Concentration Mean (λ)", fontweight='bold')
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), prop={'weight': 'bold'})
plt.tight_layout()
plt.show()

# === Zoomed-in version of above ===
plt.figure(figsize=(6, 5))

# Ground Truth
sns.lineplot(x=range(len(C0_GT)), y=C0_GT,
             color='black', linewidth=2, alpha=1, linestyle=':')

# Zero λ
sns.lineplot(x=range(len(cm1_Zero_Lambda_mean)), y=cm1_Zero_Lambda_mean,
             linestyle='--', color='tab:blue', linewidth=2, alpha=0.3)

# Standard λ
sns.lineplot(x=range(len(cm1_standard_mean)), y=cm1_standard_mean,
             linestyle='-', color='tab:blue', linewidth=2, alpha=0.7)

# High λ
sns.lineplot(x=range(len(cm1_mid_Lambda_mean)), y=cm1_mid_Lambda_mean,
             linestyle='-', color='tab:blue', linewidth=2, alpha=1)

plt.xlim(30, 40)  # Zoom into voxel index range
plt.ylim(0.15, 0.6)
plt.tight_layout()
plt.show()


In [None]:
"""  
Constraint matrix Dw Comparison (Tuning Wmax)
""" 

# === Plot Cm uncertainty (Dw) ===
plt.figure(figsize=(9, 5))

# High
sns.lineplot(x=range(len(cm1_std_high_Wmax)), y=cm1_std_high_Wmax,
             label="High Dw", linestyle='-', color='tab:blue', linewidth=2, alpha=1)

# Low
sns.lineplot(x=range(len(cm1_std_low_Wmax)), y=cm1_std_low_Wmax,
             label="Low Dw", linestyle='-', color='tab:blue', linewidth=2, alpha=0.3)

# Reference
sns.lineplot(x=range(len(water_rou_1D)), y=water_rou_1D * 0.04,
             label='Anatomical Reference', linestyle=':', color='gray', linewidth=2)

plt.xlabel("Voxel Index", fontweight='bold')
plt.ylabel("Uncertainty", fontweight='bold')
plt.title("Concentration Uncertainty (Dw)", fontweight='bold')
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), prop={'weight': 'bold'})
plt.tight_layout()
plt.show()

plt.figure(figsize=(9, 5))

# Reference
sns.lineplot(x=range(len(C0_GT)), y=C0_GT,
             label='Ground Truth', color='black', linewidth=2, alpha=1, linestyle='--')

# High Dw
sns.lineplot(x=range(len(cm1_high_Wmax_mean)), y=cm1_high_Wmax_mean,
             label='High Dw', color='tab:blue', linewidth=2, alpha=1.0)

# Low Dw
sns.lineplot(x=range(len(cm1_low_Wmax_mean)), y=cm1_low_Wmax_mean,
             label='Low Dw', color='tab:blue', linewidth=2, alpha=0.4)

plt.xlabel("Voxel Index", fontweight='bold')
plt.ylabel("concentration", fontweight='bold')
plt.title("Concentration Mean (Dw)", fontweight='bold')
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), prop={'weight': 'bold'})
plt.tight_layout()
plt.show()

plt.figure(figsize=(6, 5))

# Reference
sns.lineplot(x=range(len(C0_GT)), y=C0_GT,
             color='black', linewidth=2, alpha=1, linestyle='--')

# High Dw
sns.lineplot(x=range(len(cm1_high_Wmax_mean)), y=cm1_high_Wmax_mean,
             color='tab:blue', linewidth=2, alpha=1.0)

# Low Dw
sns.lineplot(x=range(len(cm1_low_Wmax_mean)), y=cm1_low_Wmax_mean,
             color='tab:blue', linewidth=2, alpha=0.4)

plt.xlim(35, 42)
plt.ylim(0.4, 0.8)
plt.tight_layout()
plt.show()


In [None]:
"""  
FFT and SPICE Comparison
""" 

plt.figure(figsize=(9, 5))

# --- FFT ---
sns.lineplot(x=range(len(cm1_std_fft)), y=cm1_std_fft,
             label="Glu (FFT)", color='tab:red', linewidth=2, alpha=1)
sns.lineplot(x=range(len(cm2_std_fft)), y=cm2_std_fft,
             label="Cho (FFT)", color='tab:orange', linewidth=2, alpha=0.5)


# --- without constraint ---
sns.lineplot(x=range(len(cm1_std_Zero_Lambda_cg)), y=cm1_std_Zero_Lambda_cg,
             label="Glu (SPICE)", color='tab:blue', linewidth=3, alpha=0.5, linestyle='--')
sns.lineplot(x=range(len(cm2_std_Zero_Lambda_cg)), y=cm2_std_Zero_Lambda_cg,
             label="Cho (SPICE)", color='tab:cyan', linewidth=3, alpha=0.5, linestyle='--')



# --- with constraint ---
sns.lineplot(x=range(len(cm1_std_standard_cg)), y=cm1_std_standard_cg,
             label="Glu (SPICE + constraint)", color='tab:blue', linewidth=2, alpha=1)
sns.lineplot(x=range(len(cm2_std_standard_cg)), y=cm2_std_standard_cg,
             label="Cho (SPICE + constraint)", color='tab:cyan', linewidth=2, alpha=1)



# --- Reference ---
sns.lineplot(x=range(len(water_rou_1D)), y=water_rou_1D * 0.03,
             label='Anatomical Reference', linestyle=':', color='gray', linewidth=2)

plt.xlabel("Voxel Index", fontweight='bold')
plt.ylabel("Uncertainty", fontweight='bold')
plt.title("Concentration Uncertainty after Fitting", fontweight='bold')
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), prop={'weight': 'bold'})
plt.tight_layout()
plt.show()

