In [None]:
import os
import glob
import re
import pandas as pd
import numpy as np
import SimpleITK as sitk
import scipy.ndimage as nd
from tqdm.auto import tqdm

In [None]:
def collapse_4D_to_3D(arr):

    axis_sizes = np.array(arr.shape)
    channel_axis = int(np.argmin(axis_sizes))
    return np.max(arr, axis=channel_axis)

#capture segmentation metadata
name_pattern    = re.compile(r'Segment(\d+)_Name')
id_pattern      = re.compile(r'Segment(\d+)_ID')
label_pattern   = re.compile(r'Segment(\d+)_LabelValue')
channel_pattern = re.compile(r'Segment(\d+)_Layer')

def match_key(key, meta_image):
    m = name_pattern.fullmatch(key)
    if m is not None:
        return int(m.group(1)), "Segment Name", meta_image.GetMetaData(key)
    
    m = id_pattern.fullmatch(key)
    if m is not None:
        return int(m.group(1)), "Segment ID", meta_image.GetMetaData(key)
    
    m = label_pattern.fullmatch(key)
    if m is not None:
        return int(m.group(1)), "LabelValue", int(meta_image.GetMetaData(key))
    
    m = channel_pattern.fullmatch(key)
    if m is not None:
        return int(m.group(1)), "Layer", int(meta_image.GetMetaData(key))
    
    return None


def parse_slicer_seg_metadata(sitk_image):
    segments = {}
    for key in sitk_image.GetMetaDataKeys():
        m = match_key(key, sitk_image)
        if m is None:
            continue
        seg_index, field, value = m
        if seg_index not in segments:
            segments[seg_index] = {}
        segments[seg_index][field] = value
    return segments

def write_slicer_seg_metadata(sitk_out, segments):
    for seg_index in sorted(segments.keys()):
        seg_info = segments[seg_index]
        if seg_info.get("erodedAway", False):
            continue  # Skip segments that became empty.

        base_idx = f"Segment{seg_index}"
        if "Segment Name" in seg_info:
            sitk_out.SetMetaData(f"{base_idx}_Name", str(seg_info["Segment Name"]))
        
        if "Segment ID" in seg_info:
            sitk_out.SetMetaData(f"{base_idx}_ID", str(seg_info["Segment ID"]))
        
        if "LabelValue" in seg_info:
            sitk_out.SetMetaData(f"{base_idx}_LabelValue", str(seg_info["LabelValue"]))
        
        if "Layer" in seg_info:
            sitk_out.SetMetaData(f"{base_idx}_Layer", str(seg_info["Layer"]))
        
         sitk_out.SetMetaData(f"{base_idx}_Index", str(seg_index))


def erode_one_voxel_6conn_3D(binary_arr):
    struct_6conn = np.zeros((3, 3, 3), dtype=bool)
    struct_6conn[1, 1, 1] = True
    struct_6conn[0, 1, 1] = True  # -z
    struct_6conn[2, 1, 1] = True  # +z
    struct_6conn[1, 0, 1] = True  # -y
    struct_6conn[1, 2, 1] = True  # +y
    struct_6conn[1, 1, 0] = True  # -x
    struct_6conn[1, 1, 2] = True  # +x

    eroded = nd.binary_erosion(binary_arr, structure=struct_6conn, iterations=1, border_value=0)
    return eroded.astype(binary_arr.dtype)

