Notebook for Extracting patches and annotations from Monkey Dataset

In [None]:
from tiatoolbox.wsicore.wsireader import WSIReader, VirtualWSIReader
from tiatoolbox.tools.patchextraction import get_patch_extractor
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import re
import cv2
import json
from tqdm import tqdm

In [3]:
def extract_id(file_name: str):
    """
    Give a file name such as 'A_P000001_PAS_CPG.tif',
    Extract the ID: 'A_P000001'
    """
    match = re.match(r"([A-Z]_P\d+)_", file_name, re.IGNORECASE)

    if match:
        return match.group(1)
    else:
        return None


def parse_json_annotations(json_path: str):
    """Extract annotations from json file"""
    with open(json_path, "r") as f:
        annotations = json.load(f)
    return annotations


def filter_points_with_bb(points_dict, bb):
    selected_points = []
    for item in points_dict:
        point = item["point"]
        if (
            bb[0] <= point[0]
            and bb[1] <= point[1]
            and bb[2] >= point[0]
            and bb[3] >= point[1]
        ):
            selected_points.append(point)
    return selected_points


def extract_rois_coords(annotation_rois: list[dict]):
    polygons = []
    for item in annotation_rois:
        polygons.append(np.array(item["polygon"], np.int32))
    return polygons


def mask_from_poly(poly_coords, canvas_width, canvas_height, scale_factor):
    canvas = np.zeros(shape=(canvas_width, canvas_height), dtype=np.uint8)
    for i in range(len(poly_coords)):
        poly_coords[i] = (poly_coords[i] // scale_factor).astype(np.int32)
    cv2.fillPoly(canvas, poly_coords, 1)
    return canvas


def filter_coords_with_mask(xs, ys, binary_mask):
    new_xs, new_ys = [], []
    for i in range(len(xs)):
        x = xs[i]
        y = ys[i]
        try:
            if binary_mask[y, x] == 0:
                continue
            else:
                new_xs.append(x)
                new_ys.append(y)
        except:
            continue
    return new_xs, new_ys


def get_relative_coords(base_coords, bb, mask):
    selected_points = filter_points_with_bb(base_coords, bb)

    relative_x_coords = [int(item[0] - bb[0]) for item in selected_points]
    relative_y_coords = [int(item[1] - bb[1]) for item in selected_points]

    relative_x_coords, relative_y_coords = filter_coords_with_mask(
        relative_x_coords, relative_y_coords, mask
    )
    return relative_x_coords, relative_y_coords


def save_data(
    file_name,
    patch_image_dir,
    cell_mask_dir,
    json_dir,
    patch_image,
    bb,
    lymphocyte_coords,
    monocyte_coords,
):
    patch_name = f"{file_name}_{bb[0]}_{bb[1]}_{bb[2]}_{bb[3]}.npy"
    patch_save_path = os.path.join(patch_image_dir, patch_name)
    os.makedirs(patch_image_dir, exist_ok=True)
    np.save(patch_save_path, patch_image)

    cell_mask = np.zeros(
        shape=(patch_image.shape[0], patch_image.shape[1]), dtype=np.uint8
    )

    for coord in lymphocyte_coords:
        cell_mask[coord[1], coord[0]] = 1

    for coord in monocyte_coords:
        cell_mask[coord[1], coord[0]] = 2

    cell_mask_save_name = f"{file_name}_{bb[0]}_{bb[1]}_{bb[2]}_{bb[3]}.npy"
    cell_mask_save_path = os.path.join(cell_mask_dir, cell_mask_save_name)
    os.makedirs(cell_mask_dir, exist_ok=True)
    np.save(cell_mask_save_path, cell_mask)

    annotations = {"lymphocytes": lymphocyte_coords, "monocytes": monocyte_coords}

    json_save_name = f"{file_name}_{bb[0]}_{bb[1]}_{bb[2]}_{bb[3]}.json"
    json_save_path = os.path.join(json_dir, json_save_name)
    os.makedirs(json_dir, exist_ok=True)
    with open(json_save_path, "w") as file:
        json.dump(annotations, file)

In [None]:
# Path to folder containing all the target WSIs
images_folder = "/home/u1910100/Downloads/Monkey/images/pas-cpg"
# Path to folder containing masks for ROIs for each WSI
tissue_masks_folder = "/home/u1910100/Downloads/Monkey/images/tissue-masks"
# Path to folder containing annotation json files
annotations_folder = "/home/u1910100/Downloads/Monkey/annotations/json"

patch_image_dir = "/home/u1910100/Documents/Monkey/patches_256/images"
cell_mask_dir = "/home/u1910100/Documents/Monkey/patches_256/annotations/masks"
json_dir = "/home/u1910100/Documents/Monkey/patches_256/annotations/json"

for wsi_image_name in tqdm(os.listdir(images_folder)):
    wsi_id = extract_id(wsi_image_name)

    # Annotation file names and paths
    inflammatory_json_name = f"{wsi_id}_inflammatory-cells.json"
    lymphocyte_json_name = f"{wsi_id}_lymphocytes.json"
    monocyte_json_name = f"{wsi_id}_monocytes.json"

    inflammatory_json_path = os.path.join(annotations_folder, inflammatory_json_name)
    lymphocytes_json_path = os.path.join(annotations_folder, lymphocyte_json_name)
    monocyte_json_path = os.path.join(annotations_folder, monocyte_json_name)

    inflammatory_annotations = parse_json_annotations(inflammatory_json_path)
    inflammatory_points = inflammatory_annotations["points"]
    lymphocyte_annotations = parse_json_annotations(lymphocytes_json_path)
    lymphocyte_points = lymphocyte_annotations["points"]
    monocyte_annotations = parse_json_annotations(monocyte_json_path)
    monocyte_points = monocyte_annotations["points"]

    # WSI path and tissue mask path
    wsi_path = os.path.join(images_folder, wsi_image_name)
    mask_name = f"{wsi_id}_mask.tif"
    mask_path = os.path.join(tissue_masks_folder, mask_name)

    # Read WSI and tissue mask
    wsi_reader = WSIReader.open(wsi_path)
    thumb = wsi_reader.slide_thumbnail()
    plt.imshow(thumb)
    plt.show()
    mask_reader = WSIReader.open(mask_path)
    mask_thumbnail = mask_reader.slide_thumbnail()
    binary_mask = mask_thumbnail[:, :, 0]
    plt.imshow(binary_mask)
    plt.show()

    # Extract patches
    patch_extractor = get_patch_extractor(
        input_img=wsi_reader,
        input_mask=binary_mask,
        method_name="slidingwindow",
        patch_size=(256, 256),
        stride=(224, 224),
        resolution=0,
        units="level",
    )
    # print(f"Number of Patches: {len(patch_extractor)}")
    for idx, patch in enumerate(patch_extractor):
        patch = patch_extractor[idx]
        # print(idx)
        bb = patch_extractor.coordinate_list[idx]
        # print(bb)

        mask_patch = mask_reader.read_rect(
            (bb[0], bb[1]),
            (256, 256),
            resolution=0,
            units="level",
        )

        inflammatory_xs, inflammatory_ys = get_relative_coords(
            inflammatory_points, bb, mask_patch[:, :, 0]
        )
        lymphocyte_xs, lymphocyte_ys = get_relative_coords(
            lymphocyte_points, bb, mask_patch[:, :, 0]
        )
        monocyte_xs, monocyte_ys = get_relative_coords(
            monocyte_points, bb, mask_patch[:, :, 0]
        )
        masked_patch = patch * mask_patch
        # plt.imshow(masked_patch)
        # plt.scatter(inflammatory_xs, inflammatory_ys, c="r")
        # plt.scatter(lymphocyte_xs, lymphocyte_ys, c="g")
        # plt.scatter(monocyte_xs, monocyte_ys, c="b")
        # plt.show()
        lympchoyte_coords = []
        lympchoyte_coords = list(zip(lymphocyte_xs, lymphocyte_ys))
        monocyte_coords = []
        monocyte_coords = list(zip(monocyte_xs, monocyte_ys))

        save_data(
            file_name=wsi_id,
            patch_image_dir=patch_image_dir,
            cell_mask_dir=cell_mask_dir,
            json_dir=json_dir,
            patch_image=masked_patch,
            bb=bb,
            lymphocyte_coords=lympchoyte_coords,
            monocyte_coords=monocyte_coords,
        )
    # break