In [1]:
import torch
from torch2trt import torch2trt
import warnings
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel

from Module.models.timm_image_encoder import TimmImageEncoder
from torch2trt import torch2trt

ModuleNotFoundError: No module named 'segment_anything'

In [None]:
from mobile_sam.modeling.sam import Sam
import onnxruntime

In [None]:
## 0. Setting
opset=17

## 1. Export Decoder to onnx from SAM
- Export the SAM prompt encoder and mask decoder to an ONNX model

In [None]:
def to_numpy(tensor):
    return tensor.cpu().numpy()

In [30]:
#model_type_t = 'vit_l'
#checkpoint_t = './runs/241107_SAM_ViT_L_ft_v1/best.pth'
model_type_t = 'vit_b'
checkpoint_t = './runs/241028_SAM_FT_water_10e/241028_SAM_FT_10e_.pth'

In [31]:
output = "241112_test_sam_vit-l_decoder.onnx"

In [32]:
sam = sam_model_registry[model_type_t](checkpoint=checkpoint_t)

In [33]:
onnx_model = SamOnnxModel(
    model=sam,
    return_single_mask=True,
    use_stability_score=False,
    return_extra_metrics=False,
    )

In [34]:
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
opset=16
dynamic_axes = {
        "point_coords": {1: "num_points"},
        "point_labels": {1: "num_points"},
    }

dummy_inputs = {
        "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
        "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
        "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
        "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
        "has_mask_input": torch.tensor([1], dtype=torch.float),
        "orig_im_size": torch.tensor([1500, 2250], dtype=torch.int32),
    }

In [35]:
_ = onnx_model(**dummy_inputs)

In [36]:
output_names = ["masks", "iou_predictions", "low_res_masks"]

In [37]:
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    with open(output, "wb") as f:
        print(f"Exporting onnx model to {output}...")
        torch.onnx.export(
                onnx_model,
                tuple(dummy_inputs.values()),
                f,
                export_params=True,
                verbose=False,
                opset_version=opset,
                do_constant_folding=True,
                input_names=list(dummy_inputs.keys()),
                output_names=output_names,
                dynamic_axes=dynamic_axes,
            )

Exporting onnx model to 241108_test_sam_vit-b_decoder.onnx...
verbose: False, log level: Level.ERROR



In [38]:
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(output, providers=providers)

In [39]:
_ = ort_session.run(None, ort_inputs)

In [40]:
print("Model has successfully been run with ONNXRuntime.")

Model has successfully been run with ONNXRuntime.


## 2. Export image encoder (ResNet)

In [8]:
model_type_s = 'resnet18'
weight_s_path = "./runs/241031_vit-b_to_resnet18_v2/Nanosam_encoder.pth"
model_s = TimmImageEncoder('resnet18', pretrained=True)
model_s.load_state_dict(torch.load(weight_s_path)["model"])
device = "cuda" if torch.cuda.is_available() else "cpu"
model_s.to(device)
model_s.eval();

In [9]:
input_size = 1024

output = "241107_test_sam_vit-b_encoder.onnx"

In [10]:
data = torch.randn(1, 3, input_size, input_size).to(device)

In [11]:
model_trt = torch2trt(model_s, [data])

In [12]:
output = "241107_test_sam_vit-b_encoder.onnx"

torch.onnx.export(
            model_s,
            (data,),
            output,
            input_names=["image"],
            output_names=["image_embeddings"],
            opset_version=opset
        )

verbose: False, log level: Level.ERROR

