In [None]:
import os
from pprint import pprint
import torch
from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader
from tiatoolbox.tools.patchextraction import get_patch_extractor
from tissue_masker_lite import get_mask
import matplotlib.pyplot as plt
import numpy as np
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
import segmentation_models_pytorch as smp
from tqdm.auto import tqdm
import cv2
import pickle
import skimage
import cv2
import skimage.measure
from torch.utils.data import DataLoader
import json
from multiprocessing import Pool
from tiatoolbox.annotation.storage import Annotation, SQLiteStore
from shapely import Point, Polygon

In [None]:
wsi_name = "106S.tif"
wsi_without_ext = os.path.splitext(wsi_name)[0]
masks_dir = "/home/u1910100/GitHub/TIAger-Torch/output/seg_out"
tumor_stroma_mask_path = os.path.join(masks_dir, f"{wsi_without_ext}_tumor_stroma.npy")

tumor_stroma_mask = np.load(tumor_stroma_mask_path)

plt.imshow(tumor_stroma_mask)
plt.show()

In [None]:
detModel1 = "/home/u1910100/GitHub/TIAger-Torch/runs/cell/fold_1/model_55.pth"
detModel2 = "/home/u1910100/GitHub/TIAger-Torch/runs/cell/fold_2/model_40.pth"
detModel3 = "/home/u1910100/GitHub/TIAger-Torch/runs/cell/fold_3/model_30.pth"
detModel = [detModel1, detModel2, detModel3]

models: list[torch.nn.Module] = []
for model_path in detModel:
    model = smp.Unet(
        encoder_name="efficientnet-b0",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights=None,  # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=1,  # model output channels (number of classes in your dataset)
    )

    model.load_state_dict(torch.load(model_path))

    model.to("cuda")
    model.eval()
    models.append(model)

In [None]:
def imagenet_normalise(img: torch.tensor) -> torch.tensor:
    """Normalises input image to ImageNet mean and std"""

    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])

    for i in range(3):
        img[:, i, :, :] = (img[:, i, :, :] - mean[i]) / std[i]

    return img


def mm2_to_px(mm2, mpp):
    return (mm2 * 1e6) / mpp**2


def dist_to_px(dist, mpp):
    dist_px = int(round(dist / mpp))
    return dist_px


def px_to_mm(px, mpp):
    return px * mpp / 1000


def px_to_um2(px, mpp):
    area_um2 = px * (mpp**2)
    return area_um2


def detections_in_tile(image_tile, det_models):
    patch_size = 128
    overlap = 28
    tile_reader = VirtualWSIReader(image_tile, power=20)

    patch_extractor = get_patch_extractor(
        input_img=tile_reader,
        method_name="slidingwindow",
        patch_size=(patch_size, patch_size),
        stride=(overlap, overlap),
        resolution=20,
        units="power",
    )

    predictions = []
    batch_size = 256

    dataloader = DataLoader(patch_extractor, batch_size=batch_size, shuffle=False)

    for i, imgs in enumerate(tqdm(dataloader, leave=False)):
        imgs = torch.permute(imgs, (0, 3, 1, 2))
        imgs = imgs / 255
        imgs = imagenet_normalise(imgs)
        imgs = imgs.to("cuda").float()

        val_predicts = np.zeros(shape=(imgs.size()[0], 128, 128), dtype=np.float32)

        with torch.no_grad():
            for det_model in det_models:
                temp_out = det_model(imgs)
                temp_out = torch.sigmoid(temp_out)
                temp_out = temp_out.detach().cpu().numpy()[:, 0, :, :]

                val_predicts += temp_out

        val_predicts = val_predicts / 3
        predictions.extend(list(val_predicts))

    return predictions, patch_extractor.coordinate_list


