In [None]:
#  -------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  -------------------------------------------------------------------------------------------

# Phrase grounding

This notebook demonstrates the usage of the BioViL-T image and text models in a multimodal phrase grounding setting.
Given a chest X-ray and a radiology text phrase, the joint model grounds the phrase in the image, i.e., highlights the regions of the image that share features similar to the phrase.
Please refer to [our ECCV and CVPR papers](https://hi-ml.readthedocs.io/en/latest/multimodal.html#credit) for further details.

The notebook can also be run on Binder without the need of any coding or local installation:

[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/microsoft/hi-ml/HEAD?labpath=hi-ml-multimodal%2Fnotebooks%2Fphrase_grounding.ipynb)

This demo is solely for research evaluation purposes, not intended to be a medical product or clinical use.

## Setup

Let's first install the `hi-ml-multimodal` Python package, which will allow us to import the `health_multimodal` Python module.

In [1]:
pip_source = "hi-ml-multimodal"

In [2]:
%pip install {pip_source}

Collecting hi-ml-multimodal
  Downloading hi_ml_multimodal-0.2.1-py3-none-any.whl (36 kB)
Collecting timm==0.6.5
  Downloading timm-0.6.5-py3-none-any.whl (512 kB)
[K     |████████████████████████████████| 512 kB 42.8 MB/s eta 0:00:01
[?25hCollecting torch==1.9.0
  Downloading torch-1.9.0-cp38-cp38-manylinux1_x86_64.whl (831.4 MB)
[K     |████████████████████████████████| 831.4 MB 37 kB/s s eta 0:00:01     |████████▉                       | 228.5 MB 45.4 MB/s eta 0:00:14     |████████████████▌               | 429.4 MB 83.8 MB/s eta 0:00:05
[?25hCollecting pillow==9.3.0
  Downloading Pillow-9.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
[K     |████████████████████████████████| 3.2 MB 77.2 MB/s eta 0:00:01
[?25hCollecting transformers==4.17.0
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 71.4 MB/s eta 0:00:01
[?25hCollecting torchvision<=0.10.0,>0.9
  Downloading torchvision-0.10.0-cp38-cp38-

Successfully installed PyWavelets-1.4.1 SimpleITK-2.1.1 contourpy-1.1.1 cycler-0.12.1 filelock-3.13.0 fonttools-4.43.1 hi-ml-multimodal-0.2.1 huggingface-hub-0.6.0 imageio-2.31.6 importlib-resources-6.1.0 joblib-1.3.2 kiwisolver-1.4.5 matplotlib-3.7.3 networkx-3.1 pillow-9.3.0 pydicom-2.2.2 pyparsing-3.1.1 regex-2023.10.3 sacremoses-0.0.53 scikit-image-0.18.1 scipy-1.10.1 tifffile-2023.7.10 timm-0.6.5 tokenizers-0.14.1 torch-1.9.0 torchvision-0.10.0 tqdm-4.66.1 transformers-4.17.0 typing-extensions-4.8.0 zipp-3.17.0
Note: you may need to restart the kernel to use updated packages.


In [None]:
from typing import List
from typing import Tuple

import tempfile
from pathlib import Path

import torch
from IPython.display import display
from IPython.display import Markdown

from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map
from health_multimodal.text import get_bert_inference
from health_multimodal.text.utils import BertEncoderType
from health_multimodal.image import get_image_inference
from health_multimodal.image.utils import ImageModelType
from health_multimodal.vlp import ImageTextInferenceEngine

## Load multimodal model

Load the text and image models from [Hugging Face 🤗](https://aka.ms/biovil-models) and instantiate the inference engines:

In [None]:
text_inference = get_bert_inference(BertEncoderType.BIOVIL_T_BERT)
image_inference = get_image_inference(ImageModelType.BIOVIL_T)

Instantiate the joint inference engine:

In [None]:
image_text_inference = ImageTextInferenceEngine(
    image_inference_engine=image_inference,
    text_inference_engine=text_inference,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_text_inference.to(device)

## Helper visualization functions

In [None]:
TypeBox = Tuple[float, float, float, float]

def plot_phrase_grounding(image_path: Path, text_prompt: str, bboxes: List[TypeBox]) -> None:
    similarity_map = image_text_inference.get_similarity_map_from_raw_data(
        image_path=image_path,
        query_text=text_prompt,
        interpolation="bilinear",
    )
    plot_phrase_grounding_similarity_map(
        image_path=image_path,
        similarity_map=similarity_map,
        bboxes=bboxes
    )

def plot_phrase_grounding_from_url(image_url: str, text_prompt: str, bboxes: List[TypeBox]) -> None:
    image_path = Path(tempfile.tempdir, "downloaded_chest_xray.jpg")
    !curl -s -L -o {image_path} {image_url}
    plot_phrase_grounding(image_path, text_prompt, bboxes)

## Inference

We will run inference on a chest X-ray from [Open-i](https://openi.nlm.nih.gov/detailedresult?img=CXR111_IM-0076-1001&req=4), but any other chest X-ray image in DICOM or JPEG format can be used for research purposes.


In [None]:
image_url = "https://openi.nlm.nih.gov/imgs/512/242/1445/CXR1445_IM-0287-4004.png"
text_prompt = "Left basilar consolidation seen"
# Ground-truth bounding box annotation(s) for the input text prompt
bboxes = [
    (306, 168, 124, 101),
]

text = (
    'The ground-truth bounding box annotation for the phrase'
    f' *{text_prompt}* is shown in the middle figure (in black).'
)

display(Markdown(text))
plot_phrase_grounding_from_url(image_url, text_prompt, bboxes)