# Speed Test

In [None]:
import torch
import numpy as np
from os.path import join, dirname
from PIL import Image
import cv2
from datetime import datetime

import usam
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

from usam.patch_sam2 import patch_sam2
from fvcore.nn import FlopCountAnalysis, parameter_count_table, parameter_count


root = dirname(usam.__path__[0])
model_cfg = join(root, "models", "sam", "configs_2.0", "sam2_hiera_l.yaml")
checkpoint = join(root, "models", "sam", "checkpoints_2.0", "sam2_hiera_large.pt")
#model_cfg = join(root, "models", "sam", "configs_2.0", "sam2_hiera_b+.yaml")
#checkpoint = join(root, "models", "sam", "checkpoints_2.0", "sam2_hiera_base_plus.pt")
#model_cfg = join(root, "models", "sam", "configs_2.0", "sam2_hiera_s.yaml")
#checkpoint = join(root, "models", "sam", "checkpoints_2.0", "sam2_hiera_small.pt")
#model_cfg = join(root, "models", "sam", "configs_2.0", "sam2_hiera_t.yaml")
#checkpoint = join(root, "models", "sam", "checkpoints_2.0", "sam2_hiera_tiny.pt")
mlp_dir = join(root, "models", "mlps", "sam2.0")


image = Image.open(join(root, "assets", "dog_sample.png"))
image = np.array(image.convert("RGB"))
input_point = np.array([[256, 256]])
input_label = np.array([1])


num_iterations = 50
    
    
# Load predictor
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
print(parameter_count_table(predictor.model))
print("Parameters: ", parameter_count(predictor.model))

# Do some warmups to make the GPU ready
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    for _ in range(num_iterations):
        predictor.set_image(image)
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )

## Speed Test for MC-based image augmented SAM2

In [None]:
num_img_augmentations = 5

# Run one iteration for warmup
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    for _ in range(num_iterations):
        for _ in range(num_img_augmentations):
            aug_image = cv2.flip(image, 1)
            predictor.set_image(aug_image)
            masks, scores, logits = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=False,
            )
    
# Run speed test
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    t0 = datetime.now()
    for _ in range(num_iterations):
        for _ in range(num_img_augmentations):
            aug_image = cv2.flip(image, 1)
            predictor.set_image(aug_image)
            masks, scores, logits = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=False,
            )
    t1 = datetime.now()
    t_image_aug = (t1 - t0).total_seconds() / num_iterations
    print(f"MC-simulation with image augmentation in SAM2: {(t1 - t0).total_seconds() / num_iterations:.5f} s per iteration")

## Speed Test for MC-based prompt augmented SAM2

In [None]:
num_prompt_augmentations = 8

# Run one iteration for warmup
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(image)
    for _ in range(num_prompt_augmentations):
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
    
# Run speed test
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    t0 = datetime.now()
    for _ in range(num_iterations):
        predictor.set_image(image)
        for _ in range(num_prompt_augmentations):
            masks, scores, logits = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=False,
            )
    t1 = datetime.now()
    t_prompt_aug = (t1 - t0).total_seconds() / num_iterations
    print(f"MC-simulation with prompt augmentation in SAM2: {(t1 - t0).total_seconds() / num_iterations:.5f} s per iteration")

## Speed Test for standard SAM2 with entropy calculation

In [None]:
def entropy(mask):
    mask = np.clip(mask, 1e-6, 1 - 1e-6)
    e = -mask * np.log2(mask) - (1 - mask) * np.log2(1 - mask)
    e = mask * e
    e = np.sum(e) / np.sum(mask)
    return e

# Run one iteration for warmup
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )
    e = entropy(masks[0])

# Run speed test
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    t0 = datetime.now()
    for _ in range(num_iterations):
        predictor.set_image(image)
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
        e = entropy(masks[0])
    t1 = datetime.now()
    t_entropy = (t1 - t0).total_seconds() / num_iterations
    print(f"Entropy of SAM2: {(t1 - t0).total_seconds() / num_iterations:.5f} s per iteration")

## Speed Test for normal SAM2 

In [None]:


# Run one iteration for warmup
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )

# Run speed test
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    t0 = datetime.now()
    for _ in range(num_iterations):
        predictor.set_image(image)
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
    t1 = datetime.now()
    t_standard = (t1 - t0).total_seconds() / num_iterations
    print(f"Normal SAM2: {(t1 - t0).total_seconds() / num_iterations:.5f} s per iteration")
    

## Speed Test for USAM extended SAM2

In [None]:
# Patch SAM2 with USAM extension
patch_sam2(predictor, mlp_dir)

# Print MLP Parameters
for name, mlp in predictor.model.sam_mask_decoder.regression_models.items():
    print(f"MLP {name}: {parameter_count_table(mlp)}") 
    print("Parameters: ", parameter_count(mlp))

# Run one iteration for warmup
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(image)
    masks, scores, logits, mlp_scores = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )
    num_mlps = len(mlp_scores)
    print("There are {} MLPs.".format(len(mlp_scores)))
    
# Run speed test
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    t0 = datetime.now()
    for _ in range(num_iterations):
        predictor.set_image(image)
        masks, scores, logits, mlp_scores = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
    t1 = datetime.now()
    t_usam = (t1 - t0).total_seconds() / num_iterations
    print(f"USAM extension in SAM2: {(t1 - t0).total_seconds() / num_iterations:.5f} s per iteration")

## Summarize results

In [None]:
print(f"SAM2: {t_standard:.5f} s per iteration")
print(f"SAM2 with Entropy: {t_entropy:.5f} s per iteration")
print(f"SAM2 with {num_img_augmentations} MC image augmentations: {t_image_aug:.5f} s per iteration")
print(f"SAM2 with {num_prompt_augmentations} MC prompt augmentations: {t_prompt_aug:.5f} s per iteration")
print(f"SAM2 with all {num_mlps} (!) MLPs proposed in USAM: {t_usam:.5f} s per iteration")
