In [None]:
import math
from pathlib import Path
import os, shutil
import random 

import numpy as np
import matplotlib.pyplot as plt
import PIL
import openslide
import cv2

# Setting

In [None]:
data_root_dir = "/nfs_share/students/jinhyun/TCGA/LUAD"

In [None]:

def get_svs_files(base_dir: str):
    base_path = Path(base_dir)
    svs_files = []
    
    for subdir in base_path.iterdir():
        if subdir.is_dir():
            svs_files.extend([str(file) for file in subdir.glob("*.svs")])
    
    return svs_files

In [None]:
svs_file_list = get_svs_files(data_root_dir)
svs_file_list_filtered = [x for x in svs_file_list if "-DX1." in x]

print(f"Found {len(svs_file_list)} SVS files. Obtained {len(svs_file_list_filtered)} files after filtering")

# Function definitions

In [None]:
def print_slide_info(slide, slide_filepath):
    print("Level count: %d" % slide.level_count)
    print("Level dimensions: " + str(slide.level_dimensions))
    print("Level downsamples: " + str(slide.level_downsamples))
    print("Slide dimensions (width, height): " + str(slide.dimensions))
    print("Format: " + str(slide.detect_format(slide_filepath)))
    print("Properties:")
    for prop_key in slide.properties.keys():
        print("  Property: " + str(prop_key) + ", value: " + str(slide.properties.get(prop_key)))

def slide_to_scaled_pil_image(slide, level):
    """
    Obtain scaled-down PIL image from WSI slide
    """
    
    slide_width, slide_height = slide.dimensions
    image = slide.read_region((0, 0), level, slide.level_dimensions[level]).convert("RGB")

    print(f"Origial slide size (width, height) : {slide_width}, {slide_height}")
    print(f"PIL Image size at level {level} : {image.size}")
    print(f"NumPy array shape at level {level} : {np.array(image).shape}")

    return image


def create_tissue_masks(slide, mask_level):
    """
    Creates tissue masks using Otsu's thresholding.

    Args:
        annotation_filepath (str): Path to the XML annotation file.
        mask_level (int): The level of the mask (resolution level).
    """

    # Load the slide at the specified level
    slide_thumbnail  = np.array(slide.read_region((0, 0), mask_level, slide.level_dimensions[mask_level]).convert("RGB"))


    # Create the tissue mask using Otsu's thresholding on the saturation channel
    slide_region = slide.read_region((0, 0), mask_level, slide.level_dimensions[mask_level])
    slide_rgb = cv2.cvtColor(np.array(slide_region), cv2.COLOR_RGBA2RGB)
    slide_hsv = cv2.cvtColor(slide_rgb, cv2.COLOR_RGB2HSV)
    saturation_channel = slide_hsv[:, :, 1] # 채도(Saturation) refers to intensity of colors. Tissue regions are likely to show higher saturation values
    _, tissue_mask = cv2.threshold(saturation_channel, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)


    return slide_thumbnail, tissue_mask

def draw_masks(slide_thumbnail, tissue_mask):
    fig, axes = plt.subplots(1, 2, figsize=(10, 10))
    images = [slide_thumbnail, tissue_mask]
    titles = ['Slide Thumbnail', 'Tissue Mask']

    for ax, img, title in zip(axes.ravel(), images, titles):
        cmap = 'gray' if len(img.shape) == 2 else None  # Use grayscale colormap for 2D masks
        ax.imshow(img, cmap=cmap, vmin=0, vmax=255)
        ax.set_title(title)
        ax.axis('off')

    plt.tight_layout()
    plt.show()

