# AI_campus_ProstateSeg 
# Module 1: Loading of Slides and Masks from a subset of the PANDA dataset and conducting basic validation and characterization
### PANDA: Prostate cANcer graDe Assessment (PANDA) Challenge

In this notebook, we will simply load some whole slide image and mask pairs from a subset of the PANDA dataset from Kaggle (https://www.kaggle.com/competitions/prostate-cancer-grade-assessment/data). These image and mask pairs are from the Radboud subset of the data.

## 1. Import the necessary packages

In [90]:
import os
import tifffile
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from typing import List, Union, Optional

## 2. Define folder and sub-folder directory names

In [4]:
ROOT_FOLDER = "sample_data"
IMAGE_SUBFOLDER = "train_images"
MASK_SUBFOLDER = "train_label_masks"

image_dir = os.path.join(ROOT_FOLDER, IMAGE_SUBFOLDER)
mask_dir = os.path.join(ROOT_FOLDER, MASK_SUBFOLDER)

## 3. List all the image files (.tiff extension) contained in the image folder

In [67]:
for image_file in sorted(os.listdir(image_dir)):
    print(image_file)

.DS_Store
0018ae58b01bdadc8e347995b69f99aa.tiff
004dd32d9cd167d9cc31c13b704498af.tiff
0068d4c7529e34fd4c9da863ce01a161.tiff
006f6aa35a78965c92fffd1fbd53a058.tiff
007433133235efc27a39f11df6940829.tiff
0076bcb66e46fb485f5ba432b9a1fe8a.tiff
008069b542b0439ed69b194674051964.tiff
00928370e2dfeb8a507667ef1d4efcbb.tiff
00951a7fad040bf7e90f32e81fc0746f.tiff
00a26aaa82c959624d90dfb69fcf259c.tiff


## 4. List all the mask files contained in the mask folder

In [68]:
for mask_file in sorted(os.listdir(mask_dir)):
    print(mask_file)

.DS_Store
0018ae58b01bdadc8e347995b69f99aa_mask.tiff
004dd32d9cd167d9cc31c13b704498af_mask.tiff
0068d4c7529e34fd4c9da863ce01a161_mask.tiff
006f6aa35a78965c92fffd1fbd53a058_mask.tiff
007433133235efc27a39f11df6940829_mask.tiff
0076bcb66e46fb485f5ba432b9a1fe8a_mask.tiff
008069b542b0439ed69b194674051964_mask.tiff
00928370e2dfeb8a507667ef1d4efcbb_mask.tiff
00951a7fad040bf7e90f32e81fc0746f_mask.tiff
00a26aaa82c959624d90dfb69fcf259c_mask.tiff


## 5. Write function to validate the dataset such that there is a mask corresponding to each image, and all images are paired with the matching mask label

Each file with the name "X" in image folder must be coupled with a file named "X_mask" in the mask folder

In [78]:
def verify_pair_match(image_dir: str,
                      mask_dir: str,
                      mask_ext: str = "_mask.tiff") -> bool:
    """
    Verifies if every image file in the provided image directory has a corresponding mask file 
    in the mask directory. The mask file is expected to have the same name as the image file 
    but with a specified extension (default is '_mask.tiff').

    Parameters:
    ----------
    image_dir : str
        The directory containing image files.
    
    mask_dir : str
        The directory containing mask files.
    
    mask_ext : str, optional (default is "_mask.tiff")
        The extension of the mask files. 
    
    Returns:
    -------
    bool
        Returns True if every image file has a corresponding mask file, otherwise returns False. 
        
    """
    
    image_files = sorted(os.listdir(image_dir))
    mask_files = sorted(os.listdir(mask_dir))
    
    for file in image_files:
        if file != '.DS_Store':  
            if file[:file.find(".")] + mask_ext not in mask_files:
                return False  
    
    return True

In [79]:
verify_pair_match(image_dir = image_dir,
                  mask_dir = mask_dir)

True

## 6. Load and print out the dimensions of each of the images contained in the image folder

In [80]:
for image_file in sorted(os.listdir(image_dir)):
    if image_file != ".DS_Store":
        img = tifffile.imread(os.path.join(ROOT_FOLDER, IMAGE_SUBFOLDER, image_file))
        print(img.shape)

(25344, 5888, 3)
(22528, 8192, 3)
(10496, 6912, 3)
(7680, 2048, 3)
(24320, 9472, 3)
(14848, 11776, 3)
(8704, 23808, 3)
(36352, 10752, 3)
(8192, 11520, 3)
(20736, 18688, 3)


## 7. Load and print out the dimensions of each of the masks contained in the mask folder

In [139]:
for mask_file in sorted(os.listdir(mask_dir)):
    if mask_file != ".DS_Store":
        mask = tifffile.imread(os.path.join(ROOT_FOLDER, MASK_SUBFOLDER, mask_file))
        print(mask.shape)

(25344, 5888, 3)
(22528, 8192, 3)
(10496, 6912, 3)
(7680, 2048, 3)
(24320, 9472, 3)
(14848, 11776, 3)
(8704, 23808, 3)
(36352, 10752, 3)
(8192, 11520, 3)
(20736, 18688, 3)


## 8. Write function to validate the dataset such that each image-mask pair has the same dimensionality along specified dimension indices

In [150]:
def verify_dim_match(image_dir: str,
                     mask_dir: str,
                     dims: List[int] = [0,1,2]) -> bool:
    """
    Verifies that each image and mask pair have the same dimensionality along the specified
    list of dimension indexes in dims. 
    
    Parameters:
    ----------
    image_dir : str
        The directory containing image files.
    
    mask_dir : str
        The directory containing mask files.
    
    dims: List[int]
        List of integers specifying dimensions and their values to check for match. 
        Default: [0,1,2]: each image and mask pair will be checked for a match at dimension indices [0,1,2]
                          
        
    Returns:
    -------
    bool
        Returns True if the dimensionality of each image and mask pair matches along all of the 
        specified dimension indices, False otherwise
        
    """
    
    image_files = list(filter(lambda file: file != '.DS_Store', sorted(os.listdir(image_dir))))
    mask_files = list(filter(lambda file: file != '.DS_Store', sorted(os.listdir(mask_dir))))
    
    assert len(image_files) == len(mask_files)
    dims = list(set(dims))
    
    n_files = len(image_files)
    
    for i in range(n_files):
        img = tifffile.imread(os.path.join(ROOT_FOLDER, IMAGE_SUBFOLDER, image_files[i]))
        mask = tifffile.imread(os.path.join(ROOT_FOLDER, MASK_SUBFOLDER, mask_files[i]))
        
        img_dim = img.shape
        mask_dim = mask.shape
        
        match = all([img_dim[d] == mask_dim[d] for d in dims])
        
        if not match:
            return False
        
    return True

In [326]:
verify_dim_match(image_dir = image_dir,
                 mask_dir = mask_dir)

True

## 9. Write function to validate and characterize masks

In [325]:
def validate_three_channel_mask(mask: np.ndarray) -> bool:
    
    """
    Since the PANDA dataset masks have 3 channels, this is a basic
    function to validate an input mask. Only one of the three channels
    of the mask could be effectively used as the official mask of the
    corresponding image.
    
    Parameters:
    ----------
    mask : np.ndarray
        Input three-channel mask with integer labels possibly ranging from 0-5
        across one or all channels
    
    For the Radboud study of the PANDA dataset:
        0: background (non tissue) or unknown
        1: stroma (connective tissue, non-epithelium tissue)
        2: healthy (benign) epithelium
        3: cancerous epithelium (Gleason 3)
        4: cancerous epithelium (Gleason 4)
        5: cancerous epithelium (Gleason 5)
    
    Returns:
    -------
    bool
        Returns True if the three-channel mask is valid, False otherwise.
        Validity criteria defined in the following cases. 
    """
    
    # Verify that the mask to validate contains 3 channels
    assert isinstance(mask, np.ndarray)
    assert mask.shape[2] == 3
    
    # Extract first, second, and third channels from the mask
    first = mask[:, :, 0]
    second = mask[:, :, 1]
    third = mask[:, :, 2]
    
    first_sum = first.sum()
    second_sum = second.sum()
    third_sum = third.sum()
    
    assert first_sum >= 0 and second_sum >= 0 and third_sum >= 0
        
    # Case 1: The mask indicates all background (contains only 0s)
    # If all entries in all channels are 0, all channels will sum to 0
    
    if (first_sum + second_sum + third_sum) == 0:
        return True
    
    # Case 2: The mask indicates non-zero signal and only one of the three channels
    # is designated to express the non-zero signal
    # Exactly one of the channels will sum to a number greater than 0
    # While the other two channels will each sum to 0
    
    if first_sum > 0 and (second_sum + third_sum == 0):
        return True
    if second_sum > 0 and (first_sum + third_sum == 0):
        return True
    if third_sum > 0 and (first_sum + second_sum == 0):
        return True
    
    # Case 3: The mask indicates non-zero signal and two channels are designated to
    # express that non-zero signal. So two of the three channels must sum to a number 
    # greater than 0 and both of them should be identical
    
    if first_sum > 0 and second_sum > 0 and np.all(first == second):
        return True
    if first_sum > 0 and third_sum > 0 and np.all(first == third):
        return True
    if second_sum > 0 and third_sum > 0 and np.all(second == third):
        return True
    
    # Case 4: 
    # The mask indicates non-zero signal and all of the three channels are designated
    # to express that non-zero signal. So each of the three channels must sum to a number
    # greater than 0 and all of them must be identical
    
    if (first_sum == second_sum == third_sum) and np.all(first == second) and np.all(second == third):
        return True
    
    return False

In [322]:
def validate_mask_dir(mask_dir: str) -> bool:
    
    """
    Validate all masks in mask directory using
    validate_three_channel_mask
    
    Parameters:
    ----------
    mask_dir : str
        The directory containing mask files.
        
    Returns:
    -------
    bool
        Returns True if all masks in directory are valid
        according to validate_three_channel_mask
        
    """
    
    mask_files = list(filter(lambda file: file != '.DS_Store', sorted(os.listdir(mask_dir))))
    
    for file in mask_files:
        mask = tifffile.imread(os.path.join(ROOT_FOLDER, MASK_SUBFOLDER, file))
        if not validate_three_channel_mask(mask):
            return False
    return True

In [324]:
validate_mask_dir(mask_dir = mask_dir)

True