In [None]:


import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import coremltools as ct
import math
from repvit_sam.utils.transforms import ResizeLongestSide
import torch.nn.functional as F


def show_mask(mask, ax):
    color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def preprocess(x, img_size=1024):
    """Normalize pixel values and pad to a square input."""
    # Normalize colors
    transform = ResizeLongestSide(img_size)
    x = transform.apply_image(x)
    x = torch.as_tensor(x)
    x = x.permute(2, 0, 1).contiguous()

    pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
    pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
    x = (x - pixel_mean) / pixel_std

    # Pad
    h, w = x.shape[-2:]
    padh = img_size - h
    padw = img_size - w
    x = F.pad(x, (0, padw, 0, padh))
    return x, transform

def postprocess(raw_image, masks):
    def resize_longest_image_size(
            input_image_size, longest_side: int
        ):
            scale = longest_side / max(input_image_size)
            transformed_size = [int(math.floor(scale * each + 0.5)) for each in input_image_size]
            return transformed_size

    prepadded_size = resize_longest_image_size(raw_image.shape[:2], masks.shape[2])
    masks = masks[..., : prepadded_size[0], : prepadded_size[1]]  # type: ignore

    h, w = raw_image.shape[:2]
    masks = F.interpolate(torch.tensor(masks), size=(h, w), mode="bilinear", align_corners=False)
    masks = masks > 0
    return masks

In [None]:
!python3 ../scripts/export_coreml_encoder.py --resolution 1024 --model repvit --samckpt ../weights/repvit_sam.pt
!python3 ../scripts/export_coreml_decoder.py --checkpoint ../weights/repvit_sam.pt --model-type repvit

In [None]:
encoder = ct.models.MLModel('coreml/repvit_1024.mlpackage')

In [None]:
decoder = ct.models.MLModel('coreml/sam_decoder.mlpackage')

In [None]:
raw_image = cv2.imread('../../app/assets/picture3.jpg')
raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
image, transform = preprocess(raw_image)
image_embedding= list(encoder.predict({'x_1': image.numpy()[None, ...]}).values())[0]

In [None]:
input_point = np.array([[553, 808]])
input_label = np.array([1])

coreml_coord = input_point[None, :, :].astype(np.float32)
coreml_label = input_label[None, :].astype(np.float32)

coreml_coord = transform.apply_coords(coreml_coord, raw_image.shape[:2]).astype(np.float32)

coreml_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
coreml_has_mask_input = np.zeros(1, dtype=np.float32)

In [None]:
ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": coreml_coord,
    "point_labels": coreml_label,
    "mask_input": coreml_mask_input,
    "has_mask_input": coreml_has_mask_input,
}

In [None]:
low_res_logits, score, masks = decoder.predict(ort_inputs).values()
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
show_mask(postprocess(raw_image, masks), plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show() 