def extract_patches(slide, slide_thumbnail, tumor_mask, mask_level, patch_config = dict(), save_path = None):
    """
    Extracts normal and tumor patches from the slide using the provided masks.

    Args:
        slide (OpenSlide object): The whole slide image.
        slide_thumbnail (numpy array): The slide thumbnail image.
        tumor_mask (numpy array): The mask of tumor areas.
        mask_level (int): The level of the mask (resolution level).
        patch_config (dict): Dictionary containing configuration for patch extraction.
        save_path (str, optional): Directory to save extracted patches. Defaults to None.
       """

    patch_size = patch_config.get('patch_size', 448)  # Patch size at the highest resolution
    patch_output_size = patch_config.get('patch_output_size', 224)  # Patch size at the highest resolution
    
    # normal_area_threshold = patch_config.get('normal_area_threshold', 0.1) # normal mask inclusion ratio that select normal patches
    # normal_sel_ratio = patch_config.get('normal_sel_ratio', 1) # nomral patch selection ratio 
    # max_normal_patches = patch_config.get('max_normal_patches', 1000) # number limit of normal patches 

    tumor_area_threshold = patch_config.get('tumor_area_threshold', 0.8) # tumor mask inclusion ratio that select tumor patches
    max_tumor_patches = patch_config.get('max_tumor_patches', 1000) # number limit of tumor patches


    downsample_factor = round(slide.level_downsamples[mask_level])
    mask_step_size = patch_size // downsample_factor  # Step size at the mask level

    slide_width, slide_height = slide.level_dimensions[0]
    num_patches_x = slide_width // patch_size
    num_patches_y = slide_height // patch_size
    total_patches = num_patches_x * num_patches_y

    if save_path is not None:
        os.makedirs(save_path, exist_ok=True)

    tumor_patch_candidates = []
    patches_processed = 0

    ## First pass: collect tumor patches
    for i in range(num_patches_x):
        for j in range(num_patches_y):

            x_mask = i * mask_step_size
            y_mask = j * mask_step_size
            x_slide = i * patch_size
            y_slide = j * patch_size

            tumor_mask_region = tumor_mask[y_mask:y_mask + mask_step_size, x_mask:x_mask + mask_step_size]
            # normal_mask_region = normal_mask[y_mask:y_mask + mask_step_size, x_mask:x_mask + mask_step_size]
            mask_area = mask_step_size * mask_step_size * 255
            tumor_area_ratio = tumor_mask_region.sum() / mask_area
            # normal_area_ratio = normal_mask_region.sum() / mask_area

            if tumor_area_ratio > tumor_area_threshold:
                tumor_patch_candidates.append((x_slide, y_slide, x_mask, y_mask))
            
            patches_processed += 1
    
    ## random sample if too many patches
    tumor_patch_candidates_selected = random.sample(tumor_patch_candidates, max_tumor_patches) if len(tumor_patch_candidates) > max_tumor_patches else tumor_patch_candidates

    ## Second pass: Save the patches
    for x_slide, y_slide, x_mask, y_mask in tumor_patch_candidates_selected:
        patch = slide.read_region((x_slide, y_slide), 0, (patch_size, patch_size))
        # patch_resized = patch.resize((patch_output_size, patch_output_size))
        if save_path:
            patch.save(f"{save_path}/t_{str(x_slide)}_{str(y_slide)}.png")
        cv2.rectangle(slide_thumbnail, (x_mask, y_mask), (x_mask + mask_step_size, y_mask + mask_step_size), (0, 0, 255), 2)
    
        ## For debugging
        # draw_patches(np.array(patch), slide_thumbnail[y_mask:y_mask + mask_step_size, x_mask:x_mask + mask_step_size], tumor_mask_region)
        # if tumor_patches_extracted >= 30:
        #     return

        
        # # Extract normal patches
        # elif (normal_area_ratio > normal_area_threshold) and (tumor_area_ratio == 0) and (rand <= normal_sel_ratio) and (normal_patches_extracted < max_normal_patches):
        #     patch = slide.read_region((x_slide, y_slide), 0, (patch_size, patch_size)) # TODO
        #     if save_path:
        #         patch.save(f"{save_path}/n_{str(i)}_{str(j)}.png")
        #     cv2.rectangle(slide_thumbnail, (x_mask, y_mask), (x_mask + mask_step_size, y_mask + mask_step_size), (255, 255, 0), 2)
        #     normal_patches_extracted += 1
            

    print(f'Processed {patches_processed}/{total_patches} patches.')
    print(f'Extracted {len(tumor_patch_candidates_selected)} tumor patches out of {len(tumor_patch_candidates)} candidate tumor patches')


def draw_patches(patch, slide_thumbnail_patch, tumor_mask_patch):
    fig, axes = plt.subplots(1, 3, figsize=(10, 10))
    images = [patch, slide_thumbnail_patch, tumor_mask_patch]
    titles = ['Patch', 'Slide Thumbnail', 'Tumor Mask']

    for ax, img, title in zip(axes.ravel(), images, titles):
        cmap = 'gray' if len(img.shape) == 2 else None  # Use grayscale colormap for 2D masks
        ax.imshow(img, cmap=cmap, vmin=0, vmax=255)
        ax.set_title(title)
        ax.axis('off')

    plt.tight_layout()
    plt.show()

