# **Segment Anything Model 2 (SAM 2)**
![SAM2](https://github.com/ibaiGorordo/ONNX-SAM2-Segment-Anything/raw/main/doc/img/sam2_annotation.gif)

## Installation !!Requires GPU runtime!!

In [1]:
%cd /content
!git clone https://github.com/facebookresearch/segment-anything-2.git
%cd /content/segment-anything-2
!pip3 install -e .
!pip3 install onnx onnxscript onnxsim onnxruntime

/content
fatal: destination path 'segment-anything-2' already exists and is not an empty directory.
/content/segment-anything-2
Obtaining file:///content/segment-anything-2
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: SAM-2
  Building editable for SAM-2 (pyproject.toml) ... [?25l[?25hdone
  Created wheel for SAM-2: filename=SAM_2-1.0-0.editable-cp311-cp311-linux_x86_64.whl size=13812 sha256=fd7249f8c158beb5d0360dd2a4430a9cec86a40c950f46f13958c0496256bc28
  Stored in directory: /tmp/pip-ephem-wheel-cache-45422yps/wheels/79/7c/e1/0da3f0d4adfcc74ea4d1578b1a77a5a1647d6dc06af87a30e7
Successfully built SAM-2
Installing collected packages: SAM-2
  Attempting uninstall: SAM-2
    Found existing installation: SAM-2 1.0
    Uninstalling SAM-2-1.0:

In [2]:
%cd /content/segment-anything-2/checkpoints
!./download_ckpts.sh

/content/segment-anything-2/checkpoints
Downloading sam2.1_hiera_tiny.pt checkpoint...
--2025-02-11 22:35:39--  https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.227.219.10, 13.227.219.70, 13.227.219.59, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.227.219.10|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 156008466 (149M) [application/vnd.snesdev-page-table]
Saving to: ‘sam2.1_hiera_tiny.pt’


2025-02-11 22:35:40 (226 MB/s) - ‘sam2.1_hiera_tiny.pt’ saved [156008466/156008466]

Downloading sam2.1_hiera_small.pt checkpoint...
--2025-02-11 22:35:40--  https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.227.219.10, 13.227.219.70, 13.227.219.59, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.227.219.10|:443... connected.
HTTP request sent,

In [3]:
%cd /content/segment-anything-2/
from typing import Optional, Tuple, Any
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_


from sam2.modeling.sam2_base import SAM2Base

class SAM2ImageEncoder(nn.Module):
    def __init__(self, sam_model: SAM2Base) -> None:
        super().__init__()
        self.model = sam_model
        self.image_encoder = sam_model.image_encoder
        self.no_mem_embed = sam_model.no_mem_embed

    def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
        backbone_out = self.image_encoder(x)
        backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(
            backbone_out["backbone_fpn"][0]
        )
        backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(
            backbone_out["backbone_fpn"][1]
        )

        feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels:]
        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels:]

        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]

        # flatten NxCxHxW to HWxNxC
        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
        vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]

        vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
                 for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]

        return feats[0], feats[1], feats[2]


class SAM2ImageDecoder(nn.Module):
    def __init__(
            self,
            sam_model: SAM2Base,
            multimask_output: bool
    ) -> None:
        super().__init__()
        self.mask_decoder = sam_model.sam_mask_decoder
        self.prompt_encoder = sam_model.sam_prompt_encoder
        self.model = sam_model
        self.multimask_output = multimask_output

    @torch.no_grad()
    def forward(
            self,
            image_embed: torch.Tensor,
            high_res_feats_0: torch.Tensor,
            high_res_feats_1: torch.Tensor,
            point_coords: torch.Tensor,
            point_labels: torch.Tensor,
            mask_input: torch.Tensor,
            has_mask_input: torch.Tensor,
            img_size: torch.Tensor
    ):
        sparse_embedding = self._embed_points(point_coords, point_labels)
        self.sparse_embedding = sparse_embedding
        dense_embedding = self._embed_masks(mask_input, has_mask_input)

        high_res_feats = [high_res_feats_0, high_res_feats_1]
        image_embed = image_embed

        masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
            image_embeddings=image_embed,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embedding,
            dense_prompt_embeddings=dense_embedding,
            repeat_image=False,
            high_res_features=high_res_feats,
        )

        if self.multimask_output:
            masks = masks[:, 1:, :, :]
            iou_predictions = iou_predictions[:, 1:]
        else:
            masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(masks, iou_predictions)

        masks = torch.clamp(masks, -32.0, 32.0)
        print(masks.shape, iou_predictions.shape)

        masks = F.interpolate(masks, (img_size[0], img_size[1]), mode="bilinear", align_corners=False)

        return masks, iou_predictions

    def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:

        point_coords = point_coords + 0.5

        padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
        padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
        point_coords = torch.cat([point_coords, padding_point], dim=1)
        point_labels = torch.cat([point_labels, padding_label], dim=1)

        point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
        point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size

        point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
        point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)

        point_embedding = point_embedding * (point_labels != -1)
        point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (
                point_labels == -1
        )

        for i in range(self.prompt_encoder.num_point_embeddings):
            point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)

        return point_embedding

    def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
        mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(input_mask)
        mask_embedding = mask_embedding + (
                1 - has_mask_input
        ) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
        return mask_embedding

