# Demonstration of histomics_stream

Click to open in [[GitHub](https://github.com/DigitalSlideArchive/HistomicsStream/tree/master/example/pytorch.ipynb)] [[Google Colab](https://colab.research.google.com/github/DigitalSlideArchive/HistomicsStream/blob/master/example/pytorch_stream.ipynb)]

The `histomics_stream` Python package sits at the start of any machine learning workflow that is built on the PyTorch machine learning library.  The package is responsible for efficient access to the input image data that will be used to fit a new machine learning model or will be used to predict regions of interest in novel inputs using an already learned model.

## Installation

If you are running this notebook on Google Colab or another system where `histomics_stream` and its dependencies are not yet installed then they can be installed with the following commands.  Note that image readers in addition to openslide are also supported by using, e.g., `large_image[bioformats,ometiff,openjpeg,openslide,tiff]` on the below pip install command line.

In [None]:
# Get histomics_stream and its dependencies
!apt update
!apt install -y python3-openslide openslide-tools
!pip install 'large_image[openslide,tiff]' --find-links https://girder.github.io/large_image_wheels
!pip install histomics_stream

# Get other packages used in this notebook
# N.B. itkwidgets works with jupyter<=3.0.0
!apt install libcudnn8 libcudnn8-dev
!pip install pooch itkwidgets
!jupyter labextension install @jupyter-widgets/jupyterlab-manager jupyter-matplotlib jupyterlab-datawidgets itkwidgets

print(
    "\nNOTE!: On Google Colab you may need to choose 'Runtime->Restart runtime' for these updates to take effect."
)

## Fetching and creating the test data
This notebook has demonstrations that use the files `TCGA-AN-A0G0-01Z-00-DX1.svs` (365 MB) and `TCGA-AN-A0G0-01Z-00-DX1.mask.png` (4 kB),  The pooch commands will fetch them if they are not already available.

In [None]:
import os
import pooch

# download whole slide image
wsi_path = pooch.retrieve(
    fname="TCGA-AN-A0G0-01Z-00-DX1.svs",
    url="https://drive.google.com/uc?export=download&id=19agE_0cWY582szhOVxp9h3kozRfB4CvV&confirm=t&uuid=6f2d51e7-9366-4e98-abc7-4f77427dd02c&at=ALgDtswlqJJw1KU7P3Z1tZNcE01I:1679111148632",
    known_hash="d046f952759ff6987374786768fc588740eef1e54e4e295a684f3bd356c8528f",
    path=str(pooch.os_cache("pooch")) + os.sep + "wsi",
)
print(f"Have {wsi_path}")

# download binary mask image
mask_path = pooch.retrieve(
    fname="TCGA-AN-A0G0-01Z-00-DX1.mask.png",
    url="https://drive.google.com/uc?export=download&id=17GOOHbL8Bo3933rdIui82akr7stbRfta",
    known_hash="bb657ead9fd3b8284db6ecc1ca8a1efa57a0e9fd73d2ea63ce6053fbd3d65171",
    path=str(pooch.os_cache("pooch")) + os.sep + "wsi",
)
print(f"Have {mask_path}")

## Creating a study for use with histomics_stream

We describe the input and desired parameters using standard Python lists and dictionaries.  Here we give a high-level configuration; selection of tiles is done subsequently.

N.B.: __*all*__ values that are number of pixels are based upon the `target_magnification` that is supplied to `FindResolutionForSlide`.  This includes pixel sizes of a slide, chunk, or tile and it includes the pixel coordinates for a chunk or tile.  It applies whether the numbers are supplied to histomics_stream or returned by histomics_stream.  However, if the `magnification_source` is not `exact` the `returned_magnification` may not equal the `target_magnification`; to get the number of pixels that is relevant for the `returned_magnification`, typically these numbers of pixels are multiplied by the ratio `returned_magnification / target_magnification`.  In particular, the *pixel size of the returned tiles* will be the requested size times this ratio.

In [None]:
import histomics_stream as hs
import histomics_stream.pytorch
import torch

In [None]:
# Create a study and insert study-wide information.
# Add a slide to the study, including slide-wide information with it.
my_study0 = dict(
    version="version-1",
    tile_height=256,
    tile_width=256,
    overlap_height=0,
    overlap_width=0,
    slides=dict(
        Slide_0=dict(
            filename=wsi_path,
            slide_name=os.path.splitext(os.path.split(wsi_path)[1])[0],
            slide_group="Group 3",
            chunk_height=2048,
            chunk_width=2048,
        )
    ),
)

# For each slide, find the appropriate resolution given the target_magnification and
# magnification_tolerance.  In this example, we use the same parameters for each slide,
# but this is not required generally.
find_slide_resolution = hs.configure.FindResolutionForSlide(
    my_study0, target_magnification=20, magnification_source="exact"
)
for slide in my_study0["slides"].values():
    find_slide_resolution(slide)
print(f"my_study0 = {my_study0}")

## Tile selection

We are going to demonstrate several approaches to choosing tiles.  Each approach will start with its own copy of the `my_study0` that we have built so far.

In [None]:
import copy

In [None]:
# Demonstrate TilesByGridAndMask without a mask
my_study_by_grid = copy.deepcopy(my_study0)
tiles_by_grid = hs.configure.TilesByGridAndMask(
    my_study_by_grid, overlap_height=32, overlap_width=32, randomly_select=5
)
# We could apply this to a subset of the slides, but we will apply it to all slides in
# this example.
for slide in my_study_by_grid["slides"].values():
    tiles_by_grid(slide)
# Take a look at what we have made
print(f"==== The entire dictionary is now ==== \nmy_study_by_grid = {my_study_by_grid}")
just_tiles = tiles_by_grid.get_tiles(my_study_by_grid)
print(f"==== A quick look at just the tiles is now ====\njust_tiles = {just_tiles}")

In [None]:
# Demonstrate TilesByGridAndMask with a mask
my_study_by_grid_and_mask = copy.deepcopy(my_study0)
tiles_by_grid_and_mask = hs.configure.TilesByGridAndMask(
    my_study_by_grid_and_mask, mask_filename=mask_path, randomly_select=10
)
# We could apply this to a subset of the slides, but we will apply it to all slides in
# this example.
for slide in my_study_by_grid_and_mask["slides"].values():
    tiles_by_grid_and_mask(slide)
# Take a look at what we have made
print(
    f"==== The entire dictionary is now ==== \nmy_study_by_grid_and_mask = {my_study_by_grid_and_mask}"
)
just_tiles = tiles_by_grid_and_mask.get_tiles(my_study_by_grid_and_mask)
print(f"==== A quick look at just the tiles is now ====\njust_tiles = {just_tiles}")

In [None]:
# Demonstrate TilesByList
my_study_by_list = copy.deepcopy(my_study0)
tiles_by_list = hs.configure.TilesByList(
    my_study_by_list,
    randomly_select=5,
    tiles_dictionary=my_study_by_grid["slides"]["Slide_0"]["tiles"],
)
# We could apply this to a subset of the slides, but we will apply it to all slides in
# this example.
for slide in my_study_by_list["slides"].values():
    tiles_by_list(slide)
# Take a look at what we have made
print(f"==== The entire dictionary is now ==== \nmy_study_by_list = {my_study_by_list}")
just_tiles = tiles_by_list.get_tiles(my_study_by_list)
print(f"==== A quick look at just the tiles is now ====\njust_tiles = {just_tiles}")

In [None]:
# Demonstrate TilesRandomly
my_study_randomly = copy.deepcopy(my_study0)
tiles_randomly = hs.configure.TilesRandomly(my_study_randomly, randomly_select=10)
# We could apply this to a subset of the slides, but we will apply it to all slides in
# this example.
for slide in my_study_randomly["slides"].values():
    tiles_randomly(slide)
# Take a look at what we have made
print(
    f"==== The entire dictionary is now ==== \nmy_study_randomly = {my_study_randomly}"
)
just_tiles = tiles_randomly.get_tiles(my_study_randomly)
print(f"==== A quick look at just the tiles is now ====\njust_tiles = {just_tiles}")

## Creating a Dataset

We request tiles indicated by the mask and create a Dataset that has the image data for these tiles as well as associated parameters for each tile, such as its location.

In [None]:
# Demonstrate TilesByGridAndMask with a mask
my_study = copy.deepcopy(my_study0)
tiles_by_grid_and_mask = hs.configure.TilesByGridAndMask(
    my_study, mask_filename=mask_path, mask_threshold=0.5, randomly_select=100
)
for slide in my_study["slides"].values():
    tiles_by_grid_and_mask(slide)
print("Finished selecting tiles.")

create_pytorch_dataloader = hs.pytorch.CreateTorchDataloader()
tiles = create_pytorch_dataloader(my_study)
print("Finished with CreateTorchDataloader")
# print(f"{tile = }")
# print(f"... with tile shape = {tiles.take(1).get_single_element()[0][0].shape}")

## Fetch a model for prediction

We build a arbitrary but reasonable model for demonstration purposes.

Because each element of our Dataset is a tuple `(rgb_image_data, dictionary_of_annotation)`, a typical model that accepts only the former as its input needs to be wrapped.

In [None]:
class MyTorchModel(torch.nn.modules.module.Module):
    def __init__(
        self, in_channels, tile_height, tile_width, num_categories, kernel_size
    ):
        print(f"{in_channels = }")
        print(f"{tile_height = }")
        print(f"{tile_width = }")
        print(f"{num_categories = }")
        print(f"{kernel_size = }")
        super(MyTorchModel, self).__init__()
        out1_channels = 2 * in_channels
        padding = tuple(int((k - 1) // 2) for k in kernel_size)
        self.conv1 = torch.nn.Conv2d(
            in_channels, out1_channels, kernel_size, padding=padding
        )
        out2_channels = 4 * in_channels
        self.conv2 = torch.nn.Conv2d(
            out1_channels, out2_channels, kernel_size, padding=padding
        )
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.flat_size = int(
            in_channels * tile_height * tile_width / (out2_channels / in_channels)
        )
        self.fc1 = torch.nn.Linear(self.flat_size, 128)
        self.fc2 = torch.nn.Linear(128, num_categories)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, self.flat_size)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return x


unwrapped_model = MyTorchModel(
    in_channels=3,
    tile_height=my_study_randomly["tile_height"],
    tile_width=my_study_randomly["tile_width"],
    num_categories=2,
    kernel_size=(5, 5),
)

# At this point it would be standard to train the model.  This example is so dumb that
# we won't do that here.


class WrapModel(torch.nn.modules.module.Module):
    def __init__(self, model, *args, **kwargs):
        super(WrapModel, self).__init__(*args, **kwargs)
        self.model = unwrapped_model

    def forward(self, x):
        p = self.model(x[0])
        return p, x[1]


model = WrapModel(unwrapped_model)
print("Model created")

## Make predictions

In [None]:
import time

print("Starting predictions")
start_time = time.time()
# Consider adding a batch factor to the data loader
predictions = [model(tile) for tile in tiles]
end_time = time.time()
print("Done predicting")
num_inputs = len([0 for tile in tiles])
num_predictions = len(predictions)
print(
    f"Made {num_predictions} predictions for {num_inputs} tiles "
    f"in {end_time - start_time} s."
)
print(f"Average of {(end_time - start_time) / num_inputs} s per tile.")

## Look at internals

In [None]:
tile_iter = iter(tiles)
tile = next(tile_iter)
print(f"                       {type(tiles) = }")
print(f"               {type(tiles.dataset) = }")
print(f"         {type(iter(tiles.dataset)) = }")
print(f"                   {type(tile_iter) = }")
print(f"                        {type(tile) = }")
print(f"                         {len(tile) = }")
print(f"                     {type(tile[0]) = }")
print(f"                     {tile[0].shape = }")
print(f"                     {type(tile[1]) = }")
print(f"{tile[0][0,0,0,0].to(torch.float32) = }")
pred = predictions[0]
print(f"                 {type(predictions) = }")
print(f"                  {len(predictions) = }")
print(f"                        {type(pred) = }")
print(f"                         {len(pred) = }")
print(f"                     {type(pred[0]) = }")
print(f"                     {pred[0].shape = }")
print(f"                           {pred[0] = }")
print(f"                     {type(pred[1]) = }")
print(f"                    {pred[1].keys() = }")