In [None]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Main paths
pred_dir = '/home/maia-user/cifs/Datasets/PD_Private/chrol/ParkMRE/ablation_nopreweight_masked3.0/6_noStructnoCmu/niftimaps_denormalized'
gt_base_dir = '/home/maia-user/cifs/Datasets/PD_Private/chrol/ParkMRE/MRE_T1toMNI_202402'

# Function to load valid voxels
def load_valid_voxels(nifti_path):
    data = nib.load(nifti_path).get_fdata()
    return data[np.isfinite(data)].flatten()

# Initialize lists for groups
control_preds = []
control_gts = []
patient_preds = []
patient_gts = []

# Loop through all files in the predictions folder
for fname in os.listdir(pred_dir):
    if not fname.endswith('_prediction_denormalized.nii.gz'):
        continue
    
    pred_path = os.path.join(pred_dir, fname)

    # Extract subject ID from filename
    subj_id = fname.split('_prediction')[0]

    # Determine whether it's Patient or Control
    if "Control" in subj_id:
        group = "Control"
    elif "Patient" in subj_id:
        group = "Patient"
    else:
        print(f"Skipping unknown group: {fname}")
        continue

    # Build path to ground truth (same ID, different directory)
    gt_path = os.path.join(gt_base_dir, subj_id, 'MRE_stiffness_ToT1_202402_t1_to_MNI.nii.gz')
    
    if not os.path.exists(gt_path):
        print(f"GT not found for: {subj_id}")
        continue

    # Load valid voxels
    pred_voxels = load_valid_voxels(pred_path)
    gt_voxels = load_valid_voxels(gt_path)

    # Add to respective lists
    if group == "Control":
        control_preds.append(pred_voxels)
        control_gts.append(gt_voxels)
    else:
        patient_preds.append(pred_voxels)
        patient_gts.append(gt_voxels)

# Concatenate all distributions into a single array per group
control_preds_all = np.concatenate(control_preds)
control_gts_all = np.concatenate(control_gts)
patient_preds_all = np.concatenate(patient_preds)
patient_gts_all = np.concatenate(patient_gts)

# Plot
plt.figure(figsize=(14, 7))
'''
plt.hist(patient_gts_all, bins=300, alpha=0.5,
         label='Patient GT (MRE)', color='#0072B2', edgecolor='black')   # blue
plt.hist(control_gts_all, bins=300, alpha=0.5,
         label='Control GT (MRE)', color='#009E73', edgecolor='black')   # aquamarine green
plt.hist(control_preds_all, bins=300, alpha=0.5,
         label='Control Prediction', color='#E69F00', edgecolor='black') # warm orange
plt.hist(patient_preds_all, bins=300, alpha=0.5,
         label='Patient Prediction', color='#FF7F50', edgecolor='black') # soft red
'''
plt.hist(patient_gts_all, bins=300, alpha=0.5, label='Patient GT (MRE)', color='blue', edgecolor='black')
plt.hist(control_gts_all, bins=300, alpha=0.5, label='Control GT (MRE)', color='red', edgecolor='black')
plt.hist(control_preds_all, bins=300, alpha=0.5, label='Control Prediction', color='orange', edgecolor='black')
plt.hist(patient_preds_all, bins=300, alpha=0.5, label='Patient Prediction', color='green', edgecolor='black')
plt.title("Combined Histogram of Predictions vs Ground Truth (MRE)")
plt.xlabel("Intensity Value (Pa)")
plt.ylabel("Number of Voxels")
plt.legend()
plt.ylim([0, 1e5])
plt.xlim([0, 4000])
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# Frequency Analysis
import os
import numpy as np
import nibabel as nib

folder = "/home/maia-user/cifs/Datasets/PD_Private/chrol/ParkMRE/MRE_T1toMNI_202402/MRE_test/"

# Find all .nii.gz files in the folder
files = [f for f in os.listdir(folder) if f.endswith(".nii.gz")]
files = sorted(files)  

energies_high = []
energies_total = []

for filename in files:
    path = os.path.join(folder, filename)
    img = nib.load(path)
    data = img.get_fdata()
    vox_sizes = img.header.get_zooms()[:3]

    nx, ny, nz = data.shape

    # 1D frequencies and shifted
    fx = np.fft.fftfreq(nx, d=vox_sizes[0])
    fy = np.fft.fftfreq(ny, d=vox_sizes[1])
    fz = np.fft.fftfreq(nz, d=vox_sizes[2])
    fx = np.fft.fftshift(fx)
    fy = np.fft.fftshift(fy)
    fz = np.fft.fftshift(fz)

    FX, FY, FZ = np.meshgrid(fy, fx, fz, indexing='xy')
    freq_radius = np.sqrt(FX**2 + FY**2 + FZ**2)
    max_freq = freq_radius.max(); cutoff = 0.05 * max_freq


    high_mask = freq_radius > cutoff

    # Compute FFT and power
    fft_data = np.fft.fftn(data)
    fft_shift = np.fft.fftshift(fft_data)
    power = (np.abs(np.fft.fftshift(fft_data))**2) / np.prod(data.shape)
    total_energy = power.sum()
    high_energy = power[high_mask].sum()

    energies_total.append(total_energy)
    energies_high.append(high_energy)

    print(f"{filename}: Total energy = {total_energy:.3e}, high frequency energy = {high_energy:.3e}")

# Average high frequency and total energy
mean_high_energy = np.mean(energies_high)
mean_total_energy = np.mean(energies_total)
mean_fraction = mean_high_energy / mean_total_energy if mean_total_energy != 0 else 0

print("\n--- Average over all images ---")
print(f"Mean total energy: {mean_total_energy:.3e}")
print(f"Mean high frequency energy: {mean_high_energy:.3e}")
print(f"High frequency energy : {mean_high_energy:.3e} ({mean_fraction*100:.2f} % of total)")