# Imports

In [None]:
import os
import math
from PIL import Image
from openslide import OpenSlide
import numpy as np
import matplotlib.pyplot as plt
import tifffile

In [None]:
# Path to dataset as described in the appendix of the report. Should be root directory of the dataset
root_dir = r"D:\Datasets\CV_project\GlomDataset_WSI" 

# Patching logic

In [None]:
# Parameters, change accordingly

patch_size = 2048
patch_overlap = 1024
threshold = int(0.01 * patch_size * patch_size)
stride = patch_size - patch_overlap

# train, val and test splits
for split in os.listdir(root_dir):
    split_dir = os.path.join(root_dir, split)
    # 56Nx, GN, NEP25 and normal
    for disease in os.listdir(split_dir):
        disease_path = os.path.join(split_dir, disease)
        # different cases of disease
        for case in os.listdir(disease_path):
            case_path = os.path.join(disease_path, case)

            # tiff files are inside seperate folders
            image_wsi_dir = os.path.join(case_path, "image_wsi")
            mask_wsi_dir = os.path.join(case_path, "mask_wsi")

            # skip if the image was already segmented (useful if not completed in one sitting)
            if os.path.exists(os.path.join(case_path, "image")):
                continue
            
            for filename in os.listdir(image_wsi_dir):
                name_wo_ext = os.path.splitext(filename)[0]
                wsi_path = os.path.join(image_wsi_dir, filename)
                mask_path = os.path.join(mask_wsi_dir, filename.replace("_wsi", "_mask"))

                # paths to save results to (same structure as the original patch-level dataset)
                patch_image_dir = os.path.join(case_path, "image")
                patch_mask_dir = os.path.join(case_path, "mask")
                os.makedirs(patch_image_dir, exist_ok=True)
                os.makedirs(patch_mask_dir, exist_ok=True)

                # open image and mask with OpenSlide library
                slide = OpenSlide(wsi_path)
                mask = OpenSlide(mask_path)

                # calculate how many patches we will get for x and y direction
                w, h = slide.dimensions
                n_x = math.ceil((w - patch_size) / stride) + 1
                n_y = math.ceil((h - patch_size) / stride) + 1

                print(f"[{case}] Started processing {filename}")

                counter = 0
                for iy in range(n_y):
                    for ix in range(n_x):
                        # the image size could not be divisible by our patch size which would result in the final patch in x and y direction to be smaller than the specified patch size
                        # in this case we create a canvas with the correct patch size and paste the region there resulting in a 0 value padding
                        x = ix * stride
                        y = iy * stride
                        read_w = min(patch_size, w - x)
                        read_h = min(patch_size, h - y)

                        # Read mask patch
                        mask_patch = mask.read_region((x, y), 0, (read_w, read_h)).convert("L")
                        if (read_w, read_h) != (patch_size, patch_size):
                            mask_canvas = Image.new("L", (patch_size, patch_size), 0)
                            mask_canvas.paste(mask_patch, (0, 0))
                            mask_patch = mask_canvas

                        # filter out trivial patches, skip if number of positive pixels falls below threshold
                        mask_np = np.array(mask_patch)
                        glom_pixel_count = np.sum(mask_np == 255)
                        if glom_pixel_count < threshold:
                            continue

                        # Read image patch
                        patch = slide.read_region((x, y), 0, (read_w, read_h)).convert("RGB")
                        if (read_w, read_h) != (patch_size, patch_size):
                            canvas = Image.new("RGB", (patch_size, patch_size), (0, 0, 0))
                            canvas.paste(patch, (0, 0))
                            patch = canvas

                        # we add a suffix with the patch-id (counter) and the position of the upper left pixel for full-scale reconstruction of the mask and its prediction
                        patch.save(os.path.join(patch_image_dir, f'{name_wo_ext}_{counter}_{x}_{y}_img.png'))
                        mask_patch.save(os.path.join(patch_mask_dir, f'{name_wo_ext}_{counter}_{x}_{y}_mask.png'))

                        counter += 1

                print(f"[{case}] Processed {filename}: {counter} patches")


# Display image slide

In [None]:
wsi_path = os.path.join(root_dir, r'validation\DN\11_361\mask_wsi\11-361_mask.tiff')
slide = OpenSlide(wsi_path)
w, h = slide.dimensions

# Define region to extract (top-left corner)
x, y = w//3, 4000 # example values, used for evaluation of results of patching logic
region_size = (2048, 2048)
level = 0  # Full resolution

# Read the region and convert to RGB
region = slide.read_region((x, y), level, region_size).convert("RGB")
mask_np = np.array(region)

# Plot
plt.figure(figsize=(8, 8))
plt.imshow(region)
plt.axis('off')
plt.title("Top-Left 2048x2048 Region")
plt.show()

# Data curation: Remove trivial masks with corresponding image patches

In [None]:
# Our first patching logic used a different filtering logic which resulted in patches that had no glomeruli
# We used this logic to curate the dataset to only include patches with at least one pixel belonging to a glomerulus