def scale_image(image, target_width=512):
    """
    Scales an image to a specified width while maintaining aspect ratio.
    """
    height, width = image.shape[:2]
    scale_ratio = target_width / width
    new_height = int(height * scale_ratio)
    scaled_image = cv2.resize(image, (target_width, new_height), interpolation=cv2.INTER_AREA)
    return scaled_image

# Extract random pathces

In [None]:
# if 1:
#     i = 0
#     slide_filepath = svs_file_list_filtered[i]
for i, slide_filepath in enumerate(svs_file_list_filtered):
    patch_save_dir = Path("/nfs_share/students/jinhyun/TCGA/patches") / os.path.split(slide_filepath)[-1][:12]
    if os.path.exists(patch_save_dir):
        shutil.rmtree(patch_save_dir)
    patch_save_dir.mkdir(parents=False, exist_ok=True)

    print(f"Index {i+1}/{len(svs_file_list_filtered)}\t: Opening Slide {slide_filepath}")
    slide = openslide.OpenSlide(slide_filepath)
    mask_level = min(slide.level_count - 1, 2)

    ## Slide info
    # print_slide_info(slide, slide_filepath)

    ## Draw slide image
    # pil_img = slide_to_scaled_pil_image(slide, level = 3)
    # plt.figure(figsize=(8,8))
    # plt.imshow(pil_img)
    # plt.show()


    ## Tissue mask
    slide_thumbnail, tissue_mask = create_tissue_masks(slide, mask_level = mask_level)
    print("Mask shape: ", tissue_mask.shape)
    assert len(tissue_mask.shape) == 2, "Tumor mask should be single channel image"
    assert slide_thumbnail.shape[:2] == tissue_mask.shape, "Tumor mask shape is inconsistant with slide image"
    # draw_masks(slide_thumbnail, tissue_mask)

    ## Extract patches
    patch_config = {
    'patch_size': 448,
    'patch_output_size' : 448,
    'tumor_area_threshold': 0.5,
    # 'tumor_sel_ratio': 0.1,
    'max_tumor_patches': 100,
    }
    extract_patches(slide, slide_thumbnail, tumor_mask = tissue_mask, mask_level = mask_level, patch_config = patch_config, save_path = patch_save_dir)

    # plt.figure(figsize = (16,16))
    # plt.imshow(scale_image(slide_thumbnail, target_width=512))
    # plt.axis('off')
    # plt.show()


# Extract multi-scale patches

