Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train hovernet #76

Merged
merged 14 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,45 @@ A toolkit for computational pathology and machine learning.

## Installation

1. Clone repo

````
git clone https://github.com/Dana-Farber/pathml.git
cd pathml
````

2. Set Up Conda Environment

````
conda create --name pathml
conda activate pathml
````

3. Install CUDA. This step only applies if you want to use GPU acceleration for model training or other tasks. This guide should work, but for the most up-to-date instructions, refer to the [official PyTorch installation instructions](https://pytorch.org/get-started/locally/).

- Check the version of CUDA:

````
nvidia-smi
````

- Install correct version of `cudatoolkit`:

````
# update this command with your CUDA version number
conda install cudatoolkit=11.0
````


4. Install PathML

````
git clone https://github.com/Dana-Farber/pathml.git # clone repo
cd pathml # enter repo directory
conda env create -f environment.yml # create conda environment
conda activate pathml # activate conda environment
pip install -e . # install pathml in conda environment
conda env update -f environment.yml # install dependencies
pip install -e . # install pathml
````

>> to verify PyTorch installation with GPU support: `python -c "import torch; print(torch.cuda.is_available())"`

## Generate Documentation

This repo is not yet open to the public. Once we open source it, we will host documentation online.
Expand Down
1 change: 1 addition & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ These notebooks give examples of how to use various ``PathML`` features and deve

examples/link_stain_normalization
examples/link_nucleus_detection
examples/link_train_hovernet
3 changes: 3 additions & 0 deletions docs/source/examples/link_train_hovernet.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../../../examples/train_hovernet.ipynb"
}
12 changes: 6 additions & 6 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ Models
.. table::
:widths: 20, 20, 60

========= ============ =============
Model Reference Description
========= ============ =============
U-net [Unet]_ A model for segmentation in biomedical images
HoVer-Net [HoVerNet]_ A model for nucleus segmentation and classification in H&E images
========= ============ =============
===================================== ============ =============
Model Reference Description
===================================== ============ =============
U-net (in progress) [Unet]_ A model for segmentation in biomedical images
:class:`~pathml.ml.hovernet.HoVerNet` [HoVerNet]_ A model for nucleus segmentation and classification in H&E images
===================================== ============ =============

You can also use models from `torchvision.models <https://pytorch.org/docs/stable/torchvision/models.html>`_, or create your own!

Expand Down
746 changes: 746 additions & 0 deletions examples/train_hovernet.ipynb

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pathml/datasets/pannuke.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def train_dataloader(self):
return data.DataLoader(
dataset = self._get_dataset(fold_ix = self.split),
batch_size = self.batch_size,
shuffle = self.shuffle
shuffle = self.shuffle,
pin_memory=True
)

@property
Expand All @@ -340,7 +341,8 @@ def valid_dataloader(self):
return data.DataLoader(
self._get_dataset(fold_ix = fold_ix),
batch_size = self.batch_size,
shuffle = self.shuffle
shuffle = self.shuffle,
pin_memory=True
)

@property
Expand All @@ -356,5 +358,6 @@ def test_dataloader(self):
return data.DataLoader(
self._get_dataset(fold_ix = fold_ix),
batch_size = self.batch_size,
shuffle = self.shuffle
shuffle = self.shuffle,
pin_memory=True
)
82 changes: 30 additions & 52 deletions pathml/ml/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def forward(self, inputs):
return out


class HoverNet(nn.Module):
class HoVerNet(nn.Module):
"""
Model for simultaneous segmentation and classification based on HoVer-Net.
Can also be used for segmentation only, if class labels are not supplied.
Expand Down Expand Up @@ -426,6 +426,11 @@ def _get_gradient_hv(hv_batch, kernel_size=5):
"""
assert hv_batch.shape[1] == 2, f"inputs have shape {hv_batch.shape}. Expecting tensor of shape (B, 2, H, W)"
h_kernel, v_kernel = get_sobel_kernels(kernel_size, dt = hv_batch.dtype)

# move kernels to same device as batch
h_kernel = h_kernel.to(hv_batch.device)
v_kernel = v_kernel.to(hv_batch.device)