def tile_detection_stats(predictions, coordinate_list, x, y):
    tile_prediction = SemanticSegmentor.merge_prediction(
        (1024, 1024), predictions, coordinate_list
    )
    threshold = 0.99
    tile_prediction_mask = tile_prediction > threshold

    mask_label = skimage.measure.label(tile_prediction_mask)

    stats = skimage.measure.regionprops(mask_label, intensity_image=tile_prediction)
    output_points = []
    annotations = []
    for region in stats:
        centroid = np.round(region["centroid"]).astype(int)

        c, r, confidence = (
            np.round(centroid[1]),
            np.round(centroid[0]),
            region["mean_intensity"],
        )

        c1 = c + x
        r1 = r + y
        prediction_record = {
            "point": [
                float(px_to_mm(c1, 0.5)),
                float(px_to_mm(r1, 0.5)),
                float(0.5009999871253967),
            ],
            "probability": float(confidence),
        }

        output_points.append(prediction_record)
        annotations.append((int(c1), int(r1)))
    return annotations, output_points

In [None]:
wsi_dir = "/home/u1910100/Documents/Tiger_Data/wsitils/images/"
wsi_path = os.path.join(wsi_dir, wsi_name)
wsi = WSIReader.open(wsi_path)

tile_extractor = get_patch_extractor(
    input_img=wsi,
    method_name="slidingwindow",
    patch_size=(1024, 1024),
    resolution=20,
    units="power",
    input_mask=tumor_stroma_mask,
)
# Each tile of size 1024x1024
annotations = []
output_dict = {
    "type": "Multiple points",
    "version": {"major": 1, "minor": 0},
    "points": [],
}

for i, tile in enumerate(tqdm(tile_extractor)):
    bounding_box = tile_extractor.coordinate_list[i]  # (x_start, y_start, x_end, y_end)
    predictions, coordinates = detections_in_tile(tile, models)
    annotations_tile, output_points_tile = tile_detection_stats(
        predictions, coordinates, bounding_box[0], bounding_box[1]
    )
    annotations.extend(annotations_tile)
    output_dict["points"].extend(output_points_tile)

output_path = (
    f"/home/u1910100/GitHub/TIAger-Torch/output/det_out/{wsi_without_ext}.json"
)
with open(output_path, "w") as fp:
    json.dump(output_dict, fp, indent=4)

output_path = (
    f"/home/u1910100/GitHub/TIAger-Torch/output/det_out/{wsi_without_ext}_points.json"
)
with open(output_path, "w") as fp:
    json.dump(annotations, fp, indent=4)

In [None]:
def point_to_box(x, y, size):
    """Convert centerpoint to bounding box of fixed size"""
    return np.array([x - size, y - size, x + size, y + size])


def get_centerpoints(box, dist):
    """Returns centerpoints of box"""
    return (box[0] + dist, box[1] + dist)


def non_max_suppression_fast(boxes, overlapThresh):
    """Very efficient NMS function taken from pyimagesearch"""

    # if there are no boxes, return an empty list
    if len(boxes) == 0:
        return []
    # if the bounding boxes integers, convert them to floats --
    # this is important since we'll be doing a bunch of divisions
    if boxes.dtype.kind == "i":
        boxes = boxes.astype("float")
    # initialize the list of picked indexes
    pick = []
    # grab the coordinates of the bounding boxes
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    # compute the area of the bounding boxes and sort the bounding
    # boxes by the bottom-right y-coordinate of the bounding box
    area = (x2 - x1 + 1) * (y2 - y1 + 1)
    idxs = np.argsort(y2)
    # keep looping while some indexes still remain in the indexes
    # list
    while len(idxs) > 0:
        # grab the last index in the indexes list and add the
        # index value to the list of picked indexes
        last = len(idxs) - 1
        i = idxs[last]
        pick.append(i)
        # find the largest (x, y) coordinates for the start of
        # the bounding box and the smallest (x, y) coordinates
        # for the end of the bounding box
        xx1 = np.maximum(x1[i], x1[idxs[:last]])
        yy1 = np.maximum(y1[i], y1[idxs[:last]])
        xx2 = np.minimum(x2[i], x2[idxs[:last]])
        yy2 = np.minimum(y2[i], y2[idxs[:last]])
        # compute the width and height of the bounding box
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)
        # compute the ratio of overlap
        overlap = (w * h) / area[idxs[:last]]
        # delete all indexes from the index list that have
        idxs = np.delete(
            idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
        )
    # return only the bounding boxes that were picked using the
    # integer data type
    return boxes[pick].astype("int")


