In [None]:
!pip install numpy tqdm cupy-cuda12x nvidia-dali-cuda120 pillow

In [None]:
import os
import glob
import numpy as np
import cupy as cp
from nvidia.dali import pipeline, types
import nvidia.dali.fn as fn
import threading
from queue import Queue
from PIL import Image

# --- CUDA Kernel for Tile Filtering ---
@cp.fuse()
def filter_empty_tiles(tiles, threshold=0.95):
    """
    Filters out tiles that are mostly empty based on a background pixel threshold.

    Args:
        tiles: A batch of CuPy arrays representing tiles.
        threshold: The threshold for the percentage of background pixels.

    Returns:
        A boolean mask array indicating which tiles should be kept.
    """
    background_pixels = cp.sum(tiles > 250, axis=(1, 2, 3))
    percentage_background = background_pixels / (tiles.shape[1] * tiles.shape[2] * tiles.shape[3])
    return percentage_background < threshold


# --- DALI Pipeline for Loading and Preprocessing ---
class WSITilingPipeline(pipeline.Pipeline):
    def __init__(self, input_dir, batch_size, device_id=0, num_threads=4, seed=42):
        super(WSITilingPipeline, self).__init__(batch_size, num_threads, device_id, seed)
        self.input = fn.readers.file(file_root=input_dir, random_shuffle=True)
        self.decode = fn.decoders.image(device="mixed")

    def define_graph(self):
        inputs, _ = self.input()
        images = fn.peek_image_shape(inputs)
        images = fn.resize(images, resize_x=tile_size, resize_y=tile_size)
        return images


# --- Worker Thread Function ---
def worker_thread(input_dir, output_dir, tile_size, threshold, queue, batch_size, device_id):
    pipe = WSITilingPipeline(input_dir, batch_size=batch_size, device_id=device_id)
    pipe.build()

    while True:
        wsi_file = queue.get()
        if wsi_file is None:
            break  # Signal to terminate the thread

        wsi_name = os.path.splitext(os.path.basename(wsi_file))[0]
        output_wsi_dir = os.path.join(output_dir, wsi_name)
        os.makedirs(output_wsi_dir, exist_ok=True)

        tile_count = 0
        while True:
            try:
                images = pipe.run()[0].as_cpu().as_array()
            except RuntimeError:
                break

            # Move images to GPU
            images_gpu = cp.asarray(images)

            # Generate tiles
            tiles = cp.array(
                [
                    images_gpu[:, y:y+tile_size, x:x+tile_size]
                    for y in range(0, images_gpu.shape[1], tile_size)
                    for x in range(0, images_gpu.shape[2], tile_size)
                ]
            )

            # Filter out empty tiles
            keep_tiles = filter_empty_tiles(tiles, threshold)
            filtered_tiles = tiles[keep_tiles]

            # Save tiles using Pillow (PIL)
            for j, tile in enumerate(filtered_tiles):
                tile_filename = f"{wsi_name}_tile_{tile_count + j}.png"
                tile_path = os.path.join(output_wsi_dir, tile_filename)
                
                # Transfer tile to CPU and save with Pillow
                tile_cpu = cp.asnumpy(tile)
                Image.fromarray(tile_cpu.astype(np.uint8)).save(tile_path)

            tile_count += len(filtered_tiles)

        queue.task_done()


# --- Main Tiling Function ---
def tile_wsi_images(
    input_dir,
    output_dir,
    tile_size=512,
    batch_size=8,
    threshold=0.95,
    num_threads=4,
):
    """
    Tiles WSI images using CUDA and DALI, saving tiles to disk.

    Args:
        input_dir: Directory containing the WSI images.
        output_dir: Directory to save the tiled images.
        tile_size: Size of each tile.
        batch_size: Batch size for processing.
        threshold: Threshold for filtering empty tiles.
        num_threads: Number of worker threads to use for parallel processing.
    """

    # Get the list of WSI image files
    wsi_files = glob.glob(os.path.join(input_dir, "*.png"))
    print(f"Found {len(wsi_files)} WSI images.")

    # Create a queue for tasks
    queue = Queue()

    # Create and start worker threads
    threads = []
    for i in range(num_threads):
        thread = threading.Thread(
            target=worker_thread,
            args=(input_dir, output_dir, tile_size, threshold, queue, batch_size, i),
        )
        thread.start()
        threads.append(thread)

    # Add tasks to the queue
    for wsi_file in wsi_files:
        queue.put(wsi_file)

    # Wait for all tasks to complete
    queue.join()

    # Signal threads to terminate
    for _ in range(num_threads):
        queue.put(None)

    # Wait for threads to finish
    for thread in threads:
        thread.join()

    print("Tiling complete!")