# add extra dims so we can convolve with a batch
h_kernel = h_kernel.unsqueeze(0).unsqueeze(0)
v_kernel = v_kernel.unsqueeze(0).unsqueeze(0)
Expand All @@ -436,7 +441,10 @@ def _get_gradient_hv(hv_batch, kernel_size=5):

h_grad = F.conv2d(h_inputs, h_kernel, stride = 1, padding = 2)
v_grad = F.conv2d(v_inputs, v_kernel, stride = 1, padding = 2)


del h_kernel
del v_kernel

return h_grad, v_grad


Expand Down Expand Up @@ -478,7 +486,7 @@ def _loss_hv_mse(hv_out, true_hv):
return loss


def loss_HoVerNet(outputs, ground_truth, n_classes=None):
def loss_hovernet(outputs, ground_truth, n_classes=None):
"""
Compute loss for HoVer-Net.
Equation (1) in Graham et al.
Expand All @@ -505,7 +513,6 @@ def loss_HoVerNet(outputs, ground_truth, n_classes=None):
Medical Image Analysis, 58, p.101563.
"""
true_mask, true_hv = ground_truth
true_hv = true_hv.float()
# unpack outputs, and also calculate nucleus masks
if n_classes is None:
np_out, hv = outputs
Expand Down Expand Up @@ -574,8 +581,8 @@ def _post_process_single_hovernet(np_out, hv_out, small_obj_size_thresh=10, kern
https://github.com/vqdang/hover_net/blob/14c5996fa61ede4691e87905775e8f4243da6a62/models/hovernet/post_proc.py#L27

Args:
np_out: Output of NP branch. Tensor of shape (2, H, W) of logit predictions for binary classification
hv_out: Output of HV branch. Tensor of shape (2, H, W) of predictions for horizontal/vertical maps
np_out (torch.Tensor): Output of NP branch. Tensor of shape (2, H, W) of logit predictions for binary classification
hv_out (torch.Tensor): Output of HV branch. Tensor of shape (2, H, W) of predictions for horizontal/vertical maps
small_obj_size_thresh (int): Minimum number of pixels in regions. Defaults to 10.
kernel_size (int): Width of Sobel kernel used to compute horizontal and vertical gradients.
h (float): hyperparameter for thresholding nucleus probabilities. Defaults to 0.5.
Expand All @@ -584,7 +591,7 @@ def _post_process_single_hovernet(np_out, hv_out, small_obj_size_thresh=10, kern
"""
# compute pixel probabilities from logits, apply threshold, and get into np array
np_preds = F.softmax(np_out, dim = 0)[1, :, :]
np_preds = np_preds.detach().clone().numpy()
np_preds = np_preds.numpy()

np_preds[np_preds >= h] = 1
np_preds[np_preds < h] = 0
Expand All @@ -596,8 +603,9 @@ def _post_process_single_hovernet(np_out, hv_out, small_obj_size_thresh=10, kern
tau_q_h = np_preds

# normalize hv predictions, and compute horizontal and vertical gradients, and normalize again
h_out = hv_out[0, ...].detach().clone().numpy().astype(np.float32)
v_out = hv_out[1, ...].detach().clone().numpy().astype(np.float32)
hv_out = hv_out.numpy().astype(np.float32)
h_out = hv_out[0, ...]
v_out = hv_out[1, ...]
# https://stackoverflow.com/a/39037135
h_normed = cv2.normalize(h_out, None, alpha = 0, beta = 1, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
v_normed = cv2.normalize(v_out, None, alpha = 0, beta = 1, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
Expand Down Expand Up @@ -657,25 +665,31 @@ def post_process_batch_hovernet(outputs, n_classes, small_obj_size_thresh=10, ke
segmentation. Defaults to 0.5.

Returns:
If n_classes is None, returns det_out. In classification setting, returns (det_out, class_out).
np.ndarray: If n_classes is None, returns det_out. In classification setting, returns (det_out, class_out).

- det_out is a Tensor of shape (B, H, W)
- class_out is a Tensor of shape (B, n_classes, H, W)
- det_out is np.ndarray of shape (B, H, W)
- class_out is np.ndarray of shape (B, n_classes, H, W)

Each pixel is labelled from 0 to n, where n is the number of individual nuclei detected. 0 pixels indicate
background. Pixel values i indicate that the pixel belongs to the ith nucleus.
"""

assert len(outputs) in {2, 3}, f"outputs has size {len(outputs)}. Must have size 2 (for segmentation) or 3 (for " \
f"classification)"

if n_classes is None:
np_out, hv_out = outputs
# send ouputs to cpu
np_out = np_out.detach().cpu()
hv_out = hv_out.detach().cpu()
classification = False
else:
assert len(outputs) == 3, f"n_classes={n_classes} but outputs has {len(outputs)} elements. Expecting a list " \
f"of length 3, one for each of np, hv, and nc branches"
np_out, hv_out, nc_out = outputs
# send ouputs to cpu
np_out = np_out.detach().cpu()
hv_out = hv_out.detach().cpu()
nc_out = nc_out.detach().cpu()
classification = True

batchsize = hv_out.shape[0]
Expand All @@ -691,12 +705,13 @@ def post_process_batch_hovernet(outputs, n_classes, small_obj_size_thresh=10, ke
# get the pixel-level class predictions from the logits
nc_out_preds = F.softmax(nc_out, dim = 1).argmax(dim = 1)

out_classification = np.zeros_like(nc_out, dtype = np.uint8)
out_classification = np.zeros_like(nc_out.numpy(), dtype = np.uint8)

for batch_ix, nuc_preds in enumerate(out_detection_list):
# get labels of nuclei from nucleus detection
nucleus_labels = list(np.unique(nuc_preds))
nucleus_labels.remove(0) # 0 is background
if 0 in nucleus_labels:
nucleus_labels.remove(0) # 0 is background
nucleus_class_preds = nc_out_preds[batch_ix, ...]

out_class_preds_single = out_classification[batch_ix, ...]
Expand Down Expand Up @@ -769,40 +784,3 @@ def _vis_outputs_single(images, preds, n_classes, index=0, ax=None, markersize=5
x, y = segmentation_lines(nuclei_mask.astype(np.uint8))
ax.scatter(x, y, color = palette[i], marker = ".", s = markersize)
ax.axis("off")


def vis_outputs(images, preds, n_classes, n_images, markersize=5, palette=None):
"""
Plot the results of HoVer-Net predictions for multiple images in a batch, overlayed over
original images.

Args:
images: Input RGB image batch. Tensor of shape (B, 3, H, W).
preds: Postprocessed outputs of HoVer-Net. From post_process_batch_hovernet(). Each pixel should be either 0
for background or an integer n indicating that the pixel is part of the nth nucleus. Can be either:
- Tensor of shape (B, H, W), in the context of nucleus detection.
- Tensor of shape (B, n_classes, H, W), in the context of nucleus classification.
n_classes (int): Number of classes for classification setting, or None to indicate detection setting.
n_images (int): number of images to plot. Must be a multiple of 4.
markersize: Size of markers used to outline nuclei
palette (list): list of colors to use for plotting. If None, uses matplotlib.colors.TABLEAU_COLORS.
Defaults to None
"""
if palette is None:
palette = list(TABLEAU_COLORS.values())

if n_classes is not None:
assert n_classes == preds.shape[1], f"preds dimension {preds.shape[1]} doesn't match n_classes {n_classes}"
assert len(palette) >= n_classes, f"len(palette)={len(palette)} < n_classes={n_classes}."

assert len(preds.shape) in [3, 4], f"Preds shape is {preds.shape}. Must be (B, H, W) or (B, n_classes, H, W)"
assert n_images <= images.shape[0], f"input n_images {n_images} is larger than batch size {images.shape[0]}"
assert n_images % 4 == 0, f"input n_images {n_images} must be a multiple of 4"

nr = int(np.ceil(n_images / 4))
fig, ax = plt.subplots(nrows = nr, ncols = 4, figsize = (10, 4 * nr))

for i, ax in enumerate(ax.ravel()):
_vis_outputs_single(images, preds, n_classes, ax = ax, index = i, markersize = markersize, palette = palette)

plt.tight_layout()
63 changes: 63 additions & 0 deletions pathml/ml/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Utilities for ML module
import torch
from torch.nn import functional as F
import numpy as np


def center_crop_im_batch(batch, dims, batch_order = "BCHW"):
Expand Down Expand Up @@ -76,6 +77,31 @@ def dice_loss(true, logits, eps=1e-3):
return loss


def dice_score(pred, truth, eps=1e-3):
"""
Calculate dice score for two tensors of the same shape.
If tensors are not already binary, they are converted to bool by zero/non-zero.

Args:
pred (np.ndarray): Predictions
truth (np.ndarray): ground truth
jacob-rosenthal marked this conversation as resolved.
Show resolved Hide resolved
eps (float, optional): Constant used for numerical stability to avoid divide-by-zero errors. Defaults to 1e-3.

Returns:
float: Dice score
"""
assert isinstance(truth, np.ndarray) and isinstance(pred, np.ndarray), \
f"pred is of type {type(pred)} and truth is type {type(truth)}. Both must be np.ndarray"
assert pred.shape == truth.shape, f"pred shape {pred.shape} does not match truth shape {truth.shape}"
# turn into binary if not already
pred = pred != 0
truth = truth != 0

num = 2 * np.sum(pred.flatten() * truth.flatten())
denom = np.sum(pred) + np.sum(truth) + eps
return float(num / denom)


def get_sobel_kernels(size, dt=torch.float32):
"""
Create horizontal and vertical Sobel kernels for approximating gradients
Expand All @@ -96,3 +122,40 @@ def get_sobel_kernels(size, dt=torch.float32):

return kernel_h, kernel_v


def wrap_transform_multichannel(transform):
"""
Wrapper to make albumentations transform compatible with a multichannel mask.
Channel should be in first dimension, i.e. (n_mask_channels, H, W)

Args:
transform: Albumentations transform. Must have 'additional_targets' parameter specified with
a total of `n_channels` key,value pairs. All values must be 'mask' but the keys don't matter.
e.g. for a mask with 3 channels, you could use:
`additional targets = {'mask1' : 'mask', 'mask2' : 'mask', 'pathml' : 'mask'}`

Returns:
function that can be called with a multichannel mask argument
"""
targets = transform.additional_targets
n_targets = len(targets)

# make sure that everything is correct so that transform is correctly applied
assert all([v == "mask" for v in targets.values()]), \
f"error all values in transform.additional_targets must be 'mask'."

def transform_out(*args, **kwargs):
mask = kwargs.pop("mask")
assert mask.ndim == 3, f"input mask shape {mask.shape} must be 3-dimensions ()"
assert mask.shape[0] == n_targets, \
f"input mask shape {mask.shape} doesn't match additional_targets {transform.additional_targets}"

mask_to_dict = {key : mask[i, :, :] for i, key in enumerate(targets.keys())}
kwargs.update(mask_to_dict)
out = transform(*args, **kwargs)
mask_out = np.stack([out.pop(key) for key in targets.keys()], axis=0)
assert mask_out.shape == mask.shape
out["mask"] = mask_out
return out

return transform_out
32 changes: 32 additions & 0 deletions pathml/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import TABLEAU_COLORS

from pathml.preprocessing.utils import segmentation_lines


def plot_segmentation(ax, masks, palette=None, markersize=5):
jacob-rosenthal marked this conversation as resolved.
Show resolved Hide resolved
"""
Plot segmentation contours. Supports multi-class masks.

Args:
ax: matplotlib axis
masks (np.ndarray): Mask array of shape (n_masks, H, W). Zeroes are background pixels.
palette: color palette to use. if None, defaults to matplotlib.colors.TABLEAU_COLORS
markersize (int): Size of markers used on plot. Defaults to 5
"""
assert masks.ndim == 3
n_channels = masks.shape[0]

if palette is None:
palette = list(TABLEAU_COLORS.values())

nucleus_labels = list(np.unique(masks))
if 0 in nucleus_labels:
nucleus_labels.remove(0) # background
# plot each individual nucleus
for label in nucleus_labels:
for i in range(n_channels):
nuclei_mask = masks[i, ...] == label
x, y = segmentation_lines(nuclei_mask.astype(np.uint8))
ax.scatter(x, y, color = palette[i], marker = ".", s = markersize)
Loading