In [1]:
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import functools
from segment_anything.utils.onnx import SamOnnxModel
from segment_anything.modeling import Sam
import nibabel as nib

In [2]:
device = 'cuda'
# 1. Load the image
img = nib.load(r"E:\Mine\SAM_editor\static\nifti_files\img_CT_new.nii.gz")
img = img.get_fdata()
min = img.min()
max = img.max()
img = (img[...,481,None] - min)*255/(max - min)
img = np.transpose(np.repeat(img, 3, axis=-1).astype(np.uint8), (1,0,2))

In [3]:
img.min(), img.max(), img.dtype, img.shape

(13, 103, dtype('uint8'), (512, 512, 3))

In [4]:
img.shape

(512, 512, 3)

In [None]:
# 2. Load the Segment anything model
sam = sam_model_registry["vit_b"]( checkpoint = r"E:\Mine\SAM_editor\SAM_testing_only\SAM_models\MedSAM\medsam_vit_b.pth").to(device)
sam.eval()

In [None]:
sam.device

In [None]:
# 3. Put the model to the SamPredictor helper object
predictor = SamPredictor(sam)
mask_generator = SamAutomaticMaskGenerator(sam)

# 4. Encode the image to embeddings.
predictor.set_image(img)

FOR POINTS

In [None]:
# 5. Prepare the prompt
input_point = np.array([[150,250], [350, 250]])
input_label = np.array([1,1])

# 6. Decode masks
masks = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=True)

In [None]:
input_point.shape

In [None]:
help(predictor.predict)

In [None]:
masks[0].shape,masks[1],masks[2].shape

In [None]:
plt.imshow(img, cmap='gray')
plt.imshow(masks[0][0], cmap='jet', alpha=0.5)  # 'jet' colormap and 50% transparency
plt.scatter( input_point[0,0], input_point[0,1], color='red', s=100, marker='x')
plt.scatter( input_point[1,0], input_point[1,1], color='red', s=100, marker='x')

# Show the plot
# plt.axis('off')  # Hide axis
plt.show()

FOR BOX

In [None]:
bbox = np.array([80, 150, 420, 390])

bbox_masks = predictor.predict(box=bbox, multimask_output=True)

In [None]:
bbox_masks[0].shape,bbox_masks[1],bbox_masks[2].shape

In [None]:
fig, ax = plt.subplots()
ax.imshow(img, cmap='gray')
ax.imshow(bbox_masks[0][2], cmap='jet', alpha=0.5)  # 'jet' colormap and 50% transparency
# ax.scatter( input_point[0,0], input_point[0,1], color='red', s=100, marker='x')
# ax.scatter( input_point[1,0], input_point[1,1], color='red', s=100, marker='x')

# Show the plot
# plt.axis('off')  # Hide axis
x0,y0,x1,y1 = bbox
w = x1 - x0
h = y1 - y0
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  
plt.show()

In [None]:
output_mask = mask_generator.generate(img)
output_mask

In [None]:
output_mask[0]

In [None]:
sorted_output_mask = sorted(output_mask,key = lambda x: x['predicted_iou'])
for i in sorted_output_mask:
    fig, ax = plt.subplots()
    ax.imshow(img, cmap='gray')
    ax.imshow(i['segmentation'], cmap='jet', alpha=0.5)  # 'jet' colormap and 50% transparency
    ax.scatter( int(i['point_coords'][0][0]), int(i['point_coords'][0][1]), color='red', s=100, marker='x')
    x, y, width, height = i['bbox']
    rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor='r', facecolor='none')
    ax.add_patch(rect)


    # Show the plot
    # plt.axis('off')  # Hide axis
    plt.show()
    print("predicted_iou : ",i['predicted_iou'])
    print("stability_score : ",i['stability_score'])

In [None]:
output_mask[0]['segmentation'].astype(np.uint8).max()

In [None]:
sorted_output_mask = sorted(output_mask,key = lambda x: x['predicted_iou'])
combinedmask = functools.reduce(lambda a, b: a + ( b['segmentation'].astype(np.uint8) * ( a.max() + 1)), output_mask, np.zeros_like(output_mask[0]['segmentation'], dtype=np.uint8))
print("np.unique(combinedmask)", np.unique(combinedmask), np.unique(combinedmask).shape)
fig, ax = plt.subplots()
# ax.imshow(img, cmap='gray')
ax.imshow(combinedmask, cmap='jet')  # 'jet' colormap and 50% transparency
# ax.scatter( int(i['point_coords'][0][0]), int(i['point_coords'][0][1]), color='red', s=100, marker='x')
# x, y, width, height = i['bbox']
# rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor='r', facecolor='none')
# ax.add_patch(rect)


# Show the plot
# plt.axis('off')  # Hide axis
plt.show()
# print("predicted_iou : ",i['predicted_iou'])
# print("stability_score : ",i['stability_score'])

In [None]:
output_mask[0]['segmentation'].astype(np.uint8)

In [None]:
np.array([1]) & np.array([1000])