Copyright (c) Technical University of Munich. All Rights Reserved.

<table>
<tr>
 <td><img src="https://raw.githubusercontent.com/CS4MS/CS4MS_W24/main/images/logo_CS_MS_final.png" height=200 style="float: left;"></td>
    <td>
        <h1>Computer Science for Medical Students</h1>
        <h2>Exercise 7: Vision-Languag Model for Radiology</h2>
        <h3>Matthias Keicher &amp; Chantal Pellegrini</h3>
        <a href="https://www.cs.cit.tum.de/camp/research/vision-language/" target="_blank">Our Vision-Language Research Group @ CAMP</a>
        </br></br>
        <a href="https://github.com/CS4MS/" target="_blank">CS4MS GitHub Repository</a>
    </td>
</tr>
</table>

# Introduction



## Recap CLIP
![CLIP: Contrastive Language-Image Pretraining](https://github.com/openai/CLIP/raw/main/CLIP.png)

[OpenAI blog post with more details](https://openai.com/research/clip)

## BioVil: Making the Most of Text Semantics to Improve Biomedical Vision–Language Processing

![BioVil architecture](https://www.microsoft.com/en-us/research/uploads/prod/2022/07/BioVIL-1024x328.png)
- Contrastive language-image pretraining on chest X-rays and corresponding radiology reports proposed by Microsoft Research [1].
- Trained on the [MIMIC-CXR dataset](https://physionet.org/content/mimic-cxr/2.0.0/) with 377,110 radiographs and corresponding radiology reports.

- [BioVil blog post with more details](https://www.microsoft.com/en-us/research/publication/making-the-most-of-text-semantics-to-improve-biomedical-vision-language-processing/)

[1] Boecking, Benedikt, et al. "Making the most of text semantics to improve biomedical vision–language processing." European conference on computer vision. Cham: Springer Nature Switzerland, 2022.

<br >

### Image Encoder
![CNN](https://upload.wikimedia.org/wikipedia/commons/6/63/Typical_cnn.png)
![ResNet50](https://upload.wikimedia.org/wikipedia/commons/9/98/ResNet50.png)

- [ResNet50](https://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html) CNN architecture
- Pretrained on chest X-rays with contrastive pretraining ([SimCLR](http://proceedings.mlr.press/v119/chen20j.html))

<br >

### Text Encoder

![BioVil CXR-BERT Text Encoder](https://www.microsoft.com/en-us/research/uploads/prod/2022/07/CXR-BERT.png)

- Based on [BERT Transformer encoder architecture](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)
- Pretrained on PubMed articles, [MIMIC clinical notes](https://physionet.org/content/mimiciii) and MIMIC-CXR radiology reports


# Vision-Language Model
Installation of dependencies and loading of text and image encoders

## Dependencies

In [None]:
# Installation of libraries needed for transformer-based language models and timm for CNN image encoders
%pip install transformers timm --quiet

## Loading Image Encoder

In [None]:
#@title BioVil ResNet
#  -------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  -------------------------------------------------------------------------------------------

# Adopted from https://github.com/microsoft/hi-ml/tree/c606808b20c88d6e2cc388bd650abf34ccae17cf/hi-ml-multimodal/src/health_multimodal/image/model

import math
from __future__ import annotations

from dataclasses import dataclass
from contextlib import contextmanager
from functools import partial
from enum import Enum, unique
from typing import Callable, Optional, Any, List, Tuple, Type, Union, Set, Generator, Sequence
from abc import ABC, abstractmethod
import tempfile
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
from torchvision.datasets.utils import download_url
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
from timm.models.layers import DropPath, Mlp, trunc_normal_


def get_module_device(module: torch.nn.Module) -> torch.device:
    """
    Returns the device of the module
    """
    device = next(module.parameters()).device  # type: ignore
    assert isinstance(device, torch.device)

    return device

@dataclass
class ImageModelOutput:
    img_embedding: torch.Tensor
    patch_embeddings: torch.Tensor
    projected_global_embedding: torch.Tensor
    class_logits: torch.Tensor
    projected_patch_embeddings: torch.Tensor


@unique
class ImageEncoderType(str, Enum):
    RESNET18 = "resnet18"
    RESNET50 = "resnet50"
    RESNET18_MULTI_IMAGE = "resnet18_multi_image"
    RESNET50_MULTI_IMAGE = "resnet50_multi_image"

    @classmethod
    def get_members(cls, multi_image_encoders_only: bool) -> List[ImageEncoderType]:
        if multi_image_encoders_only:
            return [cls.RESNET18_MULTI_IMAGE, cls.RESNET50_MULTI_IMAGE]
        else:
            return [member for member in cls]


@unique
class ImageEncoderWeightTypes(str, Enum):
    RANDOM = "random"
    IMAGENET = "imagenet"
    BIOVIL = "biovil"
    BIOVIL_T = "biovil_t"


class MLP(nn.Module):
    """
    Fully connected layers to map between image embeddings and projection space where pairs of images are compared.

    :param input_dim: Input embedding feature size
    :param hidden_dim: Hidden layer size in MLP
    :param output_dim: Output projection size
    :param use_1x1_convs: Use 1x1 conv kernels instead of 2D linear transformations for speed and memory efficiency.
    """

    def __init__(
        self, input_dim: int, output_dim: int, hidden_dim: Optional[int] = None, use_1x1_convs: bool = False
    ) -> None:
        super().__init__()

        if use_1x1_convs:
            linear_proj_1_args = {'in_channels': input_dim, 'out_channels': hidden_dim, 'kernel_size': 1, 'bias': False}
            linear_proj_2_args = {'in_channels': hidden_dim, 'out_channels': output_dim, 'kernel_size': 1, 'bias': True}
            normalisation_layer: Callable = nn.BatchNorm2d
            projection_layer: Callable = nn.Conv2d
        else:
            linear_proj_1_args = {'in_features': input_dim, 'out_features': hidden_dim, 'bias': False}
            linear_proj_2_args = {'in_features': hidden_dim, 'out_features': output_dim, 'bias': True}
            normalisation_layer = nn.BatchNorm1d
            projection_layer = nn.Linear

        self.output_dim = output_dim
        self.input_dim = input_dim
        if hidden_dim is not None:
            self.model = nn.Sequential(
                projection_layer(**linear_proj_1_args),
                normalisation_layer(hidden_dim),
                nn.ReLU(inplace=True),
                projection_layer(**linear_proj_2_args),
            )
        else:
            self.model = nn.Linear(input_dim, output_dim)  # type: ignore

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward pass of the multi-layer perceptron"""
        x = self.model(x)
        return x


class MultiTaskModel(nn.Module):
    """Torch module for multi-task classification heads. We create a separate classification head
    for each task and perform a forward pass on each head independently in forward(). Classification
    heads are instances of `MLP`.

    :param input_dim: Number of dimensions of the input feature map.
    :param classifier_hidden_dim: Number of dimensions of hidden features in the MLP.
    :param num_classes: Number of output classes per task.
    :param num_tasks: Number of classification tasks or heads required.
    """

    def __init__(self, input_dim: int, classifier_hidden_dim: Optional[int], num_classes: int, num_tasks: int):
        super().__init__()

        self.num_classes = num_classes
        self.num_tasks = num_tasks

        for task in range(num_tasks):
            # TODO check if softmax not needed here.
            setattr(self, "fc_" + str(task), MLP(input_dim, output_dim=num_classes, hidden_dim=classifier_hidden_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Returns [batch_size, num_tasks, num_classes] tensor of logits."""
        batch_size = x.shape[0]
        out = torch.zeros((batch_size, self.num_classes, self.num_tasks), dtype=x.dtype, device=x.device)
        for task in range(self.num_tasks):
            classifier = getattr(self, "fc_" + str(task))
            out[:, :, task] = classifier(x)
        return out



TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]


class ResNetHIML(ResNet):
    """Wrapper class of the original torchvision ResNet model.

    The forward function is updated to return the penultimate layer
    activations, which are required to obtain image patch embeddings.
    """

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)

    def forward(
        self, x: torch.Tensor, return_intermediate_layers: bool = False
    ) -> Union[torch.Tensor, TypeSkipConnections]:
        """ResNetHIML forward pass. Optionally returns intermediate layers using the
        ``return_intermediate_layers`` argument.

        :param return_intermediate_layers: If ``True``, return layers x0-x4 as a tuple,
            otherwise return x4 only.
        """

        x0 = self.conv1(x)
        x0 = self.bn1(x0)
        x0 = self.relu(x0)
        x0 = self.maxpool(x0)

        x1 = self.layer1(x0)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        if return_intermediate_layers:
            return x0, x1, x2, x3, x4
        else:
            return x4


def _resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> ResNetHIML:
    """Instantiate a custom :class:`ResNet` model.

    Adapted from :mod:`torchvision.models.resnet`.
    """
    model = ResNetHIML(block=block, layers=layers, **kwargs)
    if pretrained:
        raise NotImplementedError
    return model


def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.

    :param pretrained: If ``True``, returns a model pre-trained on ImageNet.
    :param progress: If ``True``, displays a progress bar of the download to ``stderr``.
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)


def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.

    :param pretrained: If ``True``, returns a model pre-trained on ImageNet
    :param progress: If ``True``, displays a progress bar of the download to ``stderr``.
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)


@dataclass
class MultiHeadAttentionOutput:
    mha_output: torch.Tensor
    attention: Optional[torch.Tensor] = None


class VisionTransformerPooler(nn.Module):
    """
    :param input_dim: Input feature dimension (i.e., channels in old CNN terminology)
    :param grid_shape: Shape of the grid of patches per image
    :param num_heads: Number of self-attention heads within the MHA block
    :param num_blocks: Number of blocks per attention layer
    :param norm_layer: Normalisation layer

    `self.type_embed`: Is used to characterise prior and current scans, and
                       create permutation variance across modalities/series.
    """

    def __init__(
        self,
        input_dim: int,
        grid_shape: Tuple[int, int],
        num_heads: int = 8,
        num_blocks: int = 3,
        norm_layer: Any = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()

        block_kwargs = dict(
            dim=input_dim,
            num_heads=num_heads,
            mlp_ratio=1.0,
            drop=0.10,
            attn_drop=0.10,
            drop_path=0.25,
            act_layer=nn.GELU,
            norm_layer=norm_layer,
        )
        self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)])
        self.norm_post = norm_layer(input_dim)
        self.grid_shape = grid_shape
        self.num_patches = grid_shape[0] * grid_shape[1]
        self.num_blocks = num_blocks

        # Temporal positional embeddings
        num_series: int = 2
        self.type_embed = nn.Parameter(torch.zeros(num_series, 1, input_dim))
        trunc_normal_(self.type_embed, std=0.02)

        # Positional embeddings 1 x L x C (L: Sequence length, C: Feature dimension)
        self.pos_drop = nn.Dropout(p=0.10)
        pos_embed_class = SinePositionEmbedding(embedding_dim=input_dim // 2, normalize=True)
        pos_embed = pos_embed_class(mask=torch.ones([1, grid_shape[0], grid_shape[1]]))  # 1 x L x C
        self.register_buffer("pos_embed", pos_embed, persistent=False)

        # Initialisation
        self.apply(self._init_weights)

    def no_weight_decay(self) -> Set[str]:
        return {'type_embed'}

    def forward(self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, C, H, W = current_image.shape
        assert H == self.grid_shape[0] and W == self.grid_shape[1], "Input and grid shapes do not match"

        # Flatten patch embeddings to have shape (B x L x C), L = H * W
        if previous_image is not None:
            assert previous_image.shape == current_image.shape, "current_image and previous_image shapes do not match"
            previous_image = previous_image.view(B, C, H * W).transpose(1, 2)
        current_image = current_image.view(B, C, H * W).transpose(1, 2)
        pos_embed = self.pos_embed.repeat(B, 1, 1)  # type: ignore

        # Final token activations (B x 2L x C)
        token_features = self.forward_after_reshape(x=current_image, pos_embed=pos_embed, x_previous=previous_image)

        # Extract the patch features of current image
        cur_img_token_id = 0
        current_token_features = token_features[:, cur_img_token_id : self.num_patches + cur_img_token_id]
        current_patch_features = current_token_features.transpose(1, 2).view(B, C, H, W)

        return current_patch_features

    def forward_after_reshape(
        self, x: torch.Tensor, pos_embed: torch.Tensor, x_previous: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        B, L, _ = x.shape  # Batch, Sequence length, Feature dimension

        # Positional and type embeddings
        type_embed = self.type_embed[0].expand(B, L, -1)
        if x_previous is not None:
            x = torch.cat((x, x_previous), dim=1)
            pos_embed = torch.cat((pos_embed, pos_embed), dim=1)
            prev_type_embed = self.type_embed[1].expand(B, L, -1)
            type_embed = torch.cat((type_embed, prev_type_embed), dim=1)

        # Add positional and type embeddings (used in query and key matching)
        pos_and_type_embed = pos_embed + type_embed

        # Positional dropout
        x = self.pos_drop(x)

        # Multihead attention followed by MLP
        for block in self.blocks:
            x = block(x=x, pos_and_type_embed=pos_and_type_embed)
        x = self.norm_post(x)

        return x

    def _init_weights(self, m: nn.Module) -> None:
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


class MultiHeadAttentionLayer(nn.Module):
    """
    Multi-head self attention module

    The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
        - Defines a custom `MultiHeadAttentionLayer` which does not only apply `self-attention` but it can be
            generalised to arbitrary (query, key, value) input tuples. This feature can be valuable to process
            more than 2 scans at a time.
        - `Self-attention` specific use-case can still be invoked by calling the `forward_as_mhsa` method.
    """

    def __init__(
        self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        assert dim % num_heads == 0, f"The embedding dim ({dim}) must be divisible by the number of heads ({num_heads})"
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        self.return_attention = False

        self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
        self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
        self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, k: torch.Tensor, q: torch.Tensor, v: torch.Tensor) -> MultiHeadAttentionOutput:
        B, N, C = v.shape
        assert (
            C % self.num_heads == 0
        ), f"The embedding dim ({C}) must be divisible by the number of heads ({self.num_heads})"

        w_q = self.proj_q(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        w_k = self.proj_k(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        w_v = self.proj_v(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        attn = (w_q @ w_k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        o = (attn @ w_v).transpose(1, 2).reshape(B, N, C)
        o = self.proj(o)
        o = self.proj_drop(o)

        attention_output = attn if self.return_attention else None

        return MultiHeadAttentionOutput(mha_output=o, attention=attention_output)

    def forward_as_mhsa(self, input: torch.Tensor) -> MultiHeadAttentionOutput:
        return self(k=input, q=input, v=input)


class Block(nn.Module):
    """
    Encapsulates multi-layer perceptron and multi-head self attention modules into a block.

    The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
        - This implementation uses spatio-temporal positional embeddings instead of 2D positional embeddings only,
            and they are taken into account within the forward pass of each ViT block.
        - Utilises the custom defined `MultiHeadAttentionLayer` which does not apply `self-attention` only but can be
            generalised to arbitrary (query, key, value) tuples. This can be valuable to process more than 2 scans.

    Positional and type embeddings are handled in a similar fashion as DETR object localisation paper
    https://alcinos.github.io/detr_page/, where a fixed set of sine/cos positional embeddings are used
    in an additive manner to Q and K tensors.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 1.0,
        qkv_bias: bool = False,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = MultiHeadAttentionLayer(
            dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def with_pos_and_type_embed(self, tensor: torch.Tensor, emb: Optional[torch.Tensor]) -> torch.Tensor:
        # Add positional embeddings to key and query tensors
        return tensor if emb is None else tensor + emb

    def forward(self, x: torch.Tensor, pos_and_type_embed: Optional[torch.Tensor]) -> torch.Tensor:
        x_with_emb = self.with_pos_and_type_embed(self.norm1(x), emb=pos_and_type_embed)
        x = x + self.drop_path(self.attn.forward_as_mhsa(x_with_emb).mha_output)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class SinePositionEmbedding:
    """
    This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
    need paper, generalized to work on images.
    """

    def __init__(
        self, embedding_dim: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def __call__(self, mask: torch.Tensor) -> torch.Tensor:
        assert mask is not None, "No pixel mask provided"
        B, H, W = mask.shape
        y_embed = mask.cumsum(1, dtype=torch.float32)
        x_embed = mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale

        dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2)

        return pos


DEFAULT_DILATION_VALUES_FOR_RESNET = (False, False, True)
ImageEncoderOutputType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]


class ImageEncoder(nn.Module):
    """Image encoder trunk module for the ``ImageModel`` class.

    :param img_encoder_type : Type of image encoder model to use, either ``"resnet18_multi_image"`` or
                              ``"resnet50_multi_image"``.
    """

    def __init__(self, img_encoder_type: str):
        super().__init__()
        self.img_encoder_type = img_encoder_type
        self.encoder = self._create_encoder()

    def _create_encoder(self, **kwargs: Any) -> nn.Module:
        if self.img_encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET18_MULTI_IMAGE]:
            encoder_class = resnet18
        elif self.img_encoder_type in [ImageEncoderType.RESNET50, ImageEncoderType.RESNET50_MULTI_IMAGE]:
            encoder_class = resnet50
        else:
            supported = ImageEncoderType.get_members(multi_image_encoders_only=False)
            raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}")

        encoder = encoder_class(pretrained=False, **kwargs)

        return encoder

    def forward(self, current_image: torch.Tensor, return_patch_embeddings: bool = False) -> ImageEncoderOutputType:
        """Get image global and patch embeddings"""

        patch_emb = self.encoder(current_image)
        avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_emb, (1, 1)), 1)
        if return_patch_embeddings:
            return patch_emb, avg_pooled_emb

        return avg_pooled_emb

    def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
        """Workaround for enabling dilated convolutions after model initialization.

        :param replace_stride_with_dilation: Replace the 2x2 standard convolution stride with a dilated convolution
                                             in each layer in the last three blocks of ResNet architecture.
        """
        if self.img_encoder_type == ImageEncoderType.RESNET18:
            # resnet18 uses BasicBlock implementation, which does not support dilated convolutions.
            raise NotImplementedError("resnet18 does not support dilated convolutions")

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = DEFAULT_DILATION_VALUES_FOR_RESNET

        device = next(self.encoder.parameters()).device
        new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device)

        if self.encoder.training:
            new_encoder.train()
        else:
            new_encoder.eval()

        new_encoder.load_state_dict(self.encoder.state_dict())
        self.encoder = new_encoder


class MultiImageEncoder(ImageEncoder):
    """Multi-image encoder trunk module for the ``ImageModel`` class.
    It can be used to encode multiple images into combined latent representation.
    Currently it only supports two input images but can be extended to support more in future.

    :param img_encoder_type: Type of image encoder model to use: either ``"resnet18"`` or ``"resnet50"``.
    """

    def __init__(self, img_encoder_type: str):
        super().__init__(img_encoder_type)

        output_dim = 256  # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
        grid_shape = (14, 14)  # Spatial dimensions of patch grid.

        backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self))

        self.backbone_to_vit = nn.Conv2d(
            in_channels=backbone_output_feature_dim,
            out_channels=output_dim,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )
        self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=grid_shape)

        # Missing image embedding
        self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1))
        trunc_normal_(self.missing_previous_emb, std=0.02)

    def forward(  # type: ignore[override]
        self,
        current_image: torch.Tensor,
        previous_image: Optional[torch.Tensor] = None,
        return_patch_embeddings: bool = False,
    ) -> ImageEncoderOutputType:
        batch_size = current_image.shape[0]

        if previous_image is not None:
            assert current_image.shape == previous_image.shape
            x = torch.cat([current_image, previous_image], dim=0)
            x = super().forward(x, return_patch_embeddings=True)[0]
            x = self.backbone_to_vit(x)
            patch_x, patch_x_previous = x[:batch_size], x[batch_size:]
            diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_x_previous)
        else:
            x = super().forward(current_image, return_patch_embeddings=True)[0]
            patch_x = self.backbone_to_vit(x)
            B, _, W, H = patch_x.shape
            diff_x = self.missing_previous_emb.repeat(B, 1, W, H)

        patch_fused = torch.cat([patch_x, diff_x], dim=1)
        avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_fused, (1, 1)), 1)

        if return_patch_embeddings:
            return patch_fused, avg_pooled_emb

        return avg_pooled_emb

    def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
        raise NotImplementedError


@torch.no_grad()
def get_encoder_output_dim(module: torch.nn.Module, device: torch.device) -> int:
    """Calculate the output dimension of an encoder by making a single forward pass.

    :param module: Encoder module.
    :param device: Compute device to use.
    """
    # Target device
    assert isinstance(device, torch.device)

    x = torch.rand((1, 3, 448, 448)).to(device)

    # Extract the number of output feature dimensions
    with restore_training_mode(module):
        module.eval()
        representations = module(x)
    return representations.shape[1]


@contextmanager
def restore_training_mode(module: nn.Module) -> Generator[None, None, None]:
    """Restore the training mode of a module after some operation.

    :param module: PyTorch module.
    """
    training_mode = module.training
    yield
    module.train(mode=training_mode)


def get_encoder_from_type(img_encoder_type: str) -> ImageEncoder:
    """Returns the encoder class for the given encoder type.

    :param img_encoder_type: Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE}
    """
    if img_encoder_type in ImageEncoderType.get_members(multi_image_encoders_only=True):
        return MultiImageEncoder(img_encoder_type=img_encoder_type)
    else:
        return ImageEncoder(img_encoder_type=img_encoder_type)


class BaseImageModel(nn.Module, ABC):
    """Abstract class for image models."""

    @abstractmethod
    def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput:
        raise NotImplementedError

    @abstractmethod
    def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
        raise NotImplementedError


class ImageModel(BaseImageModel):
    """Image encoder module"""

    def __init__(
        self,
        img_encoder_type: str,
        joint_feature_size: int,
        freeze_encoder: bool = False,
        pretrained_model_path: Optional[Union[str, Path]] = None,
        **downstream_classifier_kwargs: Any,
    ):
        super().__init__()

        # Initiate encoder, projector, and classifier
        self.encoder = get_encoder_from_type(img_encoder_type)
        self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder))
        self.projector = MLP(
            input_dim=self.feature_size,
            output_dim=joint_feature_size,
            hidden_dim=joint_feature_size,
            use_1x1_convs=True,
        )
        self.downstream_classifier_kwargs = downstream_classifier_kwargs
        self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None

        # Initialise the mode of modules
        self.freeze_encoder = freeze_encoder
        self.train()

        if pretrained_model_path is not None:
            if not isinstance(pretrained_model_path, (str, Path)):
                raise TypeError(f"Expected a string or Path, got {type(pretrained_model_path)}")
            state_dict = torch.load(pretrained_model_path, map_location="cpu")
            self.load_state_dict(state_dict)

    def train(self, mode: bool = True) -> Any:
        """Switch the model between training and evaluation modes."""
        super().train(mode=mode)
        if self.freeze_encoder:
            self.encoder.train(mode=False)
            self.projector.train(mode=False)
        return self

    def forward(self, x: torch.Tensor) -> ImageModelOutput:  # type: ignore[override]
        with torch.set_grad_enabled(not self.freeze_encoder):
            patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True)
        return self.forward_post_encoder(patch_x, pooled_x)

    def forward_post_encoder(self, patch_x: torch.Tensor, pooled_x: torch.Tensor) -> ImageModelOutput:
        with torch.set_grad_enabled(not self.freeze_encoder):
            projected_patch_embeddings = self.projector(patch_x)
            projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3))

        logits = self.classifier(pooled_x) if self.classifier else None
        return ImageModelOutput(
            img_embedding=pooled_x,
            patch_embeddings=patch_x,
            class_logits=logits,
            projected_patch_embeddings=projected_patch_embeddings,
            projected_global_embedding=projected_global_embedding,
        )

    def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel:
        """Create the classification module for the downstream task."""
        downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs
        return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs)

    @torch.no_grad()
    def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
        """Get patch-wise projected embeddings from the CNN model.

        :param input_img: input tensor image [B, C, H, W].
        :param normalize: If ``True``, the embeddings are L2-normalized.
        :returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].
        """
        assert not self.training, "This function is only implemented for evaluation mode"
        outputs = self.forward(input_img)
        projected_embeddings = outputs.projected_patch_embeddings.detach()  # type: ignore
        if normalize:
            projected_embeddings = F.normalize(projected_embeddings, dim=1)
        projected_embeddings = projected_embeddings.permute([0, 2, 3, 1])  # B D H W -> B H W D (D: Features)
        return projected_embeddings


class MultiImageModel(ImageModel):
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder"

    def forward(  # type: ignore[override]
        self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None
    ) -> ImageModelOutput:
        with torch.set_grad_enabled(not self.freeze_encoder):
            patch_x, pooled_x = self.encoder(
                current_image=current_image, previous_image=previous_image, return_patch_embeddings=True
            )
        return self.forward_post_encoder(patch_x, pooled_x)



JOINT_FEATURE_SIZE = 128

BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
BIOMED_VLP_BIOVIL_T = "microsoft/BiomedVLP-BioViL-T"
HF_URL = "https://huggingface.co"

CXR_BERT_COMMIT_TAG = "v1.1"
BIOVIL_T_COMMIT_TAG = "v1.0"

BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
BIOVIL_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_CXR_BERT_SPECIALIZED}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}"  # noqa: E501
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"

BIOVIL_T_IMAGE_WEIGHTS_NAME = "biovil_t_image_model_proj_size_128.pt"
BIOVIL_T_IMAGE_WEIGHTS_URL = (
    f"{HF_URL}/{BIOMED_VLP_BIOVIL_T}/resolve/{BIOVIL_T_COMMIT_TAG}/{BIOVIL_T_IMAGE_WEIGHTS_NAME}"  # noqa: E501
)
BIOVIL_T_IMAGE_WEIGHTS_MD5 = "a83080e2f23aa584a4f2b24c39b1bb64"


def _download_biovil_image_model_weights() -> Path:
    """Download image model weights from Hugging Face.

    More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
    """
    root_dir = tempfile.gettempdir()
    download_url(
        BIOVIL_IMAGE_WEIGHTS_URL,
        root=root_dir,
        filename=BIOVIL_IMAGE_WEIGHTS_NAME,
        md5=BIOVIL_IMAGE_WEIGHTS_MD5,
    )
    return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)


def _download_biovil_t_image_model_weights() -> Path:
    """Download image model weights from Hugging Face.

    More information available at https://huggingface.co/microsoft/microsoft/BiomedVLP-BioViL-T.
    """
    root_dir = tempfile.gettempdir()
    download_url(
        BIOVIL_T_IMAGE_WEIGHTS_URL, root=root_dir, filename=BIOVIL_T_IMAGE_WEIGHTS_NAME, md5=BIOVIL_T_IMAGE_WEIGHTS_MD5
    )
    return Path(root_dir, BIOVIL_T_IMAGE_WEIGHTS_NAME)


def get_biovil_image_encoder(pretrained: bool = True) -> ImageModel:
    """Download weights from Hugging Face and instantiate the image model."""
    resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None

    image_model = ImageModel(
        img_encoder_type=ImageEncoderType.RESNET50,
        joint_feature_size=JOINT_FEATURE_SIZE,
        pretrained_model_path=resnet_checkpoint_path,
    )
    return image_model


def get_biovil_t_image_encoder() -> ImageModel:
    """Download weights from Hugging Face and instantiate the image model."""

    biovilt_checkpoint_path = _download_biovil_t_image_model_weights()
    model_type = ImageEncoderType.RESNET50_MULTI_IMAGE
    image_model = ImageModel(
        img_encoder_type=model_type,
        joint_feature_size=JOINT_FEATURE_SIZE,
        pretrained_model_path=biovilt_checkpoint_path,
    )
    return image_model


def get_imagenet_init_encoder() -> ImageModel:
    """Download ImageNet pre-trained weights and instantiate the image model."""
    raise NotImplemented


Load model and move to GPU

In [None]:
from torchvision.transforms import Resize, CenterCrop, ToTensor, Compose
import torch

# wrapper for image encoder of BioVil for convenience
class ImageEncoderBioVil(nn.Module):
    """Instantiate image model with pre-trained weights.
    :param weights: Select one of `biovil`, `biovil_t`
    """
    def __init__(self, backbone='biovil'):
        super().__init__()
        if backbone == ImageEncoderWeightTypes.BIOVIL:
            self.backbone = get_biovil_image_encoder()
        elif backbone == ImageEncoderWeightTypes.BIOVIL_T:
            self.backbone = get_biovil_t_image_encoder()
        else:
            raise ValueError(f"Weights option not found: {backbone}")

    def forward(self, image):
        return self.backbone(image)

# image transformations used by BioVil
image_transforms = Compose([Resize(512, antialias=True), CenterCrop(512), ToTensor()])

# load image encoder
image_encoder = ImageEncoderBioVil('biovil_t')

# set device for models - ideally using a GPU for fast performance (=cuda)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def print_red(text):
    print('\033[91m' + text + '\033[0m')

def print_green(text):
    print('\033[92m' + text + '\033[0m')

if device == 'cpu':
    print_red('No GPU accelerator found, falling back to CPU - model inference will be slow')
else:
    print_green('Running on GPU! (fast inference)')

image_encoder.eval()
image_encoder.to(device)

# helper method for getting image embedding from BioVil
def get_image_embeddings(images):
    with torch.no_grad():
        transformed_images = torch.stack([image_transforms(image) for image in images])
        image_model_output = image_encoder(transformed_images.to(device))
        image_embeddings = image_model_output.projected_global_embedding
        return image_embeddings

## Loading Text Encoder

In [None]:
#@title BioVil BERT
import torch
from transformers import AutoModel, AutoTokenizer

# Load the model and tokenizer
url = "microsoft/BiomedVLP-BioViL-T"
tokenizer = AutoTokenizer.from_pretrained(url, trust_remote_code=True)
text_encoder = AutoModel.from_pretrained(url, trust_remote_code=True)
text_encoder.eval()
text_encoder.to(device)

# Tokenize and compute the sentence embeddings
def get_text_embeddings(text_prompts):
    with torch.no_grad():
        tokenizer_output = tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=text_prompts,
            add_special_tokens=True,
            padding='longest',
            return_tensors='pt',
        ).to(device)
        text_embeddings = text_encoder(
            input_ids=tokenizer_output.input_ids,
            attention_mask=tokenizer_output.attention_mask,
            output_cls_projected_embedding=True,
        )
        return text_embeddings.cls_projected_embedding

## Helper functions

### Cosine Similarity
The cosine similarity is a similarity measure between two vectors that only compares the similarity of their direction, not their length. Mathematically the dot product is calculated between the normalized vectors.

More information on wikipedia: https://en.wikipedia.org/wiki/Cosine_similarity

In [None]:
def calculate_cosine_similarity(embedding_1, embedding_2):
    # Normalize the vectors to a length of one and then perform a matrix multiplication (dot product)
    return F.normalize(embedding_1) @ F.normalize(embedding_2).T

### Download image
Helper method to download an image from a given url.

In [None]:
import urllib.request
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image

def load_image(url, cache_directory='images'):
    # get file name of chest x-ray
    parsed_url = urlparse(url)
    image_name = Path(parsed_url.path).name
    # create 'images' folder as cache directory in case it does not exist yet
    cache_path = Path(cache_directory)
    cache_path.mkdir(exist_ok=True)
    # define local image path
    image_path = cache_path / image_name
    # if the image is not there yet download it
    if not image_path.exists():
        with urllib.request.urlopen(url) as response, image_path.open('wb') as image_file:
            image_file.write(response.read())
    # open image with the Pillow library and return it
    return Image.open(image_path)

# test with a random sample from the Indiana chest x-ray dataset
url = 'https://openi.nlm.nih.gov/imgs/512/276/677/CXR677_IM-2249-1001.png?keywords=Catheters,%20Indwelling,Lung,Density,Density,Density,Pleural%20Effusion,Pneumonia'
print('downloaded url:', url)
display(load_image(url))

### Plot results
This is another helper method to visualize the cosine similarities between text and images as well as softmax probabilities.

In [None]:
#@title Visualization method for text-image cosine similarity

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, AnchoredOffsetbox

# Creating a dataframe for seaborn heatmap
def plot_similarities(similarities, text_prompts, images, image_descriptions, plot_probabilities=False, size_unit = 1):
    if plot_probabilities:
        title = 'Softmax probabilities'
        # apply softmax to similarities for each image
        probabilities = similarities.softmax(dim=-1)
        matrix = probabilities
        # define range of probabilities 0%-100%
        vmin = 0
        vmax = 1
        # formatting
        fmt = '.0%'
    else:
        title = 'Cosine similarities'
        matrix = similarities
        # define range of cosine similarity
        vmin = -1
        vmax = 1
        # format
        fmt =  '.2f'
    for prompt in text_prompts:
        if len(prompt)>40:
            prompt = prompt.replace('. ', '.\n')
    # create table with matrix values and column / row names
    df = pd.DataFrame(matrix.T.cpu().detach().numpy(), columns=image_descriptions, index=text_prompts)
    # adjust fig size to number of images and prompts
    fig, ax = plt.subplots(figsize=(size_unit+(size_unit*len(images)),size_unit+size_unit*len(text_prompts)))
    # plot heatmap
    heatmap = sns.heatmap(df, annot=True, cbar=False, cmap='RdYlGn', vmin=0, vmax=1, linewidths=0.5, linecolor='black', square=True, annot_kws={"fontsize": 12}, fmt=fmt)
    heatmap.set_yticklabels(heatmap.get_yticklabels(), rotation=0)

    # Set labels
    heatmap.set_xlabel("Patient images", weight='bold')  # x-axis label
    heatmap.set_ylabel("Text prompts", weight='bold')  # y-axis label

    # Calculate the positions of the cells for placing images
    y_positions = ax.get_yticks()
    x_positions = ax.get_xticks()

    # Add images to the top of each column
    for i, img in enumerate(images):

        # Make the image the same size as a cell
        imagebox = OffsetImage(img, zoom=0.11 * size_unit)

        # Place the image at the top of the column
        ab = AnnotationBbox(imagebox, (x_positions[i], 0), box_alignment=(0.5, -0.05), frameon=False)
        ax.add_artist(ab)
    # plt.title(title, fontsize=16, weight='bold')
    # fig.suptitle(title, fontsize=16, weight='bold')
    fig.text(0.5, 1.2, title, fontsize=12, weight='bold', ha='center')  # Add a title to the figure
    plt.show()

# Downstream Tasks

## Test patients - Indiana University, Chest X-rays

Patient class

In [None]:
class Patient():
    def __init__(self, image_url, report):
        # store image
        self.image = load_image(image_url)
        parsed_url = urlparse(image_url)
        # patient identifier
        self.patient_id = Path(parsed_url.path).stem
        # extract keywords from url
        self.keywords = parsed_url.query.split('=')[1].replace('%20', ' ')
        self.url = 'https://openi.nlm.nih.gov/detailedresult?img=' + self.patient_id
        # A radiology report consists of different sections, including indication, findings and impression
        # For the purpose of this exercise, we only look at findings that provide the most detailed account of the image
        self.report = report

    def __str__(self):
        # just for convenience to see patient details
        return f"""
Patient ID: {self.patient_id}
Patient URL: {self.url}

RADIOLOGY REPORT
-------------------------------------------
{self.report}
-------------------------------------------
\n\n
"""

    def __repr__(self):
        # display the image
        display(self.image)
        # Call __str__ method
        return self.__str__()

### Patient Dataset

In [None]:
import torch

# dictionary with patients
patients = {}

patients['healthy'] = Patient(
    image_url = 'https://openi.nlm.nih.gov/imgs/512/393/3200/CXR3200_IM-1512-1001.png?keywords=normal',
    report = 'Heart size is normal and the lungs are clear.'
)

patients['pneumonia'] = Patient(
    image_url = 'https://openi.nlm.nih.gov/imgs/512/276/677/CXR677_IM-2249-1001.png?keywords=Catheters,%20Indwelling,Lung,Density,Density,Density,Pleural%20Effusion,Pneumonia',
    report = 'PICC line catheter tip XXXX in the right atrium. Heart is not enlarged. Trachea and XXXX bronchi appear normal. Lungs are mildly under expanded. No pneumothorax. There are small areas of patchy density in the left lower lung XXXX. There is a larger area of XXXX patchy density in the right mid and lower lungs with right-sided pleural effusion.',
)

patients['atelectasis'] = Patient(
    image_url = 'https://openi.nlm.nih.gov/imgs/512/242/1445/CXR1445_IM-0287-4004.png?keywords=Diaphragm,Pulmonary%20Atelectasis,Consolidation,Pleural%20Effusion,Catheters,%20Indwelling,Tube,%20Inserted,Airspace%20Disease',
    report = 'Stable cardiomediastinal silhouette. There has been interval removal of right chest tube with increased elevation of the right hemidiaphragm and XXXX right basilar atelectasis. Left basilar consolidation and pleural effusions seen. No XXXX focal consolidation or pneumothorax. There is a stable left PICC with tip overlying the mid SVC and large XXXX feeding tube courses below the diaphragm.'
)

patients['cardiomegaly'] = Patient(
    image_url = 'https://openi.nlm.nih.gov/imgs/512/309/1111/CXR1111_IM-0077-4004.png?keywords=Technical%20Quality%20of%20Image%20Unsatisfactory%20,Cardiomegaly',
    report = 'Lordotic projection and large body habitus. Limited mediastinal evaluation. Severe cardiomegaly. No visualized pneumothorax. No large effusion or airspace disease. No fracture.',
)

patients['Nodules'] = Patient(
    image_url = 'https://openi.nlm.nih.gov/imgs/512/22/1626/CXR1626_IM-0407-1001.png?keywords=Nodule,Nodule',
    report = 'The heart is normal in size. The mediastinal contours are within normal limits. There are numerous bilateral pulmonary nodules of varying sizes. The largest is noted in the left lower lobe, posteriorly measuring approximately 7.0 cm. No acute infiltrate or pleural effusion are appreciated.'
)

for i, (description, patient) in enumerate(patients.items(), 1):
    print(f'Patient {i}:', description)
    display(patient)

patient_images = [patient.image for patient in patients.values()]
patient_reports = [patient.report for patient in patients.values()]
patient_descriptions = [description for description in patients.keys()]

## Zero-shot X-Ray Classification

### Contrastive binary classification
The basic idea of contrastive zero-shot classification is to encode both a positive and negative description (e.g. presence and absence) of a class to be predicted. Next, the image is encoded in the same space and then evaluated if it is closer to the positive or negative text embedding. The [softmax function](https://en.wikipedia.org/wiki/Softmax_function) allows us to estimate a probability for this prediction.

In [None]:
def get_similarities_from_text_and_images(text_prompts, images):
    text_embeddings = get_text_embeddings(text_prompts)
    image_embeddings = get_image_embeddings(images)
    similarities = calculate_cosine_similarity(image_embeddings, text_embeddings)
    return similarities

#### Basic Prompting

##### Healthy

In [None]:
text_prompts = [
    'healthy',
    'not healthy',
]
similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

##### Cardiomegaly

In [None]:
text_prompts = [
    'cardiomegaly',
    'no cardiomegaly',
]
similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

##### Pleural effusion

In [None]:
text_prompts = [
    'pleural effusion',
    'no pleural effusion',
]
similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

##### Pneumonia

In [None]:
text_prompts = [
    'pneumonia',
    'no pneumonia',
]
similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

##### Atelectasis

In [None]:
text_prompts = [
    'atelectasis',
    'no atelectasis',
]
similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

#### Report Style Prompting

In [None]:
# dictionary to save prompts
prompts = {}

##### Healthy

In [None]:
prompts['healthy'] = '''
The lungs are clear.
Normal heart size and shape.
No abnormal fluid buildup.
No visible tumors or masses. No pneumothorax.
'''

prompts['not healthy'] = '''
There is an area of increased opacity and consolidation indicating pneumonia.
Enlargement of the heart silhouette indicating cardiomegaly.
There is a loss in volume indicating atelectasis.
'''

text_prompts = [
    prompts['healthy'],
    prompts['not healthy']
]
similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

##### Cardiomegaly

In [None]:
prompts['cardiomegaly'] = '''
Increased size of the heart shadow.
Enlargement of the heart silhouette.
Increased diameter of the heart border.
Increased cardiothoracic ratio.
'''
prompts['no cardiomegaly'] = '''
The heart shadow size is normal.
The heart silhouette is normal.
Normal diameter of the heart border.
Normal cardiothoracic ratio.
'''

text_prompts = [
    prompts['cardiomegaly'],
    prompts['no cardiomegaly']
]

similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

##### Pleural Effusion

In [None]:
prompts['pleural effusion'] = '''
Blunting of costophrenic angles.
Opacity in the lower lung fields.
Mediastinal shift.
Reduced lung volume.
Presence of meniscus sign or veil-like appearance.
'''

prompts['no pleural effusion'] = '''
No blunting of costophrenic angles.
No opacity in the lower lung fields.
The lungs are clear.
No mediastinal shift.
No presence of meniscus sign or veil-like appearance.
'''

text_prompts = [
    prompts['pleural effusion'],
    prompts['no pleural effusion']
]

similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

##### Pneumonia

In [None]:
prompts['pneumonia'] = 'There is an area of increased opacity and consolidation indicating pneumonia.'
prompts['no pneumonia'] = 'There are no opacities, no consolidation and no pleural effusion. No signs of pneumonia.'

text_prompts = [
    prompts['pneumonia'],
    prompts['no pneumonia']
]

similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

##### Atelectasis

In [None]:
prompts['atelectasis'] = '''
The lung appears reduced in volume supporting the presence of atelectasis.
The mediastinum shows a shift, consistent with volume loss associated with atelectasis.
These findings are consistent with atelectasis.
'''

prompts['no atelectasis'] = '''
Both lungs are well-expanded with clear lung fields.
The lung volumes appear normal and symmetric, with no apparent reduction in the size of either lung.
The mediastinum is positioned centrally, without evidence of mediastinal shift.
No radiographic signs of atelectasis are present. The lungs appear normally aerated and expanded.
'''

text_prompts = [
    prompts['atelectasis'],
    prompts['no atelectasis']
]

similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranging from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

## Stanardized Report Generation

#### Structured reporting template definition

In [None]:
reporting_template = {
        'healthy' : 'The lungs are clear. No findings.',
        'not healthy' :
            [
                {
                    'no cardiomegaly' : 'The heart is normal in size.',
                    'cardiomegaly' : 'There is cardiomegaly.',
                },
                {
                    'no pleural effusion': 'There is no pleural effusion.',
                    'pleural effusion': 'There is pleural effusion.'
                },
            ],
}

#### Report generation

In [None]:
# define recursive report generation method

def generate_report(image, reporting_template):
    report_sentences = []
    if isinstance(reporting_template, dict):
        choices = list(reporting_template.keys())
        text_prompts = [prompts[k] for k in choices]
        similarities = get_similarities_from_text_and_images(text_prompts, [image])
        decision = choices[similarities.argmax()]
        decision_content = reporting_template[decision]
        if isinstance(decision_content, str):
            report_sentences.append(decision_content)
        else:
            report_sentences.extend(
                generate_report(image, decision_content)
            )
    elif isinstance(reporting_template, list):
        for sub_report in reporting_template:
            report_sentences.extend(
                generate_report(image, sub_report)
            )
    return report_sentences

In [None]:
# generate report for all patients

for i, image in enumerate(patient_images, 1):
    print('Patient', i)
    display(image)
    report = ' '.join(generate_report(image, reporting_template))
    print('Generated report:', report,'\n\n')

#### Open tasks

1.  Why is cardiomegaly detected in patient 3 even though it is not present? How could this be fixed?

2. Change the prompts and observe the change in similarities

3. Add a new classification task e.g. lung opacity and come up with a report style prompt (try if ChatGPT can come up with good descriptors)

2. Extend the reporting with more choices, e.g. adding diagnoses or the severity assessment of cardiomegaly

# Additional Resources

## Radiology Report retrieval

Given a large database of reports the embedding of an image can be used to retrieve the reports most similar to the given image. In this example the images are compared to their matching reports:

In [None]:
# write all ground truth reports in a list as text prompts
text_prompts = [report.replace('. ', '.\n') for report in patient_reports]

similarities = get_similarities_from_text_and_images(text_prompts, patient_images)
print('Cosine similarities ranging from -1 to 1:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=False)
print('\nSoftmax probabilities ranginf from 0% to 100%:')
plot_similarities(similarities, text_prompts, patient_images, patient_descriptions, plot_probabilities=True)

## Phrase grounding

BioVil allows for the visualization of text similarity in images referred to as "phrase grounding" [1]:

![BioVil phrase grounding examples](https://www.microsoft.com/en-us/research/uploads/prod/2022/07/MS-CXR-2048x472.png)

[Link to phrase grounding notebook](https://github.com/microsoft/hi-ml/blob/main/hi-ml-multimodal/notebooks/phrase_grounding.ipynb)

[BioVil code on github](https://github.com/microsoft/hi-ml/tree/main/hi-ml-multimodal)


## Xplainer: From X-Ray Observations to Explainable Zero-Shot Diagnosis

![Xplainer graphical abstract](https://raw.githubusercontent.com/ChantalMP/Xplainer/master/figures/model_overview.png)

We propose a new way of explainability for zero-shot diagnosis prediction in the clinical domain. Instead of directly predicting a diagnosis, we prompt the model to classify the existence of descriptive observations, which a radiologist would look for on an X-Ray scan, and use the descriptor probabilities to estimate the likelihood of a diagnosis, making our model explainable by design. For this we leverage BioVil, a pretrained CLIP model for X-rays and apply contrastive observation-based prompting. We evaluate Xplainer on two chest X-ray datasets, CheXpert and ChestX-ray14, and demonstrate its effectiveness in improving the performance and explainability of zero-shot diagnosis.

Pellegrini, Chantal, et al. "Xplainer: From X-Ray Observations to Explainable Zero-Shot Diagnosis." accepted at MICCAI 2023.

[MICCAI 2023 Paper](https://link.springer.com/chapter/10.1007/978-3-031-43904-9_41)

[Huggingface Demo](https://huggingface.co/spaces/CAMP-ViL/Xplainer)

[Preprint on arxiv](https://arxiv.org/abs/2303.13391)

[Code on GitHub](https://github.com/ChantalMP/Xplainer)