# Iterate through masks of entire dataset
for split in os.listdir(root_dir):
    split_path = os.path.join(root_dir, split)
    for desease in os.listdir(split_path):
        desease_path = os.path.join(split_path, desease)
        for image_dir in os.listdir(desease_path):
            image_patches_path = os.path.join(os.path.join(desease_path, image_dir), 'image')
            mask_patches_path = os.path.join(os.path.join(desease_path, image_dir), 'mask')
            for mask_patch in os.listdir(mask_patches_path):
                mask_path = os.path.join(mask_patches_path, mask_patch)
                image_path = os.path.join(image_patches_path, mask_patch.replace('_mask', '_img'))
                # open mask as numpy array
                mask = Image.open(mask_path)
                mask_np = np.array(mask)
                # if max value of array is not 255 then there is no pixels belonging to a glomerulus in the mask
                # -> remove mask and corresponding image patch
                if(np.max(mask_np)) != 255:
                    os.remove(mask_path)
                    os.remove(image_path)

In [None]:
# Further dataset curation using a threshold on the glomerulus pixel count to filter out
# patches that hold almost no relevant information for the model

# Threshold for minimal glomerulus pixel count
patch_size = 2048
threshold = int(0.01 * patch_size * patch_size)

# Iterate through masks of entire dataset
for split in os.listdir(root_dir):
    split_path = os.path.join(root_dir, split)
    for disease in os.listdir(split_path):
        disease_path = os.path.join(split_path, disease)
        for image_dir in os.listdir(disease_path):
            image_patches_path = os.path.join(disease_path, image_dir, 'image')
            mask_patches_path = os.path.join(disease_path, image_dir, 'mask')
            for mask_patch in os.listdir(mask_patches_path):
                mask_path = os.path.join(mask_patches_path, mask_patch)
                image_path = os.path.join(image_patches_path, mask_patch.replace('_mask', '_img'))

                # open mask as numpy array
                mask = Image.open(mask_path)
                mask_np = np.array(mask)

                # count glomerulus pixels (value 255)
                glom_pixel_count = np.sum(mask_np == 255)

                # remove mask and image if below threshold
                if glom_pixel_count < threshold:
                    os.remove(mask_path)
                    os.remove(image_path)


# Check integrity of dataset

In [None]:
# Iterate through masks of entire dataset
for split in os.listdir(root_dir):
    split_path = os.path.join(root_dir, split)
    for disease in os.listdir(split_path):
        disease_path = os.path.join(split_path, disease)
        for image_dir in os.listdir(disease_path):
            image_patches_path = os.path.join(disease_path, image_dir, 'image')
            mask_patches_path = os.path.join(disease_path, image_dir, 'mask')

            # Count number of image and mask files
            image_files = os.listdir(image_patches_path)
            mask_files = os.listdir(mask_patches_path)

            if len(image_files) != len(mask_files):
                print(f"[WARNING] Mismatch in file counts for: {image_dir}")
                print(f"           -> Images: {len(image_files)}, Masks: {len(mask_files)}")

            # Per-file existence check
            for mask_patch in mask_files:
                mask_path = os.path.join(mask_patches_path, mask_patch)
                image_patch_name = mask_patch.replace('_mask', '_img')
                image_path = os.path.join(image_patches_path, image_patch_name)

                if not os.path.exists(image_path):
                    print(f"[WARNING] Missing image for mask: {mask_patch} in {image_patches_path}")

            # Per-file existence check
            for image_patch in image_files:
                image_path = os.path.join(image_patches_path, image_patch)
                mask_patch_name = image_patch.replace('_img', '_mask')
                mask_patch = os.path.join(mask_patches_path, mask_patch_name)

                if not os.path.exists(mask_patch):
                    print(f"[WARNING] Missing mask for image: {image_patch} in {mask_patches_path}")


# Get dimensions of largest WSI

In [None]:
maxw = maxh = 0

for split in os.listdir(root_dir):
    base_split_path = os.path.join(root_dir, split)
    for desease in os.listdir(base_split_path):
        base_desease_path = os.path.join(base_split_path, desease)
        for image_dir in os.listdir(base_desease_path):
            base_image_dir_path = os.path.join(base_desease_path, image_dir)
            for tiff_dir in os.listdir(base_image_dir_path):
                tiff_dir_path = os.path.join(base_image_dir_path, tiff_dir)
                for file in os.listdir(tiff_dir_path):
                    slide = OpenSlide(os.path.join(tiff_dir_path, file))
                    w, h = slide.dimensions
                    # save largest dimension
                    if((w*h) > (maxw*maxh)):
                        maxw, maxh = w, h

print((maxw, maxh))

# Legacy patching logic using distance from backgroundcolor to filter trivial patches

In [None]:
patch_size = 2048
patch_overlap = 1024
stride = patch_size - patch_overlap

