Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
bnorthan committed Apr 28, 2024
2 parents 98a10f0 + 9188e8f commit 859c6e4
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 11 deletions.
133 changes: 128 additions & 5 deletions src/napari_segment_everything/_tests/test_mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,38 @@
from napari_segment_everything.sam_helper import (
get_mobileSAMv2,
get_bounding_boxes,
add_properties_to_label_image,
SAM_WEIGHTS_URL,
get_weights_path,
get_device,
get_sam_automatic_mask_generator,
)
from napari_segment_everything.minimal_detection.prompt_generator import (
RcnnDetector,
YoloDetector,
)
import os
import requests
from gdown.parse_url import parse_url


def test_urls():
"""
Tests whether all the urls for the model weights exist.
"""
for url in SAM_WEIGHTS_URL.values():
if url.startswith("https://drive.google.com/"):
_, path_exists = parse_url(url)
assert path_exists
else:
req = requests.head(url)
assert req.status_code == 200


def test_mobile_sam():
"""
Tests the mobileSAMv2 process pipeline
"""
# load a color examp
image = data.coffee()

Expand All @@ -24,15 +52,110 @@ def test_mobile_sam():


def test_bbox():
# load a color examp
"""
Test whether bboxes can be generated
"""
image = data.coffee()
bounding_boxes = get_bounding_boxes(
image, detector_model="Finetuned", device="cuda", conf=0.01, iou=0.99
)
print(f"Length of bounding boxes: {len(bounding_boxes)}")
segmentations = get_mobileSAMv2(image, bounding_boxes)
return segmentations
assert len(bounding_boxes) > 0


def test_RCNN():
"""
Test RCNN object detection on CPU and CUDA devices.
"""
image = data.coffee()
model_path = str(get_weights_path("ObjectAwareModel_Cell_FT"))
assert os.path.exists(model_path)
rcnn_cpu = RcnnDetector(model_path, device="cpu")
rcnn_cuda = RcnnDetector(model_path, device="cuda")
bbox_cpu = rcnn_cpu.get_bounding_boxes(image, conf=0.5, iou=0.2)
bbox_cuda = rcnn_cuda.get_bounding_boxes(image, conf=0.5, iou=0.2)
assert len(bbox_cpu) == 6
assert len(bbox_cuda) == 6


def test_YOLO():
"""
Test YOLO object detection on CPU and CUDA devices.
"""
image = data.coffee()
model_path = str(get_weights_path("ObjectAwareModel"))
assert os.path.exists(model_path)
yolo_cpu = YoloDetector(model_path, device="cpu")
yolo_cuda = YoloDetector(model_path, device="cuda")
bbox_cpu = yolo_cpu.get_bounding_boxes(
image, conf=0.5, iou=0.2, max_det=400, imgsz=1024
)
bbox_cuda = yolo_cuda.get_bounding_boxes(
image, conf=0.5, iou=0.2, max_det=400, imgsz=1024
)
assert len(bbox_cpu) == 8
assert len(bbox_cuda) == 8


def test_weights_path():
"""
Tests whether the weights directory existing on the operating system
"""
weights_path = get_weights_path("default")
assert os.path.exists(os.path.dirname(weights_path))


def test_labels():
"""
Tests whether region properties can be generated for segmentations for different models
"""
image = data.coffee()
device = get_device()

bbox_yolo = get_bounding_boxes(
image,
detector_model="YOLOv8",
imgsz=1024,
device=device,
conf=0.4,
iou=0.9,
)
bbox_rcnn = get_bounding_boxes(
image,
detector_model="Finetuned",
imgsz=1024,
device=device,
conf=0.4,
iou=0.9,
)

segmentations_rcnn = get_mobileSAMv2(image, bbox_rcnn)
segmentations_yolo = get_mobileSAMv2(image, bbox_yolo)
segmentations_vit_b = get_sam_automatic_mask_generator(
"vit_b",
points_per_side=4,
pred_iou_thresh=0.2,
stability_score_thresh=0.5,
box_nms_thresh=0.1,
crop_n_layers=0,
).generate(image)

add_properties_to_label_image(image, segmentations_rcnn)
add_properties_to_label_image(image, segmentations_yolo)
add_properties_to_label_image(image, segmentations_vit_b)

props_rcnn = segmentations_rcnn[0].keys()
assert len(props_rcnn) == 10
props_yolo = segmentations_yolo[0].keys()
assert len(props_yolo) == 10
props_vit_b = segmentations_vit_b[0].keys()
assert len(props_vit_b) == 13


# seg = test_bbox()
seg = test_mobile_sam()
test_urls()
test_bbox()
test_mobile_sam()
test_RCNN()
test_YOLO()
test_weights_path()
test_labels()
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ class RcnnDetector(BaseDetector):
def __init__(self, model_path, device, trainable=True):
super().__init__(model_path, trainable)
self.model_type = "FasterRCNN"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = fasterrcnn_mobilenet_v3_large_fpn().to(device)
self.device = device
self.model = fasterrcnn_mobilenet_v3_large_fpn(
box_detections_per_img=500,
).to(device)
self.model.load_state_dict(torch.load(model_path))

def train(self, training_data):
Expand Down
8 changes: 4 additions & 4 deletions src/napari_segment_everything/sam_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,17 @@ def get_bounding_boxes(
):
if detector_model == "YOLOv8":
model = YoloDetector(
str(get_weights_path("ObjectAwareModel")), device="cuda"
str(get_weights_path("ObjectAwareModel")), device=device
)
bounding_boxes = model.get_bounding_boxes(
image, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det
)
elif detector_model == "Finetuned":
model = RcnnDetector(
str(get_weights_path("ObjectAwareModel_Cell_FT")),
device="cuda",
str(get_weights_path("ObjectAwareModel_Cell_FT")), device=device
)
bounding_boxes = model.get_bounding_boxes(image, conf=conf, iou=iou)
print(bounding_boxes)
# print(bounding_boxes)
return bounding_boxes


Expand Down Expand Up @@ -287,6 +286,7 @@ def filter_labels_3d_multi(
def add_properties_to_label_image(orig_image, sorted_results):

hsv_image = color.rgb2hsv(orig_image)
# switch to this? https://forum.image.sc/t/looking-for-a-faster-version-of-rgb2hsv/95214/12

hue = 255 * hsv_image[:, :, 0]
saturation = 255 * hsv_image[:, :, 1]
Expand Down

0 comments on commit 859c6e4

Please sign in to comment.