In [None]:
# import python libraries
import cv2, os, random, importlib, numpy as np
import matplotlib.pyplot as plt 
from PIL import Image
from transformers import SamModel, SamConfig, SamProcessor
import torch

# import custom scripts 
import preprocess_images
import get_bounding_box

# reload and declare functions 
importlib.reload(preprocess_images)
importlib.reload(get_bounding_box)

from preprocess_images import preprocess_grayscale, preprocess_rgb, preprocess_rgbd
from get_bounding_box import get_bounding_box, get_bounding_box_circumscribed

# double check to make sure the right weights are called 
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

folder_path_model = './models/SAM_rgb.pth'

my_mito_model = SamModel(config=model_config)
my_mito_model.load_state_dict(torch.load('./models/SAM_rgb.pth'))

# set the device to cuda if available, otherwise use cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
my_mito_model.to(device)



In [None]:
train_images, train_masks, val_images, val_masks, test_images, test_masks  = preprocess_rgb()

In [None]:
# load image
test_image = test_images[idx]
ground_truth_mask = test_masks[idx]

# use only if using input_points
#pointer_prompt = [100,150]

# the code below has been adapted from the creators of the medsam model and has been adapted for our use case 
# main difference: probablity map 

prompt = get_bounding_box(ground_truth_mask)

inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

my_mito_model.eval()

with torch.no_grad():
    outputs = my_mito_model(**inputs, multimask_output=False)

medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

bbox = prompt
rect = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], linewidth=2, edgecolor='r', facecolor='none')
axes[0].add_patch(rect)

axes[0].imshow(np.array(test_image), cmap='gray')
axes[0].set_title("Image")

# Plot the second image on the right
axes[1].imshow(medsam_seg, cmap='gray')
axes[1].set_title("Mask")

axes[2].imshow(test_masks[idx], cmap = "gray")
axes[2].set_title("Orginal Mask")

for ax in axes:
    ax.axis("off")