In [None]:
import SimpleITK as sitk
import numpy as np
import os
from PIL import Image
from skimage.measure import label, regionprops

def window_transform(sitk_image, win_width=80, win_center=40):
    min_intensity = int(win_center - win_width / 2.0)
    max_intensity = int(win_center + win_width / 2.0)
    intensity_window = sitk.IntensityWindowingImageFilter()
    intensity_window.SetWindowMaximum(max_intensity)
    intensity_window.SetWindowMinimum(min_intensity)
    return intensity_window.Execute(sitk_image)

def rescale_image(image, scale_factor):
    original_size = np.array(image.GetSize(), dtype=int)
    new_size = (original_size * scale_factor).astype(int)
    resample = sitk.ResampleImageFilter()
    original_spacing = image.GetSpacing()
    new_spacing = tuple([spacing / scale_factor for spacing in original_spacing])
    resample.SetOutputSpacing(new_spacing)
    resample.SetSize([int(sz) for sz in new_size])
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetOutputDirection(image.GetDirection())
    resample.SetInterpolator(sitk.sitkBSpline)
    return resample.Execute(image)

def get_largest_connected_component(mask_array):
    labeled_mask = label(mask_array)
    
    largest_component = 0
    max_area = 0
    for region in regionprops(labeled_mask):
        if region.area > max_area:
            max_area = region.area
            largest_component = region.label
    
    largest_mask = np.zeros_like(mask_array)
    largest_mask[labeled_mask == largest_component] = 1
    return largest_mask

def crop_and_save(image, mask, save_path, filename, win_width=90, win_center=45, scale_factor=3.5, save_as_rgb=True):
    image = rescale_image(image, scale_factor)
    mask = rescale_image(mask, scale_factor)

    img_array = sitk.GetArrayFromImage(image)
    mask_array = sitk.GetArrayFromImage(mask)

    z_indices = np.where(np.sum(mask_array, axis=(1, 2)) > 0)[0]
    max_slice_index = z_indices[np.argmax(np.sum(mask_array[z_indices], axis=(1, 2)))]
    
    largest_mask_array = get_largest_connected_component(mask_array)

    label_dict = {max_slice_index - 1: "up", max_slice_index: "max", max_slice_index + 1: "low"}
    slices_to_crop = [max_slice_index - 1, max_slice_index, max_slice_index + 1]

    for slice_index in slices_to_crop:
        if slice_index < 0 or slice_index >= mask_array.shape[0]:
            continue
        masked_slice = largest_mask_array[slice_index]
        if np.sum(masked_slice) == 0:
            continue

        rows = np.any(masked_slice, axis=1)
        cols = np.any(masked_slice, axis=0)
        rmin, rmax = np.where(rows)[0][[0, -1]]
        cmin, cmax = np.where(cols)[0][[0, -1]]

        img_slice = img_array[slice_index]
        cropped_img = img_slice[rmin:rmax+1, cmin:cmax+1]

        sitk_cropped_img = sitk.GetImageFromArray(cropped_img)
        sitk_cropped_img = window_transform(sitk_cropped_img, win_width, win_center)

        cropped_img_array = sitk.GetArrayFromImage(sitk_cropped_img)
        
        img_pil = Image.fromarray(cropped_img_array)
        img_pil = img_pil.convert('L')
        
        if save_as_rgb:
            img_pil = Image.merge("RGB", (img_pil, img_pil, img_pil))

        slice_label = label_dict.get(slice_index, f"slice_{slice_index}")
        img_pil.save(os.path.join(save_path, f'{filename}_{slice_label}.png'))

def process_all_patients(image_dir, mask_dir, save_path, save_as_rgb=True):
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    image_files = [f for f in os.listdir(image_dir) if f.endswith('.nii.gz')]
    
    total_files = len(image_files)
    processed = 0
    failed = []
    
    for image_file in image_files:
        processed += 1
        
        image_path = os.path.join(image_dir, image_file)
        mask_path = os.path.join(mask_dir, image_file)
        
        if not os.path.exists(mask_path):
            failed.append(image_file)
            continue
        
        try:
            image = sitk.ReadImage(image_path)
            mask = sitk.ReadImage(mask_path)
            
            filename_without_ext = image_file.replace('.nii.gz', '')
            
            crop_and_save(image, mask, save_path, filename_without_ext, save_as_rgb=save_as_rgb)
            
        except Exception as e:
            failed.append(image_file)
    
    print(f"Processing completed: {total_files - len(failed)}/{total_files} succeeded")
    
    if failed:
        print(f"Failed files: {', '.join(failed)}")

def convert_existing_grayscale_to_rgb(input_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    converted_count = 0
    
    for filename in os.listdir(input_dir):
        if filename.endswith('.png') or filename.endswith('.jpg') or filename.endswith('.jpeg'):
            try:
                img_path = os.path.join(input_dir, filename)
                img = Image.open(img_path).convert('L')
                
                img_rgb = Image.merge("RGB", (img, img, img))
                
                output_path = os.path.join(output_dir, filename)
                img_rgb.save(output_path)
                converted_count += 1
                
            except:
                pass
    
    print(f"Conversion completed: {converted_count} images")

if __name__ == "__main__":
    image_dir = r'images'
    mask_dir = r'masks'
    save_path = r'cropped_images'
    
    process_all_patients(image_dir, mask_dir, save_path, save_as_rgb=True)