In [None]:
import os
import numpy as np
import nibabel as nib
import SimpleITK as sitk
from scipy.ndimage import zoom

def normalize_intensity(img, mask=None):
    if mask is not None:
        img = img[mask > 0]
    mean = np.mean(img)
    std = np.std(img)
    return (img - mean) / std if std != 0 else img * 0

def resample_to_isotropic(img_path, target_spacing=(1.0, 1.0, 1.0)):
    img_sitk = sitk.ReadImage(img_path)
    original_spacing = img_sitk.GetSpacing()
    original_size = img_sitk.GetSize()

    new_size = [
        int(round(osz * ospc / tspc))
        for osz, ospc, tspc in zip(original_size, original_spacing, target_spacing)
    ]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(target_spacing)
    resample.SetSize(new_size)
    resample.SetOutputDirection(img_sitk.GetDirection())
    resample.SetOutputOrigin(img_sitk.GetOrigin())
    resample.SetInterpolator(sitk.sitkLinear)

    return resample.Execute(img_sitk)

def n4_bias_correction(img_sitk):
    mask_image = sitk.OtsuThreshold(img_sitk, 0, 1, 200)
    corrector = sitk.N4BiasFieldCorrectionImageFilter()
    return corrector.Execute(img_sitk, mask_image)

def simple_skull_strip(img_np):
    threshold = np.percentile(img_np, 10)
    return np.where(img_np > threshold, img_np, 0)

def sitk_to_numpy(img_sitk):
    return sitk.GetArrayFromImage(img_sitk).transpose(2, 1, 0)

def numpy_to_sitk(img_np, reference_sitk):
    img_np = img_np.transpose(2, 1, 0)
    out = sitk.GetImageFromArray(img_np)
    out.CopyInformation(reference_sitk)
    return out

# ======================================================
# 🔄 Funcții pentru registrare rigidă
# ======================================================
def rigid_register_images(fixed_path, moving_path):
    fixed = sitk.ReadImage(fixed_path, sitk.sitkFloat32)
    moving = sitk.ReadImage(moving_path, sitk.sitkFloat32)

    initial_transform = sitk.CenteredTransformInitializer(
        fixed, moving, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY
    )

    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMattesMutualInformation(32)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    registration_method.SetInterpolator(sitk.sitkLinear)
    registration_method.SetOptimizerAsGradientDescent(1.0, 100, 1e-6, 10)
    registration_method.SetOptimizerScalesFromPhysicalShift()
    registration_method.SetInitialTransform(initial_transform, inPlace=False)

    return registration_method.Execute(fixed, moving)

def apply_transform(image_path, transform, reference_path, is_label=False):
    image = sitk.ReadImage(image_path, sitk.sitkFloat32 if not is_label else sitk.sitkUInt8)
    reference = sitk.ReadImage(reference_path)
    interpolator = sitk.sitkNearestNeighbor if is_label else sitk.sitkLinear
    return sitk.Resample(image, reference, transform, interpolator, 0.0, image.GetPixelID())

# ======================================================
# 💾 Salvare imagine în NIfTI
# ======================================================
def save_nifti(image_sitk, output_path):
    sitk.WriteImage(image_sitk, output_path)
    print(f"✅ Salvat: {output_path}")

# ======================================================
# 🚀 Pipeline complet: registrare flair + mască
# ==============
def register_and_save_all(fixed_flair_path, moving_flair_path, moving_mask_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    print("🔄 Registrare rigidă între time01 și time02...")

    # === Resampling la isotropic pentru ambele imagini ===
    fixed_resampled = resample_to_isotropic(fixed_flair_path)
    moving_resampled = resample_to_isotropic(moving_flair_path)

    # === Bias correction + skull strip + normalize (FIXED) ===
    fixed_bc = n4_bias_correction(fixed_resampled)
    fixed_np = sitk_to_numpy(fixed_bc)
    fixed_np = simple_skull_strip(fixed_np)
    fixed_np = normalize_intensity(fixed_np)
    fixed_final = numpy_to_sitk(fixed_np, fixed_bc)

    # === Bias correction + skull strip + normalize (MOVING) ===
    moving_bc = n4_bias_correction(moving_resampled)
    moving_np = sitk_to_numpy(moving_bc)
    moving_np = simple_skull_strip(moving_np)
    moving_np = normalize_intensity(moving_np)
    moving_final = numpy_to_sitk(moving_np, moving_bc)

    # === Salvăm temporar preprocesările pentru transformare ===
    tmp_fixed_path = os.path.join(output_dir, "tmp_fixed.nii.gz")
    tmp_moving_path = os.path.join(output_dir, "tmp_moving.nii.gz")
    sitk.WriteImage(fixed_final, tmp_fixed_path)
    sitk.WriteImage(moving_final, tmp_moving_path)

    # === Estimează transformarea rigidă ===
    transform = rigid_register_images(tmp_fixed_path, tmp_moving_path)

    # === Aplică transformarea pe imaginea MOVING originală ===
    registered_flair = apply_transform(moving_flair_path, transform, fixed_flair_path, is_label=False)

    # === Postprocesare pe FLAIR înregistrat ===
    reg_np = sitk_to_numpy(registered_flair)
    reg_np = simple_skull_strip(reg_np)
    reg_np = normalize_intensity(reg_np)
    registered_flair = numpy_to_sitk(reg_np, registered_flair)

    # === Aplic transformarea pe mască, fără alterare ===
    registered_mask = apply_transform(moving_mask_path, transform, fixed_flair_path, is_label=True)

    # === Salvare rezultate ===
    flair_out_path = os.path.join(output_dir, "flair_time02_registered_to_time01.nii.gz")
    mask_out_path = os.path.join(output_dir, "mask_time02_registered_to_time01.nii.gz")
    save_nifti(registered_flair, flair_out_path)
    save_nifti(registered_mask, mask_out_path)

    # Curățare temporare
    os.remove(tmp_fixed_path)
    os.remove(tmp_moving_path)

    return flair_out_path, mask_out_path
def batch_register_all_patients(root_dir):
    empty_or_invalid = 0
    processed = 0

    for folder in os.scandir(root_dir):
        if not folder.is_dir():
            continue

        folder_path = folder.path
        flair_time01 = os.path.join(folder_path, 'flair_time01_on_middle_space.nii.gz')
        flair_time02 = os.path.join(folder_path, 'flair_time02_on_middle_space.nii.gz')
        mask_time02  = os.path.join(folder_path, 'ground_truth.nii.gz')

        if not (os.path.exists(flair_time01) and os.path.exists(flair_time02) and os.path.exists(mask_time02)):
            print(f"⚠️ Sar peste {folder.name} — lipsesc fișiere necesare.")
            empty_or_invalid += 1
            continue

        try:
            print(f"\n📂 Procesez pacientul: {folder.name}")
            register_and_save_all(flair_time01, flair_time02, mask_time02, folder_path)
            processed += 1
        except Exception as e:
            print(f"❌ Eroare la procesarea {folder.name}: {e}")
            empty_or_invalid += 1

    print(f"\n✅ Pacienți procesați: {processed}")
    print(f"❌ Cazuri omise sau cu erori: {empty_or_invalid}")


root_dir = '/content/drive/MyDrive/training_small/'
batch_register_all_patients(root_dir)