In [None]:
import os
import nibabel as nib
import numpy as np

# input path
label_dir = "labels"
result_dir = "predict_results"

# output path
out_label_root = "split_label"
out_result_root = "split_predict_results"

# label to name mapping
def label_to_name(label):
    if 1 <= label <= 12:
        return f"rib_left_{label}"
    elif 13 <= label <= 24:
        return f"rib_right_{label - 12}"
    elif label == 25:
        return "sternum"
    elif label == 26:
        return "costal_cartilages"
    else:
        return f"unknown_label_{label}"


# split function (automatically create subdirectories)
def split_labels(input_path, output_root):
    nii = nib.load(input_path)
    data = nii.get_fdata()
    unique_labels = np.unique(data).astype(int)

    patient_id = os.path.basename(input_path).replace(".nii.gz", "")
    out_dir = os.path.join(output_root, patient_id)
    os.makedirs(out_dir, exist_ok=True)

    for lbl in unique_labels:
        if lbl == 0:
            continue  # ignore background
        binary_mask = (data == lbl).astype(np.uint8)
        out_img = nib.Nifti1Image(binary_mask, nii.affine, nii.header)
        out_path = os.path.join(out_dir, f"{label_to_name(lbl)}.nii.gz")
        nib.save(out_img, out_path)


for fname in os.listdir(result_dir):
    if fname.endswith(".nii.gz"):
        split_labels(os.path.join(result_dir, fname), out_result_root)