def slide_nms(slide_path, annotation_path, tile_size):
    """Iterate over WholeSlideAnnotation and perform NMS. For this to properly work, tiles need to be larger than model inference patches."""
    # Open WSI and detection points file
    wsi = WSIReader.open(slide_path)
    with open(annotation_path, "r") as fp:
        cell_points = json.load(fp)
    print(f"{len(cell_points)} before nms")
    store = points_to_annotation_store(cell_points)
    # Get base line WSI size
    shape = wsi.slide_dimensions(resolution=0, units="level")

    center_nms_points = []

    box_size = 8
    # get 2048x2048 patch coordinates without overlap
    for y_pos in range(0, shape[1], tile_size):
        for x_pos in range(0, shape[0], tile_size):
            # Select annotations within 2048x2048 box
            box = [x_pos, y_pos, x_pos + tile_size, y_pos + tile_size]
            patch_points = get_points_within_box(store, box)

            if len(patch_points) < 2:
                continue

            # Convert each point to a 8x8 box
            boxes = np.array([point_to_box(x[0], x[1], box_size) for x in patch_points])
            nms_boxes = non_max_suppression_fast(boxes, 0.7)
            for box in nms_boxes:
                center_nms_points.append(get_centerpoints(box, box_size))
    return center_nms_points


def points_to_annotation_store(points: list):
    """
    Args: points(list): list of (x,y) coordinates
    """
    annotation_store = SQLiteStore()

    for coord in points:
        annotation_store.append(
            Annotation(geometry=Point(coord[0], coord[1]), properties={"class": 1})
        )

    return annotation_store


def get_points_within_box(annotation_store, box):
    query_poly = Polygon.from_bounds(box[0], box[1], box[2], box[3])
    anns = annotation_store.query(geometry=query_poly)
    results = []
    for point in anns.items():
        results.append(point[1].coords[0])
    return results


def get_mask_area(mask_path):
    """Get the size of a mask in pixels where the mask is 1."""

    mask = np.load(mask_path)
    counts = np.count_nonzero(mask)
    down = 6
    area = counts * down**2
    return area


def create_til_score(wsi_path, cell_points_path, tumor_stroma_mask_path):
    nms_points = slide_nms(
        slide_path=wsi_path, annotation_path=cell_points_path, tile_size=2048
    )

    cell_counts = len(nms_points)
    print(f"TIL counts = {cell_counts}")

    til_area = dist_to_px(4, 0.5) ** 2
    tils_area = cell_counts * til_area

    stroma_area = get_mask_area(tumor_stroma_mask_path)
    print(f"stroma area = {stroma_area}")
    tilscore = int((100 / int(stroma_area)) * int(tils_area))
    tilscore = min(100, tilscore)
    tilscore = max(0, tilscore)
    print(f"tilscore = {tilscore}")

In [None]:
wsi_without_ext = "106S"
cell_points_dir = "/home/u1910100/GitHub/TIAger-Torch/output/det_out"
cell_points_path = os.path.join(cell_points_dir, f"{wsi_without_ext}_points.json")
wsi_dir = "/home/u1910100/Documents/Tiger_Data/wsitils/images/"
wsi_path = os.path.join(wsi_dir, f"{wsi_without_ext}.tif")
masks_dir = "/home/u1910100/GitHub/TIAger-Torch/output/seg_out"
tumor_stroma_mask_path = os.path.join(masks_dir, f"{wsi_without_ext}_tumor_stroma.npy")

create_til_score(wsi_path, cell_points_path, tumor_stroma_mask_path)

In [None]:
cell_points_dir = "/home/u1910100/GitHub/TIAger-Torch/output/det_out"
cell_points_path = os.path.join(cell_points_dir, f"{wsi_without_ext}_points.json")
wsi_name = "106S.tif"
wsi_without_ext = os.path.splitext(wsi_name)[0]
wsi_dir = "/home/u1910100/Documents/Tiger_Data/wsitils/images/"
wsi_path = os.path.join(wsi_dir, wsi_name)


with open(cell_points_path, "r") as fp:
    cell_points = json.load(fp)
print(len(cell_points))

new_points = slide_nms(
    slide_path=wsi_path, annotation_path=cell_points_path, tile_size=2048
)
print(len(new_points))