In [11]:
import os
import pandas as pd
import numpy as np
import nibabel as nib
from neuroHarmonize import harmonizationLearn
import pickle

In [14]:
reference_dir ='/home/dual4090/lab/github/synth7T-MICCAI/data/original/3T/UCSF/'

In [19]:
masked_data[0].shape

(21023600,)

In [20]:
# 1. Data Preparation
filepaths = [os.path.join(reference_dir, f) for f in os.listdir(reference_dir) if f.endswith('.nii.gz')]
if not filepaths:
    raise ValueError("No NIFTI files found in the reference directory.")

# Create a covariates file. The model needs to know the 'site' name.
covars = pd.DataFrame({
    'SITE': ['reference_site'] * len(filepaths),
})

# # 3. Feature Extraction (Brain Masking and Flattening)
# print("Extracting brain voxels from reference images...")
# masked_data = [nib.load(f).get_fdata().flatten() for f in filepaths]
# image_data = np.vstack(masked_data)
first_img_data = nib.load(filepaths[0]).get_fdata()
brain_voxel_indices = first_img_data != 0

# Now, extract the data from all images using these indices
image_data_list = []
for f in filepaths:
    img_data = nib.load(f).get_fdata()
    image_data_list.append(img_data[brain_voxel_indices])
    
image_data = np.vstack(image_data_list)

# 4. Learn the Harmonization Model
print("Learning harmonization model from reference data...")
# The 'image_data_harmonized' is not needed here, we just need the model itself.
model, _ = harmonizationLearn(image_data, covars, 'SITE')

Learning harmonization model from reference data...


  s_data = ((X- stand_mean - mod_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, n_sample))))


In [21]:
model_path = '/home/dual4090/lab/github/Synthetic_7T_MRI_release/src/dataloader/harmonize.pkl'
with open(model_path, 'wb') as f:
        pickle.dump(model, f)
print(f"Harmonization model successfully saved to: {model_path}")

Harmonization model successfully saved to: /home/dual4090/lab/github/Synthetic_7T_MRI_release/src/dataloader/harmonize.pkl


In [22]:
indices_path = os.path.join(os.path.dirname(model_path), 'brain_voxel_indices.pkl')
with open(indices_path, 'wb') as f:
    pickle.dump(brain_voxel_indices, f)

In [31]:
# stage2_apply_harmonization_at_inference.py
import os
import pandas as pd
import numpy as np
import pickle
import nibabel as nib
from scipy.ndimage import zoom
from neuroHarmonize import harmonizationApply

def resize_3d_matrix(input_matrix, target_shape):
    original_shape = np.array(input_matrix.shape)
    zoom_factors = np.array(target_shape) / original_shape
    return zoom(input_matrix, zoom_factors,order=3)

def adjust_affine(original_affine, original_shape, new_shape):
    """
    Adjusts the affine matrix for a resized image.
    This ensures the image metadata correctly reflects the new voxel sizes.
    """
    original_spacing = np.sqrt(np.sum(original_affine[:3, :3]**2, axis=0))
    new_spacing = original_spacing * (np.array(original_shape) / np.array(new_shape))
    
    new_affine = np.copy(original_affine)
    np.fill_diagonal(new_affine, list(new_spacing) + [1])
    # To keep orientation, we copy the rotation part of the affine
    normalized_original_affine = original_affine[:3, :3] / original_spacing
    new_affine[:3, :3] = normalized_original_affine * new_spacing

    return new_affine

