# Task1.3: Exploring Segment Anything (SA)

In this last task of segmentation, we will explore how to make use of some state-of-the art segmentation models to deal with medical images. 

Vision Transformer (ViT) model was introduced to alleviate the deficiency of CNNs in capturing the long-range semantic dependencies. One particular example of ViT is  the Segment Anything (SA) project. The visualized plots generated in task 1.1 motivates the idea that including some kind of "guidance" for the model to accurately segment the desired area. Such thought can be realized through **prompt**. The SA model is designed and trained to be promptable, so it can transfer zero-shot to new image distributions and tasks.

> Kirillov, Alexander, et al. "Segment anything." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023.

The `segment_anything` folder is provided for you, which can also be downloaded using the [link](https://github.com/facebookresearch/segment-anything). 

We will use the ViT-B SAM model that is the smallest in size of parameters. Please download the corresponding checkpoint, which is named as `vit_b` in the Github repository.

In [None]:
import sys
sys.path.append('..')
from utils.Task1_utils import show_box, show_points, show_mask
import torch
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "model/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

For simplicity, let us visualize the second slice of ED phase of the first patient in ACDC dataset as the first slice doesn't contain the annotation.

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import cv2

# todo: Enter the same path in previous task.
path_ACDC_all = None

nim_image = os.path.join(path_ACDC_all, "patient001", "patient001_frame01.nii.gz")
nim_mask = os.path.join(path_ACDC_all, "patient001", "patient001_frame01_gt.nii.gz")
image = nib.load(nim_image).get_fdata()[:, :, 1]
# normalize and resize
image = (image - image.min()) / (image.max() - image.min()) * 255
image = image.astype(np.uint8)
image = cv2.resize(image, (512, 512))
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
mask = nib.load(nim_mask).get_fdata()[:, :, 1]
mask = np.where(mask == 0, np.nan, mask)
mask = cv2.resize(mask, (512, 512))

plt.imshow(image, cmap="gray")
plt.imshow(mask, cmap="jet", alpha=0.5)

predictor.set_image(image)

Your job is to segment the left ventricle, myocardium as well as right ventricles using the SA model.

## Segmenting Left Ventricle

As illustrated in the plot, the left ventricle has an approximate shape of a circle. Please draw a bounding box as the prompt supplied for the SA model.

**Note: It is almost inevitable to have the myocardium in the mask if we only draw a bounding box!**

In [None]:
# todo: Give one working bounding box
input_box = None

masks, _, _ = predictor.predict(
                    point_coords=None,
                    point_labels=None,
                    box=input_box,
                    multimask_output=False,
                )

plt.figure(figsize=(10, 8))
plt.imshow(image, cmap="gray")
show_box(input_box, plt.gca())
show_mask(masks[0], plt.gca())

## Segmenting Right Ventricle

You may have failed to exclude the myocardium in the left ventricle segmentation. Fortunately, SA model allows us to choose points to exclude from the mask. Please draw a bounding box to segment the myocardium while also providing exclusion points so that left ventricle is not inclued.

In [None]:
# todo: Give one working bounding box and one set of exclusion points
input_box = None
input_point = None
input_label = None

masks, _, _ = predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    multimask_output=False,
                    box=input_box,
                )

plt.figure(figsize=(10, 8))
plt.imshow(image, cmap="gray")
show_box(input_box, plt.gca())
show_points(input_point, input_label,plt.gca())
show_mask(masks[0], plt.gca())

Question: Right ventricle tends to have significant shape variation. For such problem, can you think of another approach except for changing the architecture of the model?

One solution is to use the **boundary loss** instead of pixel-level or region-level loss. You can refer to the following review paper for more information.

> Azad, R., Heidary, M., Yilmaz, K., HÃ¼ttemann, M., Karimijafarbigloo, S., Wu, Y., Schmeink, A., & Merhof, D. (2023). Loss Functions in the Era of Semantic Segmentation: A Survey and Outlook (No. arXiv:2312.05391). arXiv. http://arxiv.org/abs/2312.05391