Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

Commit

Permalink
tensor on different device fix
Browse files Browse the repository at this point in the history
  • Loading branch information
artyaltanzaya committed Jul 5, 2023
2 parents 7a88a41 + fdedeb9 commit 5c83aa6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,5 @@ dmypy.json
.pyre/

#Images
*.png
*.png
*.jpg
31 changes: 12 additions & 19 deletions autodistill_sam_clip/sam_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,23 @@
HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@dataclass
class SAMCLIP(DetectionBaseModel):
ontology: CaptionOntology
sam_predictor: SamAutomaticMaskGenerator
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def __init__(self, ontology: CaptionOntology):
self.ontology = ontology
self.sam_predictor = load_SAM()

if not os.path.exists(f"{HOME}/.cache/autodistill/clip"):
os.makedirs(f"{HOME}/.cache/autodistill/clip")

# os.system("pip install ftfy regex tqdm")
# os.system(
# f"cd {HOME}/.cache/autodistill/clip && pip install git+https://github.com/openai/CLIP.git"
# )
os.system("pip install ftfy regex tqdm")
os.system(
f"cd {HOME}/.cache/autodistill/clip"
f"cd {HOME}/.cache/autodistill/clip && pip install git+https://github.com/openai/CLIP.git"
)

# add clip path to path
Expand All @@ -61,7 +60,6 @@ def predict(self, input: str, confidence: int = 0.5) -> sv.Detections:

# SAM Predictions
sam_result = self.sam_predictor.generate(image_rgb)
print("sam result", len(sam_result))

# mask_annotator = sv.MaskAnnotator()

Expand All @@ -75,18 +73,17 @@ def predict(self, input: str, confidence: int = 0.5) -> sv.Detections:
valid_detections = []

labels = self.ontology.prompts()
print("labels", labels)

nms_data = []

if len(sam_result) == 0:
return sv.Detections.empty()

labels.append("background")
count = 0

for mask in sam_result:
mask_item = mask["segmentation"]
count += 1

image = image_rgb.copy()

# image[mask_item == 0] = 0
Expand All @@ -110,30 +107,26 @@ def predict(self, input: str, confidence: int = 0.5) -> sv.Detections:
continue

# show me the bbox
image = Image.fromarray(image, 'RGB')
image.save(f'{count}_img.png')
image = Image.fromarray(image)

image = self.clip_preprocess(image).unsqueeze(0).to(DEVICE)

cosime_sims = []

with torch.no_grad():
image_features = self.clip_model.encode_image(image).to(DEVICE)
image_features /= image_features.norm(dim=-1, keepdim=True)
print("image_features", image_features.shape)

text = self.tokenize(labels).to(DEVICE)
print(" devcie b", DEVICE)
text_features = self.clip_model.encode_text(text).to(DEVICE)
text_features = text_features.to(DEVICE)

text_features /= text_features.norm(dim=-1, keepdim=True)
# get cosine similarity between image and text features
print("text_features", text_features.shape)

similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print(similarity)

cosime_sims.append(similarity[0][0].item())
print("after item sim", similarity[0][0].item())

max_prob = None
max_idx = None
Expand All @@ -142,7 +135,7 @@ def predict(self, input: str, confidence: int = 0.5) -> sv.Detections:

max_prob = values[0].item()
max_idx = indices[0].item()
print("max+props", max_prob, max_idx)

if max_prob > confidence:
valid_detections.append(
sv.Detections(
Expand Down Expand Up @@ -177,4 +170,4 @@ def predict(self, input: str, confidence: int = 0.5) -> sv.Detections:

return combine_detections(
final_detections, overwrite_class_ids=[0] * len(final_detections)
)
)

0 comments on commit 5c83aa6

Please sign in to comment.