In [1]:
import numpy as np
import torch
import cv2

import sys
sys.path.append("../..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

def gt_to_anns(mask_gt):
    labels = np.unique(mask_gt)
    anns = []
    for label in labels:
        # skip background
        if label == 0:
            continue
        mask = np.all(mask_gt == label, axis=-1)
        # 1 ramdon point from mask
        num_point = 1
        indices = np.argwhere(mask)
        # swap x y
        indices[:,[1,0]] = indices[:,[0,1]]
        # sample on random point
        point = np.asarray(indices[np.random.randint(indices.shape[0], size=num_point)])
        # all point for test
        points = np.asarray(indices)
        anns.append({
            'area': np.sum(mask),
            'segmentation': mask,
            'label': label,
            'random_point_1':point,
            'points':points,
        })
    return anns

sam_checkpoint = "../../sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda:0"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam = sam.to(device)
predictor = SamPredictor(sam)

import os
from tqdm import tqdm
root = "../../datasets/people_poses/"
data_name = '100034_483681'
image = cv2.imread(root + 'val_images/' + data_name + '.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_gt = cv2.imread(root + 'val_segmentations/' + data_name + '.png')
anns = gt_to_anns(mask_gt)
predictor.set_image(image)

target = anns[2]
assert target['label'] == 6

masks, _, _ = predictor.predict(
    point_coords=np.array([[107, 252]], dtype=np.int64),
    point_labels=np.array([1]),
    multimask_output=False,
)
print(masks.shape)

  from .autonotebook import tqdm as notebook_tqdm


(1, 423, 187)


In [3]:
predictor.device, predictor.features.shape, predictor.input_h, predictor.input_w, predictor.input_size, predictor.is_image_set, predictor.orig_h, predictor.orig_w, predictor.original_size, predictor.transform

(device(type='cuda', index=0),
 torch.Size([1, 256, 64, 64]),
 None,
 None,
 (1024, 453),
 True,
 None,
 None,
 (423, 187),
 <segment_anything.utils.transforms.ResizeLongestSide at 0x1628c38fee0>)

In [10]:
embedding = predictor.features.detach().cpu().numpy()

In [3]:
target['segmentation'].shape

(423, 187)

In [9]:
(masks[0] == target['segmentation']).mean()

0.991390753593507

In [2]:
data = np.load('../../datasets/people_poses/val_embeds/' + data_name + '.npz')
embed = data['embed']

In [4]:
def set_embedding(predictor, embed, label):
    predictor.is_image_set = True
    predictor.features = torch.as_tensor(embed[None,]).to(device)
    # image_batch, mask_batch, height, width
    assert len(predictor.features.shape) == 4
    predictor.original_size = label.shape
    # TODO(jiahang): fix magic numbers
    predictor.input_size = (1024, 1024)
set_embedding(predictor, embed, target['segmentation'])
masks2, _, _ = predictor.predict(
    point_coords=np.array([[107, 252]], dtype=np.int64),
    point_labels=np.array([1]),
    multimask_output=False,
)
print(masks2.shape)

(1, 423, 187)


In [10]:
(masks2[0] == target['segmentation']).mean()

0.9115813959368403

In [6]:
predictor.transform.get_preprocess_shape(423, 187, 1024)

(1024, 453)