for split in os.listdir(root_dir):
    split_dir = os.path.join(root_dir, split)
    for disease in os.listdir(split_dir):
        disease_path = os.path.join(split_dir, disease)
        if not os.path.isdir(disease_path):
            continue

        for case in os.listdir(disease_path):
            case_path = os.path.join(disease_path, case)
            image_wsi_dir = os.path.join(case_path, "image_wsi")
            mask_wsi_dir = os.path.join(case_path, "mask_wsi")

            if os.path.exists(os.path.join(case_path, "image")):
                continue

            for filename in os.listdir(image_wsi_dir):
                if not filename.endswith((".tiff", ".png", ".jpg")):
                    continue

                name_wo_ext = os.path.splitext(filename)[0]
                wsi_path = os.path.join(image_wsi_dir, filename)
                mask_path = os.path.join(mask_wsi_dir, filename.replace("_wsi", "_mask"))

                patch_image_dir = os.path.join(case_path, "image")
                patch_mask_dir = os.path.join(case_path, "mask")
                os.makedirs(patch_image_dir, exist_ok=True)
                os.makedirs(patch_mask_dir, exist_ok=True)

                slide = OpenSlide(wsi_path)
                mask = OpenSlide(mask_path)

                w, h = slide.dimensions
                n_x = math.ceil((w - patch_size) / stride) + 1
                n_y = math.ceil((h - patch_size) / stride) + 1

                counter = 0
                for iy in range(n_y):
                    for ix in range(n_x):
                        x = ix * stride
                        y = iy * stride
                        read_w = min(patch_size, w - x)
                        read_h = min(patch_size, h - y)

                        # === Fast thumbnail check ===
                        thumb_scale = 16
                        thumb_patch = slide.read_region((x, y), 0, (read_w, read_h)).convert("RGB").resize(
                            (read_w // thumb_scale, read_h // thumb_scale), Image.BILINEAR)
                        thumb_np = np.array(thumb_patch)

                        bg_rgb = np.array([222, 208, 212])
                        black_rgb = np.array([0, 0, 0])
                        tolerance = 15

                        dist_bg = np.linalg.norm(thumb_np - bg_rgb, axis=2)
                        dist_black = np.linalg.norm(thumb_np - black_rgb, axis=2)
                        trivial_mask_thumb = (dist_bg < tolerance) | (dist_black < tolerance)
                        ratio_thumb = np.count_nonzero(~trivial_mask_thumb) / trivial_mask_thumb.size

                        if ratio_thumb < 0.1:
                            continue  # Skip clearly empty

                        patch = slide.read_region((x, y), 0, (read_w, read_h)).convert("RGB")
                        if (read_w, read_h) != (patch_size, patch_size):
                            canvas = Image.new("RGB", (patch_size, patch_size), (0, 0, 0))
                            canvas.paste(patch, (0, 0))
                            patch = canvas

                        patch.save(os.path.join(patch_image_dir, f'{name_wo_ext}_{counter}_{x}_{y}_img.png'))

                        mask_patch = mask.read_region((x, y), 0, (read_w, read_h)).convert("L")
                        if (read_w, read_h) != (patch_size, patch_size):
                            mask_canvas = Image.new("L", (patch_size, patch_size), 0)
                            mask_canvas.paste(mask_patch, (0, 0))
                            mask_patch = mask_canvas

                        mask_patch.save(os.path.join(patch_mask_dir, f'{name_wo_ext}_{counter}_{x}_{y}_mask.png'))

                        counter += 1

                print(f"[{case}] Processed {filename}: {counter} patches")


# Analyze color schemes of .jpg and .tiff files

In [None]:
def analyze_image_color_scheme(file_path):
    ext = os.path.splitext(file_path)[-1].lower()
    print(f"\n--- {file_path} ---")

    if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
        try:
            with Image.open(file_path) as img:
                print(f"Format     : {img.format}")
                print(f"Mode       : {img.mode}")
                print(f"Size       : {img.size}")

                mode_description = {
                    "1": "1-bit black & white",
                    "L": "Grayscale (8-bit)",
                    "RGB": "Red, Green, Blue",
                    "RGBA": "RGB + Alpha (transparency)",
                    "CMYK": "Cyan, Magenta, Yellow, Black",
                    "YCbCr": "Luminance + Chrominance",
                    "I": "32-bit signed integer",
                    "F": "32-bit float"
                }
                print(f"Interpreted: {mode_description.get(img.mode, 'Unknown/Custom')}")

        except Exception as e:
            print(f"Error reading image: {e}")

    elif ext in ['.tif', '.tiff']:
        try:
            with tifffile.TiffFile(file_path) as tif:
                page = tif.pages[0]  # just examine the first page
                print(f"TIFF Shape         : {page.shape}")
                print(f"TIFF Dtype         : {page.dtype}")
                print(f"TIFF Size          : {page.imagelength} x {page.imagewidth}")
                print(f"Samples per Pixel  : {page.samplesperpixel}")
                print(f"Photometric        : {page.photometric.name}")
                print(f"Planar Configuration: {page.planarconfig.name}")
        except Exception as e:
            print(f"Error reading TIFF metadata: {e}")
    else:
        print("Unsupported file format.")

# Example usage
analyze_image_color_scheme(r"D:\Datasets\CV_project\GlomDataset\train\56Nx\12_116\img\56Nx_12_116_4_4096_0_img.jpg")
analyze_image_color_scheme(r"D:\Datasets\CV_project\GlomDataset_WSI_backup\train\56Nx\12_116\image_wsi\12-116_wsi.tiff")
