## Prov-GigaPath Demo

This notebook provides a quick walkthrough of the Prov-GigaPath models. We will start by demonstrating how to download the Prov-GigaPath models from HuggingFace. Next, we will show an example of pre-processing a slide. Finally, we will demonstrate how to run Prov-GigaPath on the sample slide.

### Prepare HF Token

To begin, please request access to the model from our HuggingFace repository: https://huggingface.co/prov-gigapath/prov-gigapath.

Once approved, set the HF_TOKEN to access the model.

In [None]:
import os

# Please set your Hugging Face API token
# os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN"

homedir_path = os.path.expanduser("~")
assert ("HF_TOKEN" in os.environ) or os.path.exists(f"{homedir_path}/.cache/huggingface/token"), "Please set the HF_TOKEN environment variable to your Hugging Face API token or make sure the token is cached in ~/.cache/huggingface/token"

In [None]:
import matplotlib.pyplot as plt

import torch
import timm

## Constants

In [None]:
PROJECT_DIR = ".."

local_dir_name = "sample_data"
local_dir = os.path.join(PROJECT_DIR, local_dir_name)

slide_file_name = "PROV-000-000001.ndpi"
slide_hf_path = os.path.join(local_dir_name, slide_file_name)
slide_path = os.path.join(local_dir, "PROV-000-000001.ndpi")


tile_save_dir = os.path.join(local_dir, "outputs/preprocessing")
specific_slide_tiles_dir = f"{tile_save_dir}/output/{slide_file_name}"
os.makedirs(specific_slide_tiles_dir, exist_ok=True)

features_save_dir = os.path.join(local_dir, "outputs/features")
specific_slide_features_dir = f"{features_save_dir}/output/{slide_file_name}"
os.makedirs(specific_slide_features_dir, exist_ok=True)

### Download a sample slide

In [None]:
import huggingface_hub

huggingface_hub.hf_hub_download(
    "prov-gigapath/prov-gigapath",
    filename=slide_hf_path,
    local_dir=PROJECT_DIR,
    force_download=True,
)

### Tiling

Whole-slide images are giga-pixel in size. To efficiently process these enormous images, we use a tiling technique that divides them into smaller, more manageable tile images. As an example, we demonstrate how to process a single slide.

NOTE: Prov-GigaPath is trained with slides preprocessed at 0.5 MPP. Ensure that you use the appropriate level for the 0.5 MPP.

In [None]:
from gigapath.pipeline import tile_one_slide

print("NOTE: Prov-GigaPath is trained with 0.5 mpp preprocessed slides. Please make sure to use the appropriate level for the 0.5 MPP")
tile_one_slide(slide_path, save_dir=tile_save_dir, level=1) # tile_size=256 is the default

### Load the tile images

In [None]:
image_paths = [
    os.path.join(specific_slide_tiles_dir, img)
    for img in os.listdir(specific_slide_tiles_dir)
    if img.endswith(".png")
]

print(f"Found {len(image_paths)} image tiles")

## Attempt to match extracted tile and its real coordinates - Failed

In [None]:
tile_paths = sorted(image_paths)
tile_file_names = [
    os.path.basename(sample_tile_path) for sample_tile_path in tile_paths
]
coordinates = [
    tuple(
        int(coord.replace("x", "").replace("y", ""))
        for coord in os.path.basename(sample_tile_path).split(".png")[0].split("_")
    )
    for sample_tile_path in tile_paths
]
# choose the coordinates pair with the largest x, given larges x, choose the largest y
# do it in 2 steps to avoid sorting the coordinates
max_x = max(coordinates, key=lambda x: x[0])[0]
max_x_coordinates = [coord for coord in coordinates if coord[0] == max_x]
max_y = max(max_x_coordinates, key=lambda x: x[1])[1]

max_tile_file_name = f"{max_x}x_{max_y}y.png"
assert (
    max_tile_file_name in tile_file_names
), f"Missing tile at coordinates ({max_x}, {max_y})"

print(max_tile_file_name)
plt.imshow(plt.imread(os.path.join(specific_slide_tiles_dir, max_tile_file_name)))

In [None]:
import openslide

# maybe we need to offset the coordinates by the min_x, min_y
min_x = min(coordinates, key=lambda x: x[0])[0]
min_y = min(coordinates, key=lambda x: x[1])[1]
print(f"Min coordinates: ({min_x}, {min_y})")

# does no match the max_x, max_y tile extracted above

sample_slide = openslide.OpenSlide(slide_path)
print(help(sample_slide.read_region))
sample_slide.read_region((min_x, min_y), 5, (256, 256))

### Load the Prov-GigaPath model (tile and slide encoder models)

In [None]:
tile_encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)

In [None]:
# from gigapath.pipeline import load_tile_slide_encoder

# Load the tile and slide encoder models
# NOTE: The CLS token is not trained during the slide-level pretraining.
# Here, we enable the use of global pooling for the output embeddings.

# tile_encoder, slide_encoder_model = load_tile_slide_encoder(global_pool=True)

### Run tile-level inference

In [None]:
from gigapath.pipeline import run_inference_with_tile_encoder

tile_encoder_outputs = run_inference_with_tile_encoder(image_paths, tile_encoder, batch_size=32)

for k in tile_encoder_outputs.keys():
    print(f"tile_encoder_outputs[{k}].shape: {tile_encoder_outputs[k].shape}")

In [None]:
tile_encoder_outputs

In [None]:
# save features and coordinates pytorch tensors
torch.save(tile_encoder_outputs["tile_embeds"], os.path.join(specific_slide_features_dir, "tile_embeds.pt"),)
torch.save(tile_encoder_outputs["coords"], os.path.join(specific_slide_features_dir, "coords.pt"),)

In [None]:
# free up GPU memory
del tile_encoder
torch.cuda.empty_cache()

### Run slide-level inference

In [None]:
tile_encoder_outputs = {}
tile_encoder_outputs["tile_embeds"] = torch.load(os.path.join(specific_slide_features_dir, "tile_embeds.pt"))
tile_encoder_outputs["coords"] = torch.load(os.path.join(specific_slide_features_dir, "coords.pt"))

tile_encoder_outputs

In [None]:
import gigapath.slide_encoder as slide_encoder

# load from the web
slide_encoder_model = slide_encoder.create_model(
    "hf_hub:prov-gigapath/prov-gigapath",
    "gigapath_slide_enc12l768d",
    1536,
    global_pool=True,  # like in the demo cell above
)

In [None]:
from gigapath.pipeline import run_inference_with_slide_encoder
# run inference with the slide encoder
slide_embeds = run_inference_with_slide_encoder(slide_encoder_model=slide_encoder_model, **tile_encoder_outputs)

for k in slide_embeds.keys():
    print(f"slide_embeds[{k}].shape: {slide_embeds[k].shape}")

In [None]:
slide_embeds["last_layer_embed"]

In [None]:
# from gigapath.slide_encoder function `coords_to_pos` - this is not needed to make it work, but it is useful to understand how the positional embeddings are calculated

slide_ngrids = 1000

coords_ = torch.floor(tile_encoder_outputs["coords"] / 256.0)
print(coords_)
print("/n coords_.min(axis=0)", coords_.min(axis=0))
print("/n coords_.max(axis=0)", coords_.max(axis=0))


# pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1]
pos = coords_[..., 0] * slide_ngrids + coords_[..., 1]

# return pos.long() + 1  # add 1 for the cls token
pos.long() + 1