Copyright (c) 2025 University of Michigan. All rights reserved.  
Licensed under the MIT License. See LICENSE for license information.

# Notebook for running Mask R-CNN model inference on data subset


In [None]:
from glob import glob

import torch
from torchvision.transforms.functional import adjust_contrast, adjust_brightness

import matplotlib.pyplot as plt

from ds.datasets.db_improc import process_read_srh
from ds.eval.inference import get_model, get_xform
from ds.eval.common import score_threshold_with_matrix_nms, output_mask_to_images

In [None]:
normalize_im = lambda x: (adjust_brightness(adjust_contrast(x, 2), 2)*255).to(torch.uint8)

In [None]:
ckpt_path = "/path/to/elucidate_model.ckpt"
classes = ["na", "nuclei","cyto", "rbc", "mp"]

model = get_model(ckpt_path, num_classes=len(classes))
aug = get_xform()


In [None]:
# Reading images
image_list = [
    "/path/to/patch1.tif",
    "/path/to/patch2.tif"
]

# Or use glob to get all images in a directory
#image_list = glob("/path/to/patches/*.tif")

# Preprocess images
raw_ims = [process_read_srh(i) for i in image_list]
ims = [aug(i, {})[0] for i in raw_ims]

In [None]:
# Inference on image
results = []
with torch.inference_mode():
    for im_b in torch.split(torch.stack(ims), 16):
        results_i = model(im_b.to("cuda"))
        results.extend([{k: j[k].detach().to("cpu") for k in j} for j in results_i])

results = [
    score_threshold_with_matrix_nms(r, confidence_threshold=0.50)
    for r in results
]

In [None]:
# Visualization
mask_box_images = [
    output_mask_to_images(normalize_im(ims[i]),
                          results[i]["masks"],
                          results[i]["boxes"])
    for i in range(len(ims))
]
mask_img = [i[0].permute(1, 2, 0) for i in mask_box_images]
box_img = [i[1].permute(1, 2, 0) for i in mask_box_images]

In [None]:
# Show images
for im, i, j in zip(ims, mask_img, box_img):
    fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(12,4))
    ax0.imshow(normalize_im(im).permute(1, 2, 0))
    ax1.imshow(i)
    ax2.imshow(j)
    for ax in (ax0, ax1, ax2): ax.axis("off")
    fig.tight_layout()