In [None]:
def extract_patches_multi_scale(slide, slide_thumbnail, tumor_mask, mask_level, patch_config = dict(), save_path = None):
    """
    Extracts normal and tumor patches from the slide using the provided masks.

    Args:
        slide (OpenSlide object): The whole slide image.
        slide_thumbnail (numpy array): The slide thumbnail image.
        tumor_mask (numpy array): The mask of tumor areas.
        mask_level (int): The level of the mask (resolution level).
        patch_config (dict): Dictionary containing configuration for patch extraction.
        save_path (str, optional): Directory to save extracted patches. Defaults to None.
       """

    patch_size_highest_res = patch_config.get('patch_size', 448)  # Patch size at the highest resolution

    # Define Multi-Scale Patch Sizes
    output_patch_sizes = [448, 448, 448, 448] #patch sizes from the highest resolution to lowest
    
    # normal_area_threshold = patch_config.get('normal_area_threshold', 0.1) # normal mask inclusion ratio that select normal patches
    # normal_sel_ratio = patch_config.get('normal_sel_ratio', 1) # nomral patch selection ratio 
    # max_normal_patches = patch_config.get('max_normal_patches', 1000) # number limit of normal patches 

    tumor_area_threshold = patch_config.get('tumor_area_threshold', 0.8) # tumor mask inclusion ratio that select tumor patches
    max_tumor_patches = patch_config.get('max_tumor_patches', 1000) # number limit of tumor patches


    downsample_factor = round(slide.level_downsamples[mask_level])
    mask_step_size = patch_size_highest_res // downsample_factor  # Step size at the mask level

    slide_width, slide_height = slide.level_dimensions[0]
    num_patches_x = slide_width // patch_size_highest_res
    num_patches_y = slide_height // patch_size_highest_res
    total_patches = num_patches_x * num_patches_y

    if save_path is not None:
        os.makedirs(save_path, exist_ok=True)

    tumor_patch_candidates = []
    patches_processed = 0

    ## First pass: collect tumor patches
    for i in range(num_patches_x):
        for j in range(num_patches_y):

            x_mask = i * mask_step_size
            y_mask = j * mask_step_size
            x_slide = i * patch_size_highest_res
            y_slide = j * patch_size_highest_res

            tumor_mask_region = tumor_mask[y_mask:y_mask + mask_step_size, x_mask:x_mask + mask_step_size]
            # normal_mask_region = normal_mask[y_mask:y_mask + mask_step_size, x_mask:x_mask + mask_step_size]
            mask_area = mask_step_size * mask_step_size * 255
            tumor_area_ratio = tumor_mask_region.sum() / mask_area
            # normal_area_ratio = normal_mask_region.sum() / mask_area

            if tumor_area_ratio > tumor_area_threshold:
                tumor_patch_candidates.append((x_slide, y_slide, x_mask, y_mask))
            
            patches_processed += 1
    
    ## random sample if too many patches
    tumor_patch_candidates_selected = random.sample(tumor_patch_candidates, max_tumor_patches) if len(tumor_patch_candidates) > max_tumor_patches else tumor_patch_candidates

    ## Second pass: Save the patches
    for x_slide, y_slide, x_mask, y_mask in tumor_patch_candidates_selected:
        
        for l in range(len(slide.level_downsamples)):
            scale_factor = round(slide.level_downsamples[l])
            patch_size = patch_size_highest_res * scale_factor  # Increase patch size at each level

            x_min = max(0, x_slide - (patch_size - patch_size_highest_res) // 2)
            y_min = max(0, y_slide - (patch_size - patch_size_highest_res) // 2)
            patch = slide.read_region((x_min, y_min), l, (output_patch_sizes[l], output_patch_sizes[l]))
            # patch_resized = patch.resize((output_patch_sizes[l], output_patch_sizes[l]))
            if save_path:
                patch.save(f"{save_path}/level_{l}_{str(x_slide)}_{str(y_slide)}.png")

            x_mask_vis = max(0, x_mask - (patch_size - patch_size_highest_res) // 2 // downsample_factor)
            y_mask_vis = max(0, y_mask - (patch_size - patch_size_highest_res) // 2 // downsample_factor)
            cv2.rectangle(slide_thumbnail, (x_mask_vis, y_mask_vis), 
                          (x_mask_vis + mask_step_size * scale_factor, y_mask_vis + mask_step_size * scale_factor), 
                          (0, 0, 255), 2)

        
    print(f'Processed {patches_processed}/{total_patches} patches.')
    print(f'Extracted {len(tumor_patch_candidates_selected)} tumor patches out of {len(tumor_patch_candidates)} candidate tumor patches')



In [None]:
# if 1:
#     i = 0
#     slide_filepath = svs_file_list_filtered[i]
for i, slide_filepath in enumerate(svs_file_list_filtered):
    patch_save_dir = Path("/nfs_share/students/jinhyun/TCGA/patches_multiscale") / os.path.split(slide_filepath)[-1][:12]
    if os.path.exists(patch_save_dir):
        shutil.rmtree(patch_save_dir)
    patch_save_dir.mkdir(parents=False, exist_ok=True)

    print(f"Index {i+1}/{len(svs_file_list_filtered)}\t: Opening Slide {slide_filepath}")
    slide = openslide.OpenSlide(slide_filepath)
    mask_level = min(slide.level_count - 1, 2)

    ## Slide info
    # print_slide_info(slide, slide_filepath)

    ## Draw slide image
    # pil_img = slide_to_scaled_pil_image(slide, level = 3)
    # plt.figure(figsize=(8,8))
    # plt.imshow(pil_img)
    # plt.show()


    ## Tissue mask
    slide_thumbnail, tissue_mask = create_tissue_masks(slide, mask_level = mask_level)
    # print("Mask shape: ", tissue_mask.shape)
    assert len(tissue_mask.shape) == 2, "Tumor mask should be single channel image"
    assert slide_thumbnail.shape[:2] == tissue_mask.shape, "Tumor mask shape is inconsistant with slide image"
    # draw_masks(slide_thumbnail, tissue_mask)

    ## Extract patches
    patch_config = {
    'patch_size': 448,
    'tumor_area_threshold': 0.5,
    # 'tumor_sel_ratio': 0.1,
    'max_tumor_patches': 5,
    }
    extract_patches_multi_scale(slide, slide_thumbnail, tumor_mask = tissue_mask, mask_level = mask_level, patch_config = patch_config, save_path = patch_save_dir)

    # plt.figure(figsize = (16,16))
    # plt.imshow(scale_image(slide_thumbnail, target_width=512))
    # plt.axis('off')
    # plt.show()