# --- Example Usage ---
if __name__ == "__main__":
    input_directory = "/home/input"
    output_directory = "/home/output"
    tile_size = 512  # Example tile size
    batch_size = 8  # Example batch size
    threshold = 0.95  # Example threshold
    num_threads = 4  # Example number of threads

    tile_wsi_images(
        input_directory,
        output_directory,
        tile_size,
        batch_size,
        threshold,
        num_threads,
    )


In [13]:
import os
import glob
import numpy as np
import cupy as cp
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.fn as fn
import threading
from queue import Queue
from PIL import Image

# --- CUDA Kernel for Tile Filtering ---
@cp.fuse()
def filter_empty_tiles(tiles, threshold=0.95):
    """
    Filters out tiles that are mostly empty based on a background pixel threshold.
    """
    background_pixels = cp.sum(tiles > 250, axis=(1, 2, 3))
    percentage_background = background_pixels / (tiles.shape[1] * tiles.shape[2] * tiles.shape[3])
    return percentage_background < threshold

# --- DALI Pipeline for Loading and Preprocessing ---
class WSITilingPipeline(Pipeline):
    def __init__(self, input_files, batch_size, tile_size, device_id=0, num_threads=4):
        super(WSITilingPipeline, self).__init__(batch_size, num_threads, device_id)
        self.input = fn.external_source(source=input_files, num_outputs=1)
        self.decode = fn.decoders.image(device="mixed", output_type=types.RGB)
        self.resize = fn.resize(resize_x=tile_size, resize_y=tile_size, device="gpu")

    def define_graph(self):
        inputs = self.input()
        images = self.decode(inputs)
        images_resized = self.resize(images)
        return images_resized

# --- Worker Thread Function ---
def worker_thread(input_files, output_dir, tile_size, threshold, batch_size, device_id):
    pipe = WSITilingPipeline(input_files=input_files, batch_size=batch_size, tile_size=tile_size, device_id=device_id)
    pipe.build()

    for file_idx, input_file in enumerate(input_files):
        wsi_name = os.path.splitext(os.path.basename(input_file))[0]
        output_wsi_dir = os.path.join(output_dir, wsi_name)
        os.makedirs(output_wsi_dir, exist_ok=True)

        tile_count = 0
        while True:
            try:
                images = pipe.run()[0].as_cpu().as_array()
            except RuntimeError:
                break

            # Move images to GPU
            images_gpu = cp.asarray(images)

            # Generate tiles
            tiles = cp.array(
                [
                    images_gpu[:, y:y+tile_size, x:x+tile_size]
                    for y in range(0, images_gpu.shape[1], tile_size)
                    for x in range(0, images_gpu.shape[2], tile_size)
                ]
            )

            # Filter out empty tiles
            keep_tiles = filter_empty_tiles(tiles, threshold)
            filtered_tiles = tiles[keep_tiles]

            # Save tiles using Pillow (PIL)
            for j, tile in enumerate(filtered_tiles):
                tile_filename = f"{wsi_name}_tile_{tile_count + j}.png"
                tile_path = os.path.join(output_wsi_dir, tile_filename)
                
                # Transfer tile to CPU and save with Pillow
                tile_cpu = cp.asnumpy(tile)
                Image.fromarray(tile_cpu.astype(np.uint8)).save(tile_path)

            tile_count += len(filtered_tiles)

# --- Main Tiling Function ---
def tile_wsi_images(
    input_dir,
    output_dir,
    tile_size=512,
    batch_size=8,
    threshold=0.95,
    num_threads=4,
):
    """
    Tiles WSI images using CUDA and DALI, saving tiles to disk.
    """
    wsi_files = glob.glob(os.path.join(input_dir, "*.png"))
    print(f"Found {len(wsi_files)} WSI images.")

    # Divide files among threads
    chunks = [wsi_files[i::num_threads] for i in range(num_threads)]

    threads = []
    for i, chunk in enumerate(chunks):
        thread = threading.Thread(
            target=worker_thread,
            args=(chunk, output_dir, tile_size, threshold, batch_size, i),
        )
        thread.start()
        threads.append(thread)

    for thread in threads:
        thread.join()

    print("Tiling complete!")

# --- Example Usage ---
if __name__ == "__main__":
    input_directory = "/home/input"
    output_directory = "/home/output"
    tile_size = 512  # Example tile size
    batch_size = 8  # Example batch size
    threshold = 0.95  # Example threshold
    num_threads = 4  # Example number of threads

    tile_wsi_images(
        input_directory,
        output_directory,
        tile_size,
        batch_size,
        threshold,
        num_threads,
    )


Exception in thread Thread-41 (worker_thread):
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
Exception in thread Thread-42 (worker_thread):
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
    self.run()
  File "/home/tiler/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
Exception in thread Thread-43 (worker_thread):
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
    self.run()
  File "/home/tiler/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/usr/lib/python3.12/threading.py", line 1010, in run
Exception in thread Thread-44 (worker_thread):
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
    self._target(*self._args, **self._kwargs)
  File "/tm

Found 3 WSI images.
Tiling complete!
