Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
bnorthan committed Apr 30, 2024
2 parents 9104024 + f594abd commit 58495d6
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.9', '3.10']

steps:
- uses: actions/checkout@v3
Expand Down
7 changes: 5 additions & 2 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
Copyright (c) 2024, Brian Northan
All rights reserved.

This work re-uses some code from napari sam (see license here https://github.com/MIC-DKFZ/napari-sam/blob/main/LICENSE)
This work re-uses some code from the following sources:
- napari sam (License: https://github.com/MIC-DKFZ/napari-sam/blob/main/LICENSE)
- napari-segment-anything (License: https://github.com/royerlab/napari-segment-anything/blob/main/LICENSE)
- MobileSAMv2 (License: https://github.com/ChaoningZhang/MobileSAM/blob/master/LICENSE)

And also re-uses some code from napari-segment-anything (see license here https://github.com/royerlab/napari-segment-anything/blob/main/LICENSE)
Additionally, training data is from Cellpose's annotated dataset, which is licensed CC-by-NC (see: https://github.com/MouseLand/cellpose?tab=readme-ov-file)

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Expand Down
56 changes: 29 additions & 27 deletions src/napari_segment_everything/_tests/test_mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,30 @@
import requests
from gdown.parse_url import parse_url

device = get_device()
if device == "mps":
device = "cpu"

#%%
def test_urls():
"""
Tests whether all the urls for the model weights exist.
Tests whether all the urls for the model weights exist and are accessible.
"""
for url in SAM_WEIGHTS_URL.values():
if url.startswith("https://drive.google.com/"):
_, path_exists = parse_url(url)
assert path_exists
else:
req = requests.head(url)
assert req.status_code == 200
TIMEOUT = 1
for name, url in SAM_WEIGHTS_URL.items():
try:
if url.startswith("https://drive.google.com/"):
_, path_exists = parse_url(url)
assert (
path_exists
), f"Google Drive URL path wasn't parsed correctly: {url}"
else:
req = requests.head(url, timeout=TIMEOUT)
assert (
req.status_code == 200
), f"Failed to access URL: {url}, Status code: {req.status_code}"
except requests.exceptions.Timeout:
print(f"Request timed out for URL: {url}")


def test_mobile_sam():
Expand All @@ -42,7 +54,7 @@ def test_mobile_sam():
image,
detector_model="YOLOv8",
imgsz=1024,
device="cuda",
device=device,
conf=0.4,
iou=0.9,
)
Expand All @@ -57,7 +69,7 @@ def test_bbox():
"""
image = data.coffee()
bounding_boxes = get_bounding_boxes(
image, detector_model="Finetuned", device="cuda", conf=0.01, iou=0.99
image, detector_model="YOLOv8", device=device, conf=0.9, iou=0.90
)
print(f"Length of bounding boxes: {len(bounding_boxes)}")
assert len(bounding_boxes) > 0
Expand All @@ -70,12 +82,9 @@ def test_RCNN():
image = data.coffee()
model_path = str(get_weights_path("ObjectAwareModel_Cell_FT"))
assert os.path.exists(model_path)
rcnn_cpu = RcnnDetector(model_path, device="cpu")
rcnn_cuda = RcnnDetector(model_path, device="cuda")
bbox_cpu = rcnn_cpu.get_bounding_boxes(image, conf=0.5, iou=0.2)
bbox_cuda = rcnn_cuda.get_bounding_boxes(image, conf=0.5, iou=0.2)
assert len(bbox_cpu) == 6
assert len(bbox_cuda) == 6
rcnn = RcnnDetector(model_path, device=device)
bbox = rcnn.get_bounding_boxes(image, conf=0.5, iou=0.2)
assert len(bbox) == 6


def test_YOLO():
Expand All @@ -85,16 +94,11 @@ def test_YOLO():
image = data.coffee()
model_path = str(get_weights_path("ObjectAwareModel"))
assert os.path.exists(model_path)
yolo_cpu = YoloDetector(model_path, device="cpu")
yolo_cuda = YoloDetector(model_path, device="cuda")
bbox_cpu = yolo_cpu.get_bounding_boxes(
image, conf=0.5, iou=0.2, max_det=400, imgsz=1024
)
bbox_cuda = yolo_cuda.get_bounding_boxes(
yolo = YoloDetector(model_path, device=device)
bbox = yolo.get_bounding_boxes(
image, conf=0.5, iou=0.2, max_det=400, imgsz=1024
)
assert len(bbox_cpu) == 8
assert len(bbox_cuda) == 8
assert len(bbox) == 8


def test_weights_path():
Expand All @@ -110,7 +114,6 @@ def test_labels():
Tests whether region properties can be generated for segmentations for different models
"""
image = data.coffee()
device = get_device()

bbox_yolo = get_bounding_boxes(
image,
Expand Down Expand Up @@ -150,8 +153,7 @@ def test_labels():
assert len(props_yolo) == 10
props_vit_b = segmentations_vit_b[0].keys()
assert len(props_vit_b) == 13



test_urls()
test_bbox()
test_mobile_sam()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,18 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]


def segment_from_bbox(bounding_boxes, predictor, mobilesamv2):
def segment_from_bbox(bounding_boxes, predictor, mobilesamv2, device):
"""
Segments everything given the bounding boxes of the objects and the mobileSAMv2 prediction model.
Code from mobileSAMv2
"""
input_boxes = predictor.transform.apply_boxes(
bounding_boxes, predictor.original_size
) # Does this need to be transformed?
input_boxes = torch.from_numpy(input_boxes).cuda()
if device == "cuda":
input_boxes = torch.from_numpy(input_boxes).cuda()
elif device == "cpu":
input_boxes = torch.from_numpy(input_boxes)
sam_mask = []

predicted_ious = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@ class RcnnDetector(BaseDetector):
def __init__(self, model_path, device, trainable=True):
super().__init__(model_path, trainable)
self.model_type = "FasterRCNN"
if device == "mps":
device = "cpu"
self.device = device
self.model = fasterrcnn_mobilenet_v3_large_fpn(
box_detections_per_img=500,
).to(device)
self.model.load_state_dict(torch.load(model_path))
).to(self.device)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))

def train(self, training_data):
if self.trainable:
Expand Down
16 changes: 11 additions & 5 deletions src/napari_segment_everything/sam_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ def get_sam_automatic_mask_generator(
crop_n_layers=1,
):

device = get_device()
if device == "mps":
device = "cpu"
sam = sam_model_registry[model_type](get_weights_path(model_type))
sam.to(get_device())
sam.to()
sam_anything_predictor = SamAutomaticMaskGenerator(
sam,
points_per_side=int(points_per_side),
Expand Down Expand Up @@ -178,7 +181,7 @@ def get_bounding_boxes(
return bounding_boxes


def get_mobileSAMv2(image=None, bounding_boxes=None):
def get_mobileSAMv2(image=None, bounding_boxes=None, device=get_device()):
"""
Uses a SAM model to make predictions from bounding boxes.
Expand All @@ -201,19 +204,22 @@ def get_mobileSAMv2(image=None, bounding_boxes=None):
return
if image.ndim < 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
weights_path_VIT = get_weights_path("efficientvit_l2")
samV2 = create_MS_model()

samV2.image_encoder = sam_model_registry["efficientvit_l2"](
weights_path_VIT
)
if device == "mps":
device="cpu"
samV2.to(device=device)
samV2.eval()
predictor = SamPredictorV2(samV2)
predictor.set_image(image)
sam_masks = segment_from_bbox(bounding_boxes, predictor, samV2)
sam_masks = segment_from_bbox(
bounding_boxes, predictor, samV2, device=device
)
del bounding_boxes

gc.collect()
Expand Down Expand Up @@ -304,7 +310,7 @@ def add_properties_to_label_image(orig_image, sorted_results):
# for small pixelated objects, circularity can be > 1 so we cap it
if result["circularity"] > 1:
result["circularity"] = 1

result["solidity"] = regions[0].solidity
intensity_pixels = intensity[coords]
result["mean_intensity"] = np.mean(intensity_pixels)
Expand Down
27 changes: 12 additions & 15 deletions src/napari_segment_everything/segment_everything.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from napari_segment_everything.sam_helper import (
get_bounding_boxes,
get_mobileSAMv2,
get_device,
)
import pickle

Expand All @@ -35,7 +36,7 @@
QMessageBox,
QTextBrowser,
QProgressBar,
QApplication
QApplication,
)


Expand Down Expand Up @@ -161,7 +162,6 @@ def on_index_changed(index):

self.stacked_algorithm_params_layout = QStackedWidget()


self.bbox_conf_spinner = LabeledSpinner(
"Bounding Box Confidence", 0, 1, 0.1, None, is_double=True
)
Expand All @@ -186,7 +186,6 @@ def on_index_changed(index):
self.yolo_params_layout.addWidget(self.bbbox_max_det_spinner)
self.widgetGroup1.setLayout(self.yolo_params_layout)


self.points_per_side_spinner = LabeledSpinner(
"Points per side", 4, 100, 32, None
)
Expand Down Expand Up @@ -392,30 +391,29 @@ def open_project(self):
image = project["image"]
self.load_project(image, results)



def load_project(self, image, results):
self.results = results
self.results = sorted(self.results, key=lambda x: x['area'], reverse=False)
self.results = sorted(
self.results, key=lambda x: x["area"], reverse=False
)
label_num = 1
for result in self.results:
result['keep'] = True
result['label_num'] = label_num
result["keep"] = True
result["label_num"] = label_num
label_num += 1


self.image = image
add_properties_to_label_image(self.image, self.results)
self.viewer.add_image(image)

self._3D_labels_layer.data = make_label_image_3d(self.results)
self.viewer.dims.ndisplay = 3
self._3D_labels_layer.translate = (-len(self.results), 0, 0)

self.add_points()
self.add_boxes()
self.update_slider_min_max()

def save_project(self):
options = QFileDialog.Options()
file_name, _ = QFileDialog.getSaveFileName(
Expand Down Expand Up @@ -485,7 +483,7 @@ def process(self):
bounding_boxes = get_bounding_boxes(
self.image,
detector_model="Finetuned",
device="cuda",
device=get_device(),
conf=bbox_conf,
iou=bbox_iou,
)
Expand All @@ -511,11 +509,10 @@ def process(self):
)
self.textBrowser_log.repaint()
QApplication.processEvents()

bounding_boxes = get_bounding_boxes(
self.image,
detector_model="YOLOv8",
device="cuda",
device=get_device(),
conf=bbox_conf,
iou=bbox_iou,
imgsz=bbox_imgsz,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ isolated_build=true

[gh-actions]
python =
3.8: py38
# 3.8: py38
3.9: py39
3.10: py310

Expand Down

0 comments on commit 58495d6

Please sign in to comment.