In [None]:
import sys
from segment_anything import sam_model_registry, SamPredictor

import torch
from finetune import DRIVE_Dataset
from torch.utils.data import DataLoader

from ryu_pytools import arr_info

In [None]:
sam_checkpoint = "/share/home/liuy/project/SAM_finetune/checkpoints/finetune/best.pth"
model_type = "vit_b"
device = "cuda"

In [None]:
sam_model = sam_model_registry[model_type](sam_checkpoint)
sam_model.to(device=device)

In [None]:

input_folder='/share/home/liuy/project/SAM_finetune/data/DRIVE/training/images'
label_folder='/share/home/liuy/project/SAM_finetune/data/DRIVE/training/1st_manual'

ds = DRIVE_Dataset(input_folder=input_folder, label_folder=label_folder)
dl = DataLoader(dataset=ds, batch_size=1, shuffle=True)

In [None]:
dl_iter = iter(dl)
inputs, _, (points, labels) = next(dl_iter)

In [None]:
from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam_model.image_encoder.img_size)

In [None]:
inputs = inputs.to(device)
# masks = masks.to(device)
points = points.to(device)
labels = labels.to(device)

original_size = inputs.shape[-2:]
inputs = resize_transform.apply_image_torch(inputs)
points = resize_transform.apply_coords_torch(points, original_size)

inputs = torch.stack([sam_model.preprocess(x) for x in inputs], dim=0)

In [None]:
with torch.no_grad():
    image_embedding = sam_model.image_encoder(inputs)
with torch.no_grad():
    sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
        points=(points, labels),
        boxes=None,
        masks=None,
    )

In [None]:
arr_info(image_embedding, 'image_embedding')
arr_info(sparse_embeddings, 'sparse_embeddings')
arr_info(dense_embeddings, 'dense_embeddings')

In [None]:
low_res_masks, iou_predictions, point_predictions = sam_model.mask_decoder(
    image_embeddings=image_embedding,
    image_pe=sam_model.prompt_encoder.get_dense_pe(),
    sparse_prompt_embeddings=sparse_embeddings,
    dense_prompt_embeddings=dense_embeddings,
    multimask_output=False,
)

In [None]:
arr_info(point_predictions, 'point_predictions')
arr_info(low_res_masks, 'low_res_masks')
arr_info(iou_predictions, 'iou_predictions')