Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

Commit

Permalink
work on sam-clip code
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Jun 9, 2023
1 parent 7e44970 commit 4a49282
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 25 deletions.
60 changes: 58 additions & 2 deletions autodistill_sam_clip/helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,69 @@
import os
import urllib.request

import numpy as np
import supervision as sv
import torch
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def combine_detections(detections_list, overwrite_class_ids):
if len(detections_list) == 0:
return sv.Detections.empty()

if overwrite_class_ids is not None and len(overwrite_class_ids) != len(
detections_list
):
raise ValueError(
"Length of overwrite_class_ids must match the length of detections_list."
)

xyxy = []
mask = []
confidence = []
class_id = []
tracker_id = []

for idx, detection in enumerate(detections_list):
xyxy.append(detection.xyxy)

if detection.mask is not None:
mask.append(detection.mask)

if detection.confidence is not None:
confidence.append(detection.confidence)

if detection.class_id is not None:
if overwrite_class_ids is not None:
# Overwrite the class IDs for the current Detections object
class_id.append(
np.full_like(
detection.class_id, overwrite_class_ids[idx], dtype=np.int64
)
)
else:
class_id.append(detection.class_id)

if detection.tracker_id is not None:
tracker_id.append(detection.tracker_id)

xyxy = np.vstack(xyxy)
mask = np.vstack(mask) if mask else None
confidence = np.hstack(confidence) if confidence else None
class_id = np.hstack(class_id) if class_id else None
tracker_id = np.hstack(tracker_id) if tracker_id else None

return sv.Detections(
xyxy=xyxy,
mask=mask,
confidence=confidence,
class_id=class_id,
tracker_id=tracker_id,
)


def load_SAM():
# Check if segment-anything library is already installed

Expand All @@ -29,6 +87,4 @@ def load_SAM():
)
sam_predictor = SamAutomaticMaskGenerator(sam)



return sam_predictor
116 changes: 93 additions & 23 deletions autodistill_sam_clip/sam_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import cv2
import torch
from sklearn.metrics.pairwise import cosine_similarity

torch.use_deterministic_algorithms(False)

Expand All @@ -18,7 +19,7 @@
from autodistill.detection import CaptionOntology, DetectionBaseModel
from segment_anything import SamAutomaticMaskGenerator

from helpers import load_SAM
from helpers import combine_detections, load_SAM

HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -53,38 +54,107 @@ def __init__(self, ontology: CaptionOntology):
self.tokenize = clip.tokenize

def predict(self, input: str, confidence: int = 0.5) -> sv.Detections:
image = cv2.imread(input)
image_bgr = cv2.imread(input)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

# SAM Predictions
sam_result = self.sam_predictor.generate(image)
detections = sv.Detections.from_sam(sam_result=sam_result)
sam_result = self.sam_predictor.generate(image_rgb)

# CLIP Predictions
for i, _ in enumerate(detections):
labels = self.ontology.prompts()
valid_detections = []

indices = list(range(len(labels)))
labels = self.ontology.prompts()

# image is mask
mask = detections[i].mask
if len(sam_result) == 0:
return sv.Detections.empty()

for mask in sam_result:
mask_item = mask["segmentation"]
# cut out binary mask from image

image = image_rgb.copy()
# image[mask_item == 0] = 0

# padd bbox by 20%
mask["bbox"][0] = max(0, mask["bbox"][0] - int(mask["bbox"][2] * 0.2))
mask["bbox"][1] = max(0, mask["bbox"][1] - int(mask["bbox"][3] * 0.2))
mask["bbox"][2] = min(
image.shape[1], mask["bbox"][2] + int(mask["bbox"][2] * 0.2)
)
mask["bbox"][3] = min(
image.shape[0], mask["bbox"][3] + int(mask["bbox"][3] * 0.2)
)

# cut out bbox bbox
image = image[
mask["bbox"][1] : mask["bbox"][3], mask["bbox"][0] : mask["bbox"][2]
]

# prepare for CLIP

# fix ValueError: tile cannot extend outside image
if image.shape[0] == 0 or image.shape[1] == 0:
continue

# extract mask from image
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = np.array(image)
# remove outside mask
image[mask == False] = 0
image = Image.fromarray(image)

image = self.clip_preprocess(image).unsqueeze(0).to(DEVICE)
text = self.tokenize(labels).to(DEVICE)

cosime_sims = []

with torch.no_grad():
logits_per_image, _ = self.clip_model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
image_features = self.clip_model.encode_image(image)

image_features /= image_features.norm(dim=-1, keepdim=True)

for label in labels:
text = self.tokenize([label]).to(DEVICE)

text_features = self.clip_model.encode_text(text)

text_features /= text_features.norm(dim=-1, keepdim=True)
# get cosine similarity between image and text features
cosine_sim = cosine_similarity(
image_features.cpu().numpy(), text_features.cpu().numpy()
)

cosime_sims.append(cosine_sim[0][0])

max_prob = None
max_idx = None

print(cosime_sims)

values, indices = torch.topk(torch.tensor(cosime_sims), 1)

max_prob = values[0].item()
max_idx = indices[0].item()

if max_prob > confidence:
valid_detections.append(
sv.Detections(
xyxy=np.array([mask["bbox"]]),
confidence=np.array([max_prob]),
mask=np.array([mask_item]),
class_id=np.array([max_idx]),
)
)

return combine_detections(
valid_detections, overwrite_class_ids=len(valid_detections) * [1]
)


model = SAMCLIP(ontology=CaptionOntology({"box": "box"}))

detections = model.predict("./trash.jpg", confidence=0.2)

# show predictions

image_bgr = cv2.imread("./trash.jpg")

mask_annotator = sv.MaskAnnotator()

if probs[0][indices[0]] < confidence:
detections.mask[i] = None
detections.confidence[i] = None
detections.class_id[i] = None
annotated_image = mask_annotator.annotate(scene=image_bgr, detections=detections)

return detections
cv2.imshow("image", annotated_image)
cv2.waitKey(0)

0 comments on commit 4a49282

Please sign in to comment.