In [3]:
import torch

print("PyTorch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from segment_anything.utils.onnx import SamOnnxModel

import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

PyTorch version: 2.0.1+cu117
CUDA is available: True


In [None]:
! conda list

In [None]:
parameters = {

    # dataset:
    'image_path': "D:/Project/WovenBagDetection/Datasets/MiniBatchForTest/",
    'save_path': "D:/Project/WovenBagDetection/WBDetectionWithPytorch/Test_Result",

    # model:
    'SAM_checkpoint_path': "D:/Project/WovenBagDetection/ModelCheckpoint_Pytorch/SegmentAnything/SAM_vit_b/sam_vit_b_01ec64.pth",
    'model_type': "vit_b",
    'ONNX_model_path': "D:/Project/WovenBagDetection/ONNX/sam_onnx_example.onnx"
}

In [None]:
sam = sam_model_registry[parameters['model_type']](checkpoint = parameters['SAM_checkpoint_path'])
#sam.to(device = "cuda")

mask_generator = SamAutomaticMaskGenerator(
    model = sam,
    points_per_side = 32,
    pred_iou_thresh = 0.8,
    stability_score_thresh = 0.9,
    crop_n_layers = 1,
    crop_n_points_downscale_factor = 2,
    min_mask_region_area = 100,  # Requires open-cv to run post-processing
)

In [None]:
import warnings

onnx_model = SamOnnxModel(sam, return_single_mask = True)

dynamic_axes = {
    "point_coords": {1: "num_points"},
    "point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    with open(parameters["ONNX_model_path"], "wb") as f:
        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=17,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )

In [None]:
onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
quantize_dynamic(
    model_input=parameters["ONNX_model_path"],
    model_output=onnx_model_quantized_path,
    optimize_model=True,
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)
onnx_model_path = onnx_model_quantized_path

# Test

In [None]:
import cv2
import onnxruntime
import matplotlib.pyplot as plt

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel

In [None]:
test_param = {
    'test_image_path': "D:/Project/WovenBagDetection/Datasets/MiniBatchForTest/NG/202363016245851.bmp",
    'ONNX_model_path': "D:/Project/WovenBagDetection/ONNX/sam_onnx_example.onnx",
    'model_type': 'vit_b',
    'SAM_checkpoint_path': "D:/Project/WovenBagDetection/ModelCheckpoint_Pytorch/SegmentAnything/SAM_vit_b/sam_vit_b_01ec64.pth"
}

In [None]:
image = cv2.imread(test_param['test_image_path'])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
ort_session = onnxruntime.InferenceSession(test_param['ONNX_model_path'])

In [None]:
sam = sam_model_registry[test_param['model_type']](checkpoint = test_param['SAM_checkpoint_path'])
sam.to(device = 'cuda')
predictor = SamPredictor(sam)
predictor.set_image(image)

In [None]:
image_embedding = predictor.get_image_embedding().cpu().numpy()

In [None]:
image_embedding.shape

In [None]:
plt.imshow(image_embedding)

In [1]:
! python C:\Users\92736\segment-anything\scripts\export_onnx_model.py \
    --checkpoint D:\Project\WovenBagDetection\GenerateMask\SAMModelCheckpoint\SegmentAnything\SAM_vit_b\sam_vit_b_01ec64.pth \
    --output D:\Project\WovenBagDetection\GenerateMask\SAM_ONNX\sam_vit_b_pmencoder.onnx \
    --model-type vit_b \
    --opset 12


Loading model...
Exporting onnx model to D:\Project\WovenBagDetection\GenerateMask\SAM_ONNX\sam_vit_b_pmencoder.onnx...
verbose: False, log level: Level.ERROR

Model has successfully been run with ONNXRuntime.


In [3]:
import torch

from collections import OrderedDict
from functools import partial
from segment_anything.modeling.image_encoder import ImageEncoderViT


encoder = ImageEncoderViT(
            depth=12,
            embed_dim=768,
            img_size=1024,
            mlp_ratio=4,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            num_heads=12,
            patch_size=16,
            qkv_bias=True,
            use_rel_pos=True,
            global_attn_indexes=[2, 5, 8, 11],
            window_size=14,
            out_chans=256,
        )

param = torch.load("D:\Project\WovenBagDetection\GenerateMask\SAMModelCheckpoint\SegmentAnything\SAM_vit_b\sam_vit_b_01ec64.pth")

d = OrderedDict()
for k in param:
    if "image_encoder" in k:
        d[k[14:]] = param[k]

encoder.load_state_dict(d)
encoder.eval()

x = torch.randn((1, 3, 1024, 1024))
torch.onnx.export(encoder,
                  x,
                  "D:\Project\WovenBagDetection\GenerateMask\SAM_ONNX\sam_vit_b_encoder.onnx",
                  opset_version=12,
                  input_names=["input"],
                  output_names=["output"])


verbose: False, log level: Level.ERROR



In [2]:
! python D:\Project\WovenBagDetection\GenerateMask\segment-anything\scripts\export_image_encoder.py \
--checkpoint D:\Project\WovenBagDetection\GenerateMask\SAMModelCheckpoint\SegmentAnything\SAM_vit_b\sam_vit_b_01ec64.pth \
--output D:\Project\WovenBagDetection\GenerateMask\SAM_ONNX\sam_vit_b_imgencoder.onnx \
--model-type vit_b \
--opset 12

Loading model...
Exporting onnx model to D:\Project\WovenBagDetection\GenerateMask\SAM_ONNX\sam_vit_b_imgencoder.onnx...
verbose: False, log level: Level.ERROR

Model has successfully been run with ONNXRuntime.
