## Postprocessing of masks to get rid of a few mask prediction errors

In [None]:
import os
import re
import numpy as np
import nibabel as nib
import SimpleITK as sitk
from pathlib import Pathb
from skimage.measure import label
from scipy.ndimage import distance_transform_edt
from scipy.ndimage import label, binary_dilation

def process_slice(slice_mask, midline):
    """Fix wrong-side label assignments in a 2D slice."""
    processed_mask = slice_mask.copy()
    
    for label_val in range(1, 5):
        if label_val in [1, 2]:  # These should be on left side (anatomical right)
            correct_side = 'left'
            potential_new_labels = [3, 4]
        else:  # These should be on right side (anatomical left)
            correct_side = 'right'
            potential_new_labels = [1, 2]
        
        binary = (slice_mask == label_val)
        if not np.any(binary):
            continue
        
        components = label(binary)
        
        for comp_idx in range(1, components.max() + 1):
            comp_mask = (components == comp_idx)
            cols = np.where(np.any(comp_mask, axis=0))[0]
            if len(cols) == 0:
                continue
            
            min_x, max_x = cols[0], cols[-1]
            is_wrong_side = False
            
            if correct_side == 'left' and min_x > midline:
                is_wrong_side = True
            elif correct_side == 'right' and max_x < midline:
                is_wrong_side = True
            
            if not is_wrong_side or (min_x < midline and max_x > midline):
                continue
            
            min_dist = float('inf')
            best_label = None
            
            for new_label in potential_new_labels:
                target_mask = (slice_mask == new_label)
                if not np.any(target_mask):
                    continue
                
                dist_map = distance_transform_edt(~target_mask)
                dist_to_label = dist_map[comp_mask].min()
                
                if dist_to_label < min_dist and dist_to_label <= 20:
                    min_dist = dist_to_label
                    best_label = new_label
            
            if best_label is not None:
                processed_mask[comp_mask] = best_label
    
    return processed_mask

def process_volume(input_path):
    """Process a 3D volume to correct wrong-side labels and save."""
    print(f"\nProcessing: {os.path.basename(input_path)}")
    
    img = nib.load(str(input_path))
    seg = img.get_fdata().astype(np.uint8)
    processed_volume = np.zeros_like(seg)
    midline = seg.shape[1] // 2
    total_changes = 0
    
    for z in range(seg.shape[2]):
        orig_slice = seg[:, :, z]
        slice_mask = np.rot90(orig_slice, k=-1)
        slice_mask = np.fliplr(slice_mask)
        
        processed_slice = process_slice(slice_mask, midline)
        
        processed_slice = np.fliplr(processed_slice)
        processed_slice = np.rot90(processed_slice, k=1)
        
        processed_volume[:, :, z] = processed_slice
        total_changes += np.sum(processed_slice != orig_slice)
    
    output_path = str(input_path).replace('.nii.gz', '_postprocessed.nii.gz')
    processed_img = nib.Nifti1Image(processed_volume, img.affine, img.header)
    nib.save(processed_img, output_path)
    
    print(f"  Total changes: {total_changes}")
    print(f"  Avg changes per slice: {total_changes / seg.shape[2]:.2f}")
    print(f"  Saved to: {os.path.basename(output_path)}")
    
    return total_changes

# === Main script ===
folder_path = r"-----INSERT PATH HERE -----"
os.makedirs(folder_path, exist_ok=True)
nii_files = list(Path(folder_path).glob("*.nii.gz"))
nii_files = [f for f in nii_files if not str(f).endswith("_postprocessed.nii.gz")]

print(f"Found {len(nii_files)} .nii.gz files to process")

total_files_changes = {}
for file_path in nii_files:
    try:
        changes = process_volume(file_path)
        total_files_changes[os.path.basename(file_path)] = changes
    except Exception as e:
        print(f"❌ Error processing {file_path}: {str(e)}")

# === Summary ===
print("\nProcessing Summary:")
print(f"Successfully processed {len(total_files_changes)} files")
print("Files with most changes:")
for filename, changes in sorted(total_files_changes.items(), key=lambda x: x[1], reverse=True)[:5]:
    print(f"  {filename}: {changes} changes")


## Post-postprocessing all masks to remove the isolated islands

In [None]:
# === Utility: Get largest connected component ===
def get_largest_component(mask):
    labeled, num = label(mask)
    if num == 0:
        return mask
    sizes = np.bincount(labeled.ravel())
    sizes[0] = 0
    largest_label = sizes.argmax()
    return labeled == largest_label