def harmonize_scans_at_inference(new_scans_dir, model_path, indices_path, output_dir):
    print("--- Inference Time Harmonization ---")
    os.makedirs(output_dir, exist_ok=True)

    # 1. Load the Pre-trained Model and Voxel Indices from Stage 1
    print(f"Loading model from {model_path}...")
    with open(model_path, 'rb') as f:
        model = pickle.load(f)

    print(f"Loading brain voxel indices from {indices_path}...")
    with open(indices_path, 'rb') as f:
        brain_voxel_indices = pickle.load(f)
    
    # Define the target shape for consistency
    target_shape = (260, 311, 260)
    
    # --- MAJOR CHANGE: Verify mask shape once ---
    # The mask itself must have the target shape. This is a critical check.
    if brain_voxel_indices.shape != target_shape:
        raise ValueError(f"The brain_voxel_indices mask shape {brain_voxel_indices.shape} "
                         f"does not match the target shape {target_shape}. "
                         "The model must be trained on resized data.")

    # 2. Prepare New Data
    filepaths = [os.path.join(new_scans_dir, f) for f in os.listdir(new_scans_dir) if f.endswith(('.nii.gz', '.nii'))]
    if not filepaths:
        raise ValueError("No NIFTI files found in the new scans directory.")

    print(f"Found {len(filepaths)} new scan(s) to harmonize.")

    # 3. Resize, then Extract Features
    print("Resizing scans and extracting brain voxels...")
    image_data_list = []
    # Store original nifti objects to get affine later
    original_niftis = {}

    for f in filepaths:
        try:
            original_nifti = nib.load(f)
            original_niftis[f] = original_nifti
            
            # --- MAJOR CHANGE: Resize FIRST ---
            resized_img_data = resize_3d_matrix(original_nifti.get_fdata(), target_shape)
            
            # Now the resized image and the mask have the same shape.
            # We apply the mask to the RESIZED image.
            image_data_list.append(resized_img_data[brain_voxel_indices])
        except Exception as e:
            print(f"Error processing file {f}: {e}")
            continue
            
    image_data = np.vstack(image_data_list)
    covars = pd.DataFrame({'SITE': ['new_site'] * len(image_data)})

    # 4. Apply Harmonization (The "Inference" Step)
    print("Applying harmonization model...")
    # --- FIX: Added batch_col='SITE' argument ---
    harmonized_data = harmonizationApply(data=image_data, covars=covars, model=model)

    # 5. Reconstruct and Save Harmonized Images at TARGET SHAPE
    print("Reconstructing and saving harmonized images...")
    for i, filepath in enumerate(filepaths):
        original_nifti = original_niftis[filepath]
        
        # --- MAJOR CHANGE: Reconstruct to the TARGET shape, not the original shape ---
        # Create a new 3D array filled with zeros with the standard target shape
        reconstructed_img_data = np.zeros(target_shape, dtype=np.float32)

        # Place the harmonized voxel data back into the correct locations in the new array
        reconstructed_img_data[brain_voxel_indices] = harmonized_data[i]
        
        # --- MAJOR CHANGE: Adjust the affine matrix for the new dimensions ---
        new_affine = adjust_affine(original_nifti.affine, original_nifti.shape, target_shape)

        # Create a new NIFTI image with the resized data and new affine
        harmonized_nifti = nib.Nifti1Image(reconstructed_img_data, affine=new_affine)

        base_filename = os.path.basename(filepath)
        output_filename = os.path.join(output_dir, f"harmonized_{base_filename}")
        nib.save(harmonized_nifti, output_filename)


if __name__ == '__main__':
    # --- USER: DEFINE THESE PATHS FOR EACH INFERENCE RUN ---

    # Directory containing the new, unharmonized scans you want to process
    new_scans_to_harmonize_dir = '/mnt/hdd0/download/sample_GRIP_data/Input/BrainStripped/'

    # Path to the model files you created in Stage 1
    model_file_path = r'/home/dual4090/lab/github/Synthetic_7T_MRI_release/src/dataloader/harmonize.pkl'
    indices_file_path = r'/home/dual4090/lab/github/Synthetic_7T_MRI_release/src/dataloader/brain_voxel_indices.pkl'

    # Directory where the final, harmonized scans will be saved
    output_directory = '/mnt/hdd0/download/sample_GRIP_data/Input/harmonized/'
    
    # -----------------------------------------------------------------

    harmonize_scans_at_inference(new_scans_to_harmonize_dir, model_file_path, indices_file_path, output_directory)

--- Inference Time Harmonization ---
Loading model from /home/dual4090/lab/github/Synthetic_7T_MRI_release/src/dataloader/harmonize.pkl...
Loading brain voxel indices from /home/dual4090/lab/github/Synthetic_7T_MRI_release/src/dataloader/brain_voxel_indices.pkl...
Found 6 new scan(s) to harmonize.
Resizing scans and extracting brain voxels...
Applying harmonization model...


  s_data = ((X- stand_mean - mod_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, n_sample))))


Reconstructing and saving harmonized images...
