# Working with medical images in Meerkat

To motivate Meerkat, let's consider the task of detecting pneumothorax (*i.e.* a collapsed lung) in chest X-rays ([Irvin *et al.*](https://arxiv.org/pdf/1901.07031.pdf), [Taylor *et al.*](https://journals.plos.org/plosmedicine/article?id=10.1371/journal.pmed.1002697)). As we develop a model for this task, we encounter data of different types – from X-ray images to structured metadata to embeddings extracted from a trained model. Meerkat provides the `DataPanel`, a columnar data structure (similar to a Pandas [DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html)) that can house all of these data under one roof. Keeping them together enables quicker, more adventurous model iteration, fine-grained error analysis, and easier data exploration and inspection.

**Time**: ~20 minutes

**Colab Runtime**: We recommend running this Colab with a GPU runtime. To change the runtime, 
1. Click on `Runtime` on the top navigation bar
2. Select `Change runtime type`
3. Select `GPU` from the dropdown

**TODOs**
- Remove kaggle username/token

# Setup

In [None]:
!pip install -q meerkat-ml[medimg,text]
!pip install kaggle
!pip install -q torchxrayvision
!pip install umap-learn
!python3 -m spacy download en_core_web_sm

import meerkat.version as mversion
import torch
print("meerkat version: ", mversion.__version__)
print("torch version: ", torch.__version__)