def contract_segmentation(in_path, out_path, lost_lesions_df, scan_name=None):
    img = sitk.ReadImage(in_path)
    arr = sitk.GetArrayFromImage(img)
    seg_info = parse_slicer_seg_metadata(img)
    
    if len(seg_info) == 0:
        print(f"WARNING: No segment metadata found in {in_path}")
    
    is_4D = (arr.ndim == 4)
    # If 4D but the layer values are non-unique, collapse to 3D multi-label.
    if is_4D:
        layers = [info.get("Layer") for info in seg_info.values() if "Layer" in info]
        if len(layers) != len(set(layers)):
            print(f"WARNING: Non-unique layer values in 4D segmentation; treating as 3D multi-label for {in_path}")
            is_4D = False
            arr = collapse_4D_to_3D(arr)
    
    arr_out = np.zeros_like(arr, dtype=arr.dtype)
    lost_segments = []  

    if not is_4D:
        # --- 3D Multi-Label Approach ---
        label_to_segIndex = {}
        for sIdx, info in seg_info.items():
            if "LabelValue" in info:
                lbl = info["LabelValue"]
                label_to_segIndex[lbl] = sIdx

        unique_labels = np.unique(arr)
        for lbl in unique_labels:
            if lbl == 0:
                continue
            if lbl not in label_to_segIndex:
                continue

            seg_index = label_to_segIndex[lbl]
            mask = (arr == lbl)
            eroded = erode_one_voxel_6conn_3D(mask)

            if np.any(eroded):
                arr_out[eroded == 1] = lbl
            else:
                seg_info[seg_index]["erodedAway"] = True
                seg_name = seg_info[seg_index].get("Segment Name", f"Segment{seg_index}")
                lost_segments.append((in_path, scan_name, seg_index, seg_name, seg_info[seg_index].get("Segment ID", "?")))
    else:
        # --- 4D Layered Approach ---
        # Determine which axis is the channel axis.
        num_segments = len(seg_info)
        channel_last = (arr.shape[-1] == num_segments)
        channel_first = (arr.shape[0] == num_segments)
        
        for sIdx, info in seg_info.items():
            if "Layer" not in info:
                continue
            layer = info["Layer"]
            if channel_first:
                mask = (arr[layer] != 0)
                eroded = erode_one_voxel_6conn_3D(mask)
                if np.any(eroded):
                    arr_out[layer] = eroded
                else:
                    info["erodedAway"] = True
                    seg_name = info.get("Segment Name", f"Segment{sIdx}")
                    lost_segments.append((in_path, scan_name, sIdx, seg_name, info.get("Segment ID", "?")))
            elif channel_last:
                mask = (arr[:,:,:,layer] != 0)
                eroded = erode_one_voxel_6conn_3D(mask)
                if np.any(eroded):
                    arr_out[:,:,:,layer] = eroded
                else:
                    info["erodedAway"] = True
                    seg_name = info.get("Segment Name", f"Segment{sIdx}")
                    lost_segments.append((in_path, scan_name, sIdx, seg_name, info.get("Segment ID", "?")))
            else:
                print(f"WARNING: Unable to determine channel axis for 4D segmentation in {in_path}")
    
    out_img = sitk.GetImageFromArray(arr_out)
    out_img.CopyInformation(img)
    
    for key in img.GetMetaDataKeys():
        if not re.match(r'Segment\d+_', key):
            out_img.SetMetaData(key, img.GetMetaData(key))
    
    write_slicer_seg_metadata(out_img, seg_info)
    
    sitk.WriteImage(out_img, out_path, True)
    
    for (pth, scanN, seg_idx, segName, segID) in lost_segments:
        lost_lesions_df.loc[len(lost_lesions_df)] = {
            'Path': pth,
            'Scan_Name': scanN,
            'Segment index': seg_idx,
            'Segment Name': segName,
            'Segment ID': segID
        }
    
    dimensionality = "3D" if out_img.GetDimension() == 3 else "4D"
    return seg_info, dimensionality


def example_process(df, out_dir, out_excel_all, out_excel_lost):

    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)

    lost_lesions_df = pd.DataFrame(columns=['Path', 'Scan_Name', 'Segment index', 'Segment Name', 'Segment ID'])
    all_segments_data = []

    for idx, row in tqdm(df.iterrows(), total=df.shape[0], desc='Processing rows'):
        try:
            mask_path = row['Mask']
            image_path = row['Image']
            
            if not os.path.isfile(mask_path):
                print(f"Row {idx}: mask not found at {mask_path}")
                continue
            
            scan_name = row.get('Scan_Name', None)
            if not scan_name:
                scan_name = os.path.basename(os.path.dirname(mask_path))
            anon_name = scan_name[:14]
            
            base_name = os.path.basename(mask_path)
            root, ext = os.path.splitext(base_name)
            if root.endswith('.seg'):
                root = root[:-4]
            out_name = f"{root}_contracted.seg.nrrd"
            out_path = os.path.join(out_dir, scan_name, out_name)
            os.makedirs(os.path.join(out_dir, scan_name), exist_ok=True)
            
            seg_info, dimensionality = contract_segmentation(mask_path, out_path, lost_lesions_df, scan_name=scan_name)
            
            for seg_idx, seg in seg_info.items():
                record = {
                    'Anon_Name': anon_name,
                    'Scan_Name': scan_name,
                    'Segment index': seg_idx,
                    'Segment Name': seg.get("Segment Name", ""),
                    'Segment ID': seg.get("Segment ID", ""),
                    'Label': seg.get("LabelValue", ""),
                    'Label_channel': seg.get("Layer", ""),
                    'Dimensionality': dimensionality,
                    'Image': image_path,
                    'Mask': mask_path,
                    'ErodedAway': seg.get("erodedAway", False)
                }
                all_segments_data.append(record)
        except Exception as e:
            print(f"Row {idx} failed: {e}")
            continue

    all_segments_df = pd.DataFrame(all_segments_data)
    cols = ['Anon_Name', 'Scan_Name', 'Segment index', 'Segment Name', 'Segment ID',
            'Label', 'Label_channel', 'Dimensionality', 'Image', 'Mask', 'ErodedAway']
    all_segments_df = all_segments_df[cols]
    all_excel_path = os.path.join(out_dir, out_excel_all + '.xlsx')
    all_segments_df.to_excel(all_excel_path, index=False)
    print(f"All segments summary saved to {all_excel_path}")

    lost_excel_path = os.path.join(out_dir, out_excel_lost + '.xlsx')
    lost_lesions_df.to_excel(lost_excel_path, index=False)
    print(f"Lost lesions summary saved to {lost_excel_path}")

In [None]:
seg_files = sorted(glob.glob(r"LOCAL PATH"))

output_dir = r"LOCAL PATH"

df = pd.DataFrame({'Image': seg_files, 'Mask': seg_files})

df['Scan_Name'] = df['Image'].apply(lambda x: os.path.basename(os.path.dirname(x)))

example_process(df, output_dir, out_excel_all="all_segments_summary", out_excel_lost="lost_segments_summary")