# Benchmark segment anything on validation set

## Segment anything on brain segmentation dataset

To run this notebook, following instruction at https://github.com/facebookresearch/segment-anything to install dependency and backbone model.

Model weights should be placed at `./weights/sam_vit_b_01ec64.pth` for example.

This notebook has following sections:

1. Benchmark Segment Anything model with central prompt

In [15]:
import sys

%load_ext autoreload
%autoreload 2
      
sys.path.append('../')

import torch
from dataset import data_loaders, BrainSegmentationDataset
from utils import postprocess_per_volume, dsc_distribution, plot_dsc, gray2rgb, outline
from skimage.io import imsave, imshow
import numpy as np
from matplotlib import pyplot as plt
import cv2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Dataloader for validation set only

In [4]:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

batch_size = 16
epochs = 50
lr = 0.0001
workers = 2
weights = "./"
image_size = 224
aug_scale = 0.05
aug_angle = 15

_, loader_valid = data_loaders(batch_size, workers, image_size, aug_scale, aug_angle, path="../kaggle_3m", valid_only=True)

reading validation images...
preprocessing validation volumes...
cropping validation volumes...
padding validation volumes...
resizing validation volumes...
normalizing validation volumes...
done creating validation dataset


## Load Segment Anything Model

In [5]:
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "../weights/sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

## Use central point as input

The original size is (224, 224), so central point is (112, 112)

In [6]:
input_point = np.array([[112, 112]])
input_label = np.array([1])

## Evaluate Segment Anything Model on `loader_vaild`

Important shapes:

x: (16, 3, 224, 224), where 3 is channel number

y_true: (16, 1, 224, 224)

16 is default batch size, which can be manipulated by `batch_size` argument in `data_loaders` function.


**Notice that the range of y_true entry is [0, 255]**, where 0 is background (black), 255 is mask (white, groundtruth)

In [None]:
input_list = []
pred_list = []
true_list = []

predictor = SamPredictor(sam)

In [20]:



for i, data in enumerate(loader_valid):
    x, y_true = data
    x, y_true = x.to(device), y_true.to(device)

    for j in range(batch_size):
    
        with torch.set_grad_enabled(False):
            print(f"{x[j].shape = }")
            # load image into predictor
            im_now = np.transpose(x[j], (1, 2, 0))
            im_now = im_now.detach().cpu().numpy().astype(int)

            print(f"{im_now.shape = }")

            print(f"{im_now = }")

            im_now = cv2.cvtColor(im_now, cv2.COLOR_BGR2RGB)
            
            print(f"{im_now.shape = }")
            predictor.set_image(im_now)

            # predict masks
            masks_central, score_centrals, logits_central = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True,
            )
            
            # choose the mask with largest score as prediction
            print(f"{masks_central.shape = }")
            y_pred = masks_central[0] # (3, 224, 224)
            # y_pred = np.transpose(y_pred, (1, 2, 0)) # (3, 224, 224) -> (224, 224, 3)
            print(f"y_pred.shape: {y_pred.shape}")

            # y_pred_np = y_pred.detach().cpu().numpy()
            pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
            y_true_np = y_true.detach().cpu().numpy()
            true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
            x_np = x.detach().cpu().numpy()
            input_list.extend([x_np[s] for s in range(x_np.shape[0])])
        
    break


x[j].shape = torch.Size([3, 224, 224])
im_now.shape = (224, 224, 3)
im_now = array([[[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       ...,

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]])


error: OpenCV(4.7.0) /Users/xperience/GHA-OCV-Python/_work/opencv-python/opencv-python/opencv/modules/imgproc/src/color.simd_helpers.hpp:94: error: (-2:Unspecified error) in function 'cv::impl::(anonymous namespace)::CvtHelper<cv::impl::(anonymous namespace)::Set<3, 4, -1>, cv::impl::(anonymous namespace)::Set<3, 4, -1>, cv::impl::(anonymous namespace)::Set<0, 2, 5>, cv::impl::(anonymous namespace)::NONE>::CvtHelper(cv::InputArray, cv::OutputArray, int) [VScn = cv::impl::(anonymous namespace)::Set<3, 4, -1>, VDcn = cv::impl::(anonymous namespace)::Set<3, 4, -1>, VDepth = cv::impl::(anonymous namespace)::Set<0, 2, 5>, sizePolicy = cv::impl::(anonymous namespace)::NONE]'
> Unsupported depth of input image:
>     'VDepth::contains(depth)'
> where
>     'depth' is 4 (CV_32S)