/content/segment-anything-2


## Select  model parameters

In [4]:
model_type = 'sam2_hiera_tiny' #@param ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
# input_size = 768 #@param {type:"slider", min:160, max:4102, step:8}
input_size = 1024 # Bad output if anything else (for now)
multimask_output = False

if model_type == "sam2_hiera_tiny":
    model_cfg = "sam2_hiera_t.yaml"
elif model_type == "sam2_hiera_small":
    model_cfg = "sam2_hiera_s.yaml"
elif model_type == "sam2_hiera_base_plus":
    model_cfg = "sam2_hiera_b+.yaml"
else:
    model_cfg = "sam2_hiera_l.yaml"


## Export Encoder

In [8]:

%cd /content/segment-anything-2/
import torch
from sam2.build_sam import build_sam2

sam2_checkpoint = f"/content/segment-anything-2/checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")

img=torch.randn(1, 3, input_size, input_size).cpu()

sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
high_res_feats_0, high_res_feats_1, image_embed = sam2_encoder(img)
print(high_res_feats_0.shape)
print(high_res_feats_1.shape)
print(image_embed.shape)

torch.onnx.export(sam2_encoder,
                  img,
                  f"{model_type}_encoder.onnx",
                  export_params=True,
                  opset_version=17,
                  do_constant_folding=True,
                  input_names = ['image'],
                  output_names = ['high_res_feats_0', 'high_res_feats_1', 'image_embed']
                )

/content/segment-anything-2
torch.Size([1, 32, 256, 256])
torch.Size([1, 64, 128, 128])
torch.Size([1, 256, 64, 64])


  if pad_h > 0 or pad_w > 0:
  if Hp > H or Wp > W:


## Export Decoder

In [9]:
%cd /content/segment-anything-2/


sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=multimask_output).cpu()

embed_dim = sam2_model.sam_prompt_encoder.embed_dim
embed_size = (sam2_model.image_size // sam2_model.backbone_stride, sam2_model.image_size // sam2_model.backbone_stride)
mask_input_size = [4 * x for x in embed_size]
print(embed_dim, embed_size, mask_input_size)

point_coords = torch.randint(low=0, high=input_size, size=(1, 5, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float)
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float)
has_mask_input = torch.tensor([1], dtype=torch.float)
orig_im_size = torch.tensor([input_size, input_size], dtype=torch.int32)

masks, scores = sam2_decoder(image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size)


torch.onnx.export(sam2_decoder,
                  (image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size),
                  "decoder.onnx",
                  export_params=True,
                  opset_version=16,
                  do_constant_folding=True,
                  input_names = ['image_embed', 'high_res_feats_0', 'high_res_feats_1', 'point_coords', 'point_labels', 'mask_input', 'has_mask_input', 'orig_im_size'],
                  output_names = ['masks', 'iou_predictions'],
                  dynamic_axes = {"point_coords": {0: "num_labels", 1: "num_points"},
                                  "point_labels": {0: "num_labels", 1: "num_points"},
                                  "mask_input": {0: "num_labels"},
                                  "has_mask_input": {0: "num_labels"}
                  }
                )


/content/segment-anything-2
256 (64, 64) [256, 256]
torch.Size([1, 1, 256, 256]) torch.Size([1, 1])


  assert image_embeddings.shape[0] == tokens.shape[0]
  image_pe.size(0) == 1


torch.Size([1, 1, 256, 256]) torch.Size([1, 1])




## Simplify models

In [10]:
%cd /content/segment-anything-2/
!onnxsim {model_type}_encoder.onnx {model_type}_encoder.onnx
!onnxsim decoder.onnx decoder.onnx

/content/segment-anything-2
[1;35mYour model contains "Tile" ops or/and "ConstantOfShape" ops. Folding these ops can make the [0m
[1;35msimplified model much larger. If it is not expected, please specify "--no-large-tensor" (which will [0m
[1;35mlose some optimization chances)[0m
Simplifying[33m...[0m
Finish! Here is the difference:
┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1m                  [0m[1m [0m┃[1m [0m[1mOriginal Model[0m[1m [0m┃[1m [0m[1mSimplified Model[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ Abs                │ 1              │ [1;32m0               [0m │
│ Add                │ 127            │ [1;32m90              [0m │
│ Cast               │ 155            │ [1;32m0               [0m │
│ Concat             │ 119            │ [1;32m0               [0m │
│ Constant           │ 1205           │ [1;32m218             [0m │
│ ConstantOfShape    │ 11             │ [1;32m0          

## Optional, mount GDrive for faster model download (Copy it to your Google Drive and then download)

In [11]:
from google.colab import drive
drive.mount('/content/gdrive',force_remount=True)

Mounted at /content/gdrive


In [12]:
%cd /content/segment-anything-2/
!cp {model_type}_encoder.onnx '/content/gdrive/My Drive/'
!cp decoder.onnx '/content/gdrive/My Drive/'

/content/segment-anything-2
