In [None]:
import os
import sys
sys.path.append('..')
import cv2
import numpy as np
import matplotlib.pyplot as plt

from testing.test_utils.model import create_model
from testing.test_utils.utils_vis import show_mask, show_neg_points, show_pos_points
from testing.test_utils.utils import p2sam_medical


class Args:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [None]:
# create model
sam_type, sam_ckpt = 'vit_l', 'pretrained_weights/nsclc_full_large/checkpoint.pth'
sam = create_model(sam_type, sam_ckpt, encoder_type='timm', lora=False, r=1, enable_lora=[True, True, True])
sam = sam.to('cuda')

In [None]:
# load image
test_image = cv2.imread('images/test_image.png')
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
if os.path.exists('images/test_mask.png'):
    test_mask = cv2.imread('images/test_mask.png', cv2.IMREAD_GRAYSCALE)       
    test_mask = test_mask - (test_mask.max()+test_mask.min()) / 2.0
    test_mask = (test_mask > 0.0).astype(np.float)
else:
    test_mask = None
plt.figure(figsize=(10, 10))
plt.imshow(test_image)
if test_mask is not None:
    show_mask(test_mask, plt.gca(), -1, linewidth=1) 
plt.axis('off')
plt.show()

In [None]:
# forward
args = Args(min_num_pos=3, max_num_pos=3, min_num_neg=50, max_num_neg=50, reg_patch_weight=False, medsam=False)
sam.eval()
pred_mask, point_coords, point_labels, _, _ = p2sam_medical(args, sam, ref_image_path='images/ref_image.png', ref_mask_path='images/ref_mask.png', test_image_path='images/test_image.png', test_mask_path=None, output_path='images/', slice_name='p2sam_result')

In [None]:
# plot prediction
plt.figure(figsize=(10, 10))
plt.imshow(test_image)
show_mask(pred_mask, plt.gca(), -1, linewidth=1) 
show_pos_points(point_coords, point_labels, plt.gca(), -1, None, marker_size=100, linewidth=1.5)
show_neg_points(point_coords, point_labels, plt.gca(), 'red', marker_size=100, linewidth=1.5)
plt.axis('off')
plt.show()