In [None]:
#==================================================================
#Program: PatchExtraction
#Version: 2.1
#Author(s): Tianling Niu, David Helminiak
#Date Created: 10 September 2024
#Date Last Modified: 15 May 2025
#Description: Extract .tif patch images for samples from the WSI .jpg files
#==================================================================

#Have the notebook fill more of the display width
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
display(HTML("<style>.output_result { max-width:80% !important; }</style>"))

#Should parallelization calls be used
parallelization = True

#If parallelization is enabled, how many CPU threads should be used? (0 will use any/all available)
#Recommend starting at half of the available system threads if using hyperthreading,
#or 1-2 less than the number of system CPU cores if not using hyperthreading.
#Can adjust down to help manage RAM overhead, but it may have limited impact.
availableThreads = 14

#When splitting WSI images, what size should the resulting patches be (default - 4x: 400; 10x: 1000)
patchSize = 400

#When exporting should background patch images be included
backgroundPatches = False

#Threshold % for considering a patch to be in the foreground (originally 0.8, 0.2 as of 4 June 2025)
backgroundThreshold = 0.2

    

In [None]:
#Raise the maximum image size for opencv; note that this can allow for decompression bomb DOS attacks if an untrusted image ends up as an input
import os 
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2,40).__str__()

import copy
import cv2
import glob
import logging
import natsort
import numpy as np
import pandas as pd
import ray
import shutil
import time

from contextlib import nullcontext
from ray.util.multiprocessing import Pool
from tqdm import tqdm

#Define logging levels and behaviors
logging.root.setLevel(logging.ERROR)
logging.raiseExceptions = False

#Ray actor for holding global progress in parallel sampling operations
@ray.remote(num_cpus=0)
class SamplingProgress_Actor:
    def __init__(self): self.current = 0.0
    def update(self): self.current += 1
    def getCurrent(self): return self.current

#Extract and save patches from a WSI sample
def extractPatches(dir_WSI_data, sampleFolder, sampleName, backgroundPatches, samplingProgress_Actor=None):
    
    #Load the sample WSI (crop to the extent that patches can be extracted evenly); BGR ordering
    try: imageWSI = cv2.imread(dir_WSI_data + sampleName + '.jpg', cv2.IMREAD_UNCHANGED)
    except: print('Error - Failed to load WSI image for sample: ', samplename)
    blocksY, blocksX = imageWSI.shape[0]//patchSize, imageWSI.shape[1]//patchSize
    imageWSI = imageWSI[:blocksY*patchSize, :blocksX*patchSize]
    
    #Extract and save patches from the loaded WSI; BGR ordering
    patchImages = imageWSI.reshape(blocksY, patchSize, blocksX, patchSize, 3).swapaxes(1,2)
    patchIndex = 1
    for posY in range(0, blocksY):
        for posX in range(0, blocksX):
            patchImage = patchImages[posY, posX]
            if backgroundPatches or (np.mean(patchImage[:,:,-1] >= 5) > backgroundThreshold): 
                filenameOutput = os.path.join(sampleFolder, f'PS{sampleName}_{patchIndex}_{posY*patchSize}_{posX*patchSize}.tif')
                writeSuccess = cv2.imwrite(filenameOutput, patchImage, params=(cv2.IMWRITE_TIFF_COMPRESSION, 1))
                if not writeSuccess: print('Error - Failed to write patch image to disk: ', filenameOutput)
                patchIndex+=1
                
    #If running in parallel, update the recorded progress
    if samplingProgress_Actor: samplingProgress_Actor.update.remote()



In [None]:
#Store directory references
dir_data = '.' + os.path.sep + 'DATA' + os.path.sep
dir_WSI_data = dir_data + 'INPUT_WSI' + os.path.sep
dir_patches_data = dir_data + 'PATCHES' + os.path.sep

#Ensure output directory exists and is empty
if os.path.exists(dir_patches_data): shutil.rmtree(dir_patches_data)
os.makedirs(dir_patches_data)

#Obtain list of sample names of available WSI .jpg images
sampleNames = [os.path.splitext(name)[0] for name in os.listdir(dir_WSI_data)]

#Create subfolders for the extracted patch image outputs of each WSI/sample; erase prior outputs if they exist
sampleFolders = np.asarray([dir_patches_data + sampleName + os.path.sep for sampleName in sampleNames])
for sampleFolder in sampleFolders: os.makedirs(sampleFolder)

#Extract patches for each sample
if parallelization: 
    
    #Initialize ray instance without a dashboard interface (redirect any text output to an eternal and silent void)
    with nullcontext(): _ = ray.init(num_cpus=availableThreads, logging_level=logging.root.level, include_dashboard=False)
    
    #Initialize a global progress bar
    maxProgress = len(sampleNames)
    samplingProgress_Actor = SamplingProgress_Actor.remote()
    pbar = tqdm(total=maxProgress, desc='Extracting', leave=True, ascii=False)
    
    #Start parallel extraction operations
    futures = [(dir_WSI_data, sampleFolders[sampleIndex], sampleName, backgroundPatches, samplingProgress_Actor) for sampleIndex, sampleName in enumerate(sampleNames)]
    computePool = Pool(availableThreads)
    results = computePool.starmap_async(extractPatches, futures)
    computePool.close()
    
    #Regularly monitor global progress of parallel operations till completion
    pbar.n = 0
    pbar.refresh()
    while True:
        pbar.n = np.clip(round(copy.deepcopy(ray.get(samplingProgress_Actor.getCurrent.remote())), 0), 0, maxProgress)
        pbar.refresh()
        if results.ready(): break
        time.sleep(0.1)
    pbar.n = maxProgress
    pbar.refresh()
    pbar.close()
    computePool.join()
    del samplingProgress_Actor, futures
    
    #Shutdown ray instance
    ray.shutdown()
    
else: 
    for sampleIndex, sampleName in tqdm(enumerate(sampleNames), total=len(sampleNames), desc='Samples', leave=True, ascii=False): 
        extractPatches(dir_WSI_data, sampleFolders[sampleIndex], sampleName, backgroundPatches)
    
    