<a href="https://colab.research.google.com/github/nye0/SAM-Med2D/blob/main/predictor_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SAM Comparison
This notebook contains a stripped down version of the predictor_example.ipynb file from the SAM repo as well as code to implement my model and thus verify that they perform identically.

In [1]:
# Change path to be one folder up to permit appropriate imports
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

## Set-up

In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


## Compare point-prompt based segmentations
### Use extracted code from predictor_example.ipynb to generate a segmentation

In [6]:
# Example image
image = cv2.imread('images/truck.jpg')

# First, load the SAM model and predictor.
from segment_anything import sam_model_registry, SamPredictor


sam_checkpoint = '/home/t722s/Desktop/UniversalModels/TrainedModels/sam_vit_h_4b8939.pth'
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

# Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction.
predictor.set_image(image)

# To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image.
input_point = np.array([[500, 375]])
input_label = np.array([1])

# Predict with `SamPredictor.predict`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction.
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=False,
)

### Now generate the same segmentation using code from this package

In [12]:
import numpy as np
import cv2
from classes.SAMClass import SAMInferer
from utils.base_classes import Points

# Load Model
device = 'cuda'
checkpoint_path = '/home/t722s/Desktop/UniversalModels/TrainedModels/sam_vit_h_4b8939.pth'
inferer = SAMInferer(checkpoint_path, device)

# My model takes 3d grayscale images, so take the image, remove the color dimension and give it a z dimension
img_path = 'images/truck.jpg'
img_2d = cv2.imread(img_path)
img_3d = img_2d[:,:,0][None]


# Obtain same prompt as in demo
input_point = np.array([500, 375])
input_point_3d = np.concatenate([input_point, [0]]) # include a z dimension for the point
input_point_3d = input_point_3d[::-1] # Reverse order so xyz->zyx
input_point_3d = input_point_3d[None] # Give N dimension
input_label = np.array([1])

prompt = Points(coords = input_point_3d, labels = input_label)

# Segment
segmentation = inferer.predict(img_3d, prompt)

# Convert back to 2d to compare with original code and verify equality
seg_2d = segmentation[0] # Select slice for z=0

seg_original = masks[0] # Select first mask (only one mask returned in this case; look at comment)

print(f'Arrays equal? {np.array_equal(seg_2d, seg_original)}')

Performing inference on slices: 100%|██████████| 1/1 [00:00<00:00,  1.40it/s]

Arrays equal? False





In [13]:
np.count_nonzero(seg_2d != seg_original)

2067

## Compare boxes
### From predictor_example.ipynb:

In [11]:
image = cv2.imread('data_demo/images/s0114_111.png')
predictor.set_image(image)
input_box = np.array([89,43,113,64]) #

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box,
    multimask_output=True,
)

## From this package:

In [12]:
from utils.base_classes import Boxes

# Make image 3d grayscale
img_path = "data_demo/images/s0114_111.png"
img_2d = cv2.imread(img_path)
img_3d = img_2d[:,:,0][None]

# Obtain prompt
input_box = np.array([89,43,113,64])
prompt = Boxes({0:input_box})

# Segment
segmentation = inferer.predict(img_3d, prompt)

## Convert back to 2d to compare with original code and verify equality
seg_2d = segmentation[0]
seg_original = masks[0]

print(f'Arrays equal? {np.array_equal(seg_2d, seg_original)}')

Using previously generated image embeddings


Performing inference on slices: 100%|██████████| 1/1 [00:00<00:00, 25.99it/s]

Arrays equal? True