# === Utility: Assign island to closest neighbor label ===
def reassign_to_largest_border_component(island_mask, full_mask):
    dilated = binary_dilation(island_mask, iterations=1)
    border_voxels = dilated & (full_mask > 0) & (~island_mask)
    neighbor_labels, counts = np.unique(full_mask[border_voxels], return_counts=True)
    if len(counts) == 0:
        return 0
    return neighbor_labels[np.argmax(counts)]

# === Process all *_postprocessed.nii.gz ===
for fname in os.listdir(folder_path):
    if not fname.endswith("_postprocessed.nii.gz"):
        continue

    in_path = os.path.join(folder_path, fname)
    out_path = in_path.replace("_postprocessed.nii.gz", "_relabelled.nii.gz")

    nii = nib.load(in_path)
    data = nii.get_fdata().astype(np.uint8)
    modified_slices = set()

    for lbl in [1, 2, 3]:
        mask = (data == lbl)
        labeled_cc, num = label(mask)
        if num <= 1:
            continue

        sizes = np.bincount(labeled_cc.ravel())
        sizes[0] = 0
        main_cc = sizes.argmax()

        for i in range(1, num + 1):
            if i == main_cc:
                continue
            island = (labeled_cc == i)
            new_lbl = reassign_to_largest_border_component(island, data)
            data[island] = new_lbl
            changed_z = np.unique(np.where(island)[2])
            modified_slices.update(changed_z)

    # Save result
    new_nii = nib.Nifti1Image(data, nii.affine, nii.header)
    new_nii.set_data_dtype(np.uint8)
    nib.save(new_nii, out_path)

    print(f"\n🧠 Processed: {fname}")
    print(f"✅ Saved: {os.path.basename(out_path)}")
    print(f"✏️ Modified slices: {sorted(modified_slices)}")
    print(f"🔎 Unique labels: {np.unique(data)}")


## Converting the nii.gz files into mhd for submission

In [None]:
# === Input and Output Folders ===
input_dir = r"----- INSERT PATH HERE ------"
output_dir = os.path.join(input_dir, "Pre_postprocessing")
os.makedirs(output_dir, exist_ok=True)

# === Regex: Match files like 01.nii.gz, 42.nii.gz (but NOT 42_postprocessed.nii.gz)
pattern = re.compile(r"^\d+\.nii\.gz$")

# === Convert Only Matching Files ===
for fname in os.listdir(input_dir):
    if pattern.match(fname):
        path_nifti = os.path.join(input_dir, fname)
        img = sitk.ReadImage(path_nifti)

        # Extract numeric ID (e.g., '04' from '04.nii.gz')
        case_id = os.path.splitext(os.path.splitext(fname)[0])[0]
        output_base = f"lola11-{case_id.zfill(2)}"

        # Write compressed .mhd + .zraw
        output_mhd = os.path.join(output_dir, output_base + ".mhd")
        sitk.WriteImage(img, output_mhd, useCompression=True)

        print(f"✔ Converted: {fname} → {output_base}.mhd and .zraw")
    else:
        print(f"✘ Skipped: {fname}")


## Converting the labels on the mhd files

In [None]:
# Define label mapping
label_map = {
    1: 20,
    2: 21,
    3: 22,
    4: 10,
    5: 11
}

# Loop through all .mhd files
for fname in os.listdir(output_dir):
    if fname.endswith(".mhd"):
        full_path = os.path.join(output_dir, fname)
        print(f"Relabeling: {fname}")

        # Load image
        img = sitk.ReadImage(full_path)
        arr = sitk.GetArrayFromImage(img)

        # Relabel
        new_arr = np.zeros_like(arr)
        for old, new in label_map.items():
            new_arr[arr == old] = new

        # Save with same metadata
        new_img = sitk.GetImageFromArray(new_arr)
        new_img.CopyInformation(img)
        sitk.WriteImage(new_img, full_path, useCompression=True)

print("\n✅ All masks relabeled to LOLA11 format.")


## Zipping the mhd files into 1 file

In [None]:
import os
import zipfile

output_zip = rf"{output_dir}.zip"


# Create the zip file
with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add all .mhd and .zraw files
    for fname in os.listdir(output_dir):
        if fname.endswith(('.mhd', '.zraw')):
            full_path = os.path.join(output_dir, fname)
            zipf.write(full_path, arcname=fname)
            print(f"✔ Added: {fname}")

print(f"\n✅ Zip created at: {output_zip}")