[K     |████████████████████████████████| 163kB 7.5MB/s 
[K     |████████████████████████████████| 2.5MB 11.8MB/s 
[K     |████████████████████████████████| 839kB 48.7MB/s 
[K     |████████████████████████████████| 35.5MB 88kB/s 
[K     |████████████████████████████████| 2.9MB 39.9MB/s 
[K     |████████████████████████████████| 3.2MB 34.5MB/s 
[K     |████████████████████████████████| 901kB 38.2MB/s 
[K     |████████████████████████████████| 3.3MB 35.2MB/s 
[K     |████████████████████████████████| 235kB 60.4MB/s 
[K     |████████████████████████████████| 5.1MB 41.0MB/s 
[K     |████████████████████████████████| 133kB 59.5MB/s 
[K     |████████████████████████████████| 430kB 46.5MB/s 
[K     |████████████████████████████████| 51kB 8.5MB/s 
[?25h  Building wheel for fastBPE (setup.py) ... [?25l[?25hdone
  Building wheel for Pmw (setup.py) ... [?25l[?25hdone
  Building wheel for nested-lookup (setup.py) ... [?25l[?25hdone
[31mERROR: tensorflow 2.5.0 has requirement h

ModuleNotFoundError: ignored

In [None]:
%load_ext autoreload
%autoreload 2

import os
import meerkat as mk
import numpy as np
import matplotlib.pyplot as plt
import torchxrayvision as xrv

%matplotlib inline

# Uncomment the line below to see whats going on under the hood
# logging.getLogger("meerkat").setLevel(logging.INFO)

## 💾 Downloading the data
We'll be using the dataset from the [SIIM-ACR Pneumothorax Segmentation Challenge](https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/data) (`mosaic.contrib.siim_cxr` provides utility functions for downloading the data). The downloaded dataset includes the inputs, a large number of chest x-ray files stored in [DICOM](https://www.dicomstandard.org/) format, and the targets, a CSV file mapping each file to its binary pneumothorax label.
-  Download time: ~2 minutes
- Download size:  2.0 GB

In [None]:
from meerkat.contrib.siim_cxr import download_siim_cxr
download_siim_cxr(
    "./", 
    kaggle_username="sabrieyuboglu", 
    kaggle_key="8124277674a280e445d0c7c0ed769fd3"
)

## 🔨 Building a `DataPanel`

In [None]:
dp = mk.DataPanel.from_csv("siim_cxr.csv")
dp.head()

So far, the DataPanel isn't providing anything we couldn't get with a Pandas `DataFrame` because the columns in the CSV include only strings and numbers. 

Things get interesting when we start adding columns for objects that don't play nicely with Pandas – things like images, text, time-series, videos, and multi-dimensional arrays. Out-of-the-box, Meerkat comes with a number of common column types including `ImageColumn` for images, `VideoColumn` for videos, `NumpyArrayColumn` for (potentially multi-dimensional) NumPy `ndarray`s, and `TensorColumn` for PyTorch Tensors (see [here](https://github.com/robustness-gym/meerkat/blob/dev/README.md#supported-columns) for a full list of core columns).  

To house the X-rays in the dataset, we'll be using the `MedicalVolumeColumn`, a column type similar to `ImageColumn` but optimized for medical images stored in [DICOM format](https://www.dicomstandard.org/). 

In [None]:
# Make a column of MedicalVolumeCells
from dosma import DicomReader
from meerkat.contrib.siim_cxr import cxr_transform, cxr_transform_pil

loader = DicomReader(group_by=None, default_ornt=("SI", "AP"))
dp["img"] = mk.MedicalVolumeColumn.from_filepaths(
    dp["filepath"], loader=loader, transform=cxr_transform_pil
)

## 📄 Adding in metadata

In [None]:
def unroll_metadata(dp):
    return dp["img"].get_metadata(
        as_raw_type=True,
        readable=True,
        ignore_bytes=True,
        force_load=True,
    )

dp = dp.update(unroll_metadata, materialize=False, pbar=True)
dp.head()

### 💫 Computing model predictions and activations.
We'd like to perform inference and extract:
  
1. Output predictions  
2. Output class probabilities  
3. Model activations 

Note: in order to extract model activations, we'll need to use a [PyTorch forward hook](https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks) and register it on the final layer of the ResNet. Forward hooks are just functions that get executed on the forward pass of a `torch.nn.Module`. 

In [None]:
model = xrv.models.DenseNet(weights="chex").to("cuda")

class_to_idx = {
    label: idx for idx, label in 
    enumerate(xrv.models.model_urls["chex"]["labels"])
}
model.eval()

# 2. Register the forward hook
embedding = None
def forward_hook(module, input, output):
  global embedding
  embedding = output

model.features.register_forward_hook(forward_hook)

In [None]:
import torchvision.transforms as transforms

transform = transforms.Compose([
  transforms.Lambda(lambda x: np.array(cxr_transform_pil(x))),
  transforms.Lambda(lambda x: xrv.datasets.normalize(x, 255)[None, :, :]),
  xrv.datasets.XRayCenterCrop(),
  xrv.datasets.XRayResizer(224), 
  transforms.Lambda(lambda x: torch.tensor(x)),
])

dp["input"] = mk.MedicalVolumeColumn.from_filepaths(
    dp["filepath"], loader=loader, transform=transform
)

In [None]:
import torch

@torch.no_grad()
def predict(batch: mk.DataPanel):
  global embedding
  x = batch["input"].data.to("cuda") 
  out = model(x)  # Run forward pass

  return {
       "output": mk.ClassificationOutputColumn(probs=out.cpu(), multi_label=True),
       "embedding": mk.EmbeddingColumn(embedding.mean(dim=[-1,-2]).cpu())
  }

dp = dp.update(
  function=predict, is_batched_fn=True, batch_size=16,
  num_workers=2, pbar=True, input_columns=["input"] 
)
dp.head()

In [None]:
from sklearn.metrics import roc_auc_score
roc_auc_score(dp["pmx"].data, dp["output"].probabilities().data[:, class_to_idx["Pneumothorax"]])

In [None]:
umap = dp["embedding"].umap()

dp["umap_0"] = umap.embeddings[:, 0]
dp["umap_1"] = umap.embeddings[:, 1]

In [None]:
import seaborn as sns
plt.figure(figsize=(4,4))
sns.scatterplot(
    data=dp.lz[:1000].to_pandas(), 
    x="umap_0", 
    y="umap_1", 
    hue="Patient's Sex",
    #alpha=0.05
)
sns.despine()
# plt.savefig("fig.png")

### 📄 Radiologist reports (`SpacyColumn`)

In pneuomothorax detection, as in other classification tasks, the binary label does not capture all of the nuance in the X-ray. Radiologists communicate that additional detail via natural language radiologist reports that accompany each scan. For example, a sentence in a chest X-ray report may read "A medial pneumothorax is present adjacent to the heart." Increasingly, these reports are playing a starring role in machine learning for medical imaging. The reports are used to extract weak labels ([Dunnmon & Ratner, et al.](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7413132/), [Eyuboglu *et al.*](https://www.nature.com/articles/s41467-021-22018-1)) and perform contrastive learning on paired images and reports ([Zhang *et al.*](https://arxiv.org/pdf/2010.00747.pdf)). With Meerkat, we can store the radiology reports right alongside the X-rays in the same `DataPanel`. This allows us to experiment with multi-modal learning techniques without re-engineering our data pipelines. 

Additionally, we can use the accompanying radiology reports to select critical subsets of the data and compute subgroup accuracy. For instance, say we're interested in the performance of our model on "severe" pneumothorax. Because the radiologist reports are stored in a `SpacyColumn`, a column that holds preprocessed (*e.g.* tokenized) natural language data, it's easy to write a function `is_severe` that accepts a row as input and returns `True` if the X-ray exhibits pneumothorax and the words "pneumothorax" and "severe" appear in the same sentence.

In [None]:
 dp["report_doc"] = mk.SpacyColumn.from_texts(dp["report"])

In [None]:
def is_severe(row: mk.DataPanel):
	""" Return `True` if the X-ray exhibits pneumothorax and it is described 
	as severe in the report (according to a simple rule-based heuristic)."""
	if row["pmx"] != 1:
		return False
	for sent in row["report_doc"].sents:
		if "pneumothorax" in str(sent) and "severe" in str(sent):
			return True
	return False

severe_dp = dp.filter(
	function=is_severe, is_batched_fn=False, input_columns=["report_doc", "pmx"], pbar=True
)

print(f"There are {len(severe_dp)} X-rays exhibiting severe pneumothorax.") 

## 👓  Radiologist eye-tracking data (`GazeSequenceCell` and `CellColumn`)  

Our chest X-ray dataset includes an exciting, non-standard data modality, eye-tracking time-series, for which we'll implement a custom column. We have access to this data because a subset of the X-rays in the dataset were labeled by radiologists while their gaze was being recorded by an eye-tracker. This gaze signal can provide additional supervision when training a model or can be used to slice the dataset during evaluation.

Meerkat does **not** ship with a column type for eye-tracking data, so we'll have to write our own. In Meerkat, the easiest way to implement a new column is to use the `CellColumn` abstraction. The advantage of using `CellColumn` (or one of its subclasses) is that we can support new data types without dealing with the implementation complexity of a full column. Instead, we can think in terms of the individual elements in the column: the cells. We implement a cell by subclassing `AbstractCell` and adding functionality specific to the new data type.

In [None]:
# The gaze data stored in JSON format
import json
gaze_data = json.load(open("cxr_gaze_data.json", 'rb'))

Below, we provide a simple implementation of a new cell type GazeSequenceCell that houses a sequence of eye-tracking coordinates. In addition to adding `__repr__` and `_state_keys` methods, useful for column inspection and serialization respectively, we implement the utility method to_gaze_heatmap which produces a NumPy array representing the amount of time the radiologist's gaze fell on each patch of the image. 

In [None]:
from typing import Sequence

class GazeSequenceCell(mk.AbstractCell):

  def __init__(self, gaze_x: Sequence, gaze_y: Sequence, time: Sequence):
    """
    Args:
        gaze_x (Sequence): 
        gaze_y (Sequence): [description]
        time (Sequence): [description]
    """
    self.gaze_coordinates = np.array([gaze_y, gaze_x])
    self.time = np.array(time)
  
  def get(self):
    return self
  
  def to_heatmap(self, num_patches: int = 16) -> np.ndarray:
    """ Convert the sequence to a heatmap showing the cumulative
    duration that the gaze fell on each patch of the image.  
    Args:
      num_patches (int): split the image into `num_patches` x `num_patches`
        patches.
    Returns:
      np.ndarray: an array with shape (num_patches, num_patches) where   
    """
    heatmap = np.zeros(num_patches * num_patches)
    patches = (
        np.floor(self.gaze_coordinates[0] * num_patches) * num_patches + 
        np.floor(self.gaze_coordinates[1] * num_patches)
    )
    np.add.at(heatmap, patches.astype(int), self.time)
    return heatmap.reshape(num_patches, num_patches)
  
  def __repr__(self):
      return f"GazeSequence(length={self.gaze_coordinates.shape[-1]})"

  @classmethod
  def _state_keys(cls):
      return {"gaze_coordinates", "time"}

We create a full column by instantiating a GazeSequenceCell for each X-ray and passing them into a new CellColumn.  Because we only have gaze data for a subset of the X-rays in the dataset, we store the gaze sequences in a new DataPanel alongside their corresponding "image_id" and then perform a database style join (via ms.merge) to combine the original DataPanel with the gaze data.  

In [None]:
image_ids, cells = zip(*[
    (row["image_id"], GazeSequenceCell(row["gaze_x"], row["gaze_y"], row["time"])) 
    for row in gaze_data
])
gaze_dp = mk.DataPanel.from_batch({
    "gaze": mk.CellColumn.from_cells(cells),
    "image_id": mk.NumpyArrayColumn(image_ids)
})
gaze_dp = mk.merge(dp, gaze_dp, how="inner", on="image_id")

In [None]:
NUM_PATCHES = 16
row = gaze_dp[4]
heatmap = row["gaze"].to_heatmap(num_patches=NUM_PATCHES)
height, width = np.array(row["img"]).shape
plt.imshow(row["img"], cmap="gray")
plt.imshow(
    heatmap.repeat(height / NUM_PATCHES, axis=0).repeat(width / NUM_PATCHES, axis=1), 
    alpha=0.4
)

In [None]:
dp[["image_id", "pmx", "filepath", "Patient's Age", "Patient's Sex", "img", "output", "embedding", "umap_0", "umap_1", "report_doc"]]

In [None]:
gaze_dp["patient_age"] = np.array(gaze_dp["Patient's Age"])
gaze_dp["patient_sex"] = np.array(gaze_dp["Patient's Sex"])

In [None]:
gaze_dp[["image_id", "pmx", "filepath", "patient_age", "patient_sex", "img", "output", "embedding", "umap_0", "umap_1", "report_doc", "gaze"]]

## ✂️ Segmentations
Segmentations are useful for systematically communicating regions of interest (ROIs) in an image. These annotations can also help with standardized reporting and comparison of quantitative values. However, segmentations can be quite expensive to collect and difficult to interact with dynamically.

Meerkat simplies the storage and dynamic interation with these visual labels. For example, we can use these segmentations to compute quantitative metrics, such as ROI. We can also visually compare the segmentations with Gaze heatmaps to qualitatively inspect how well gaze data can be used as a corollary for segmentations.

In [None]:
import cv2

def rle2mask(rle, orig_dim, resize_dim = None, to_nan: bool = False):
  """Convert run length encoding (RLE) to 2D binary mask.

  Args:
    rle (Sequence[int]): Run length encoding.
    orig_dim (Tuple[int]): Shape of the image.
    resize_dim (Tuple[int]): Shape to resize to.
      Resizing is done with cubic interporlation.
    to_nan (bool, optional): Convert 0s to np.nan.

  Returns:
    np.ndarray: The binary mask.
  """
  height, width = orig_dim
  mask = np.zeros(width * height)
  array = np.asarray([int(x) for x in rle.split()])
  starts = array[0::2]
  lengths = array[1::2]
  current_position = 0

  for index, start in enumerate(starts):
    current_position += start
    mask[current_position : current_position + lengths[index]] = 1
    current_position += lengths[index]
  mask = mask.reshape(width, height)

  if resize_dim is not None:
    mask = cv2.resize(mask, resize_dim, interpolation=cv2.INTER_CUBIC)
  if to_nan:
    mask[mask == 0] = np.nan
  return mask

In [None]:
def estimate_pmx_area(row):
  """Estimate the pneumothorax area in mm^2."""
  img = row["img"]
  encoded_pixels = row["encoded_pixels"]
  if encoded_pixels == "-1":
    # No pneumothorax labeled
    return {"Area": 0.}
  spacing = row["Pixel Spacing"]
  pixel_area = np.prod([float(x) for x in spacing])  # Area per pixel in mm^2
  total_area = pixel_area * np.sum(rle2mask(encoded_pixels, img.size))
  return {"Area": total_area}

# Compute pneumothorax ROI area for examples with pneumothorax
dp = dp.update(
  function=estimate_pmx_area, is_batched_fn=False, batch_size=16,
  num_workers=2, pbar=True,
  input_columns=["img", "encoded_pixels", "Pixel Spacing"], 
)
dp.head()

In [None]:
row = dp[1]
alpha = 0.4

# Plot segmentation
_, ax = plt.subplots(1,1, figsize=(5,5))
ax.imshow(row["img"], cmap="gray")
mask = rle2mask(row["encoded_pixels"], row["img"].size, to_nan=True)
ax.imshow(mask, alpha=alpha, cmap="jet")

