In [None]:
! pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
! mkdir -p {HOME}/weights
! wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights

In [None]:
import torch
import numpy as np
from PIL import Image, ImageDraw
import clip
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator



In [None]:
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAM_CHECKPOINT = "/content/weights/sam_vit_h_4b8939.pth"  # Download from https://github.com/facebookresearch/segment-anything
MODEL_TYPE = "vit_h"
CLIP_MODEL_NAME = "ViT-B/32"  # Faster than ViT-L/14
SIMILARITY_THRESHOLD = 0.9    # Adjust based on your use case

# Initialize models
def initialize_models():
    # Load SAM
    sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
    sam.to(device=DEVICE)
    mask_generator = SamAutomaticMaskGenerator(
        sam,
        points_per_side=32,  # Reduce for faster processing
        pred_iou_thresh=0.86,
        stability_score_thresh=0.92,
        crop_n_layers=1
    )
    
    # Load CLIP
    clip_model, clip_preprocess = clip.load(CLIP_MODEL_NAME, device=DEVICE)
    
    return mask_generator, clip_model, clip_preprocess

# Process reference image
def get_reference_embedding(image_path, clip_model, clip_preprocess):
    ref_image = Image.open(image_path).convert("RGB")
    preprocessed = clip_preprocess(ref_image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        embedding = clip_model.encode_image(preprocessed)
    return embedding / embedding.norm(dim=-1, keepdim=True)

# Process target image and detect objects
def detect_objects(target_path, mask_generator, clip_model, clip_preprocess, ref_embedding):
    # Load target image
    target_image = Image.open(target_path).convert("RGB")
    target_np = np.array(target_image)
    
    # Generate masks with SAM
    masks = mask_generator.generate(target_np)
    
    # Process each candidate region
    detected_regions = []
    for mask_info in masks:
        mask = mask_info["segmentation"]
        y, x = np.where(mask)
        if len(x) == 0 or len(y) == 0:
            continue
            
        # Get bounding box
        x_min, x_max = np.min(x), np.max(x)
        y_min, y_max = np.min(y), np.max(y)
        
        # Crop region (expand slightly for context)
        padding = 3
        crop = target_image.crop((
            max(0, x_min - padding),
            max(0, y_min - padding),
            min(target_image.width, x_max + padding),
            min(target_image.height, y_max + padding)
        ))
        
        # Get CLIP embedding for the region
        preprocessed = clip_preprocess(crop).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            region_embedding = clip_model.encode_image(preprocessed)
        region_embedding /= region_embedding.norm(dim=-1, keepdim=True)
        
        # Calculate similarity
        similarity = torch.matmul(ref_embedding, region_embedding.T).item()
        
        if similarity > SIMILARITY_THRESHOLD:
            detected_regions.append({
                "bbox": (x_min, y_min, x_max, y_max),
                "similarity": similarity,
                "mask": mask
            })
    
    return target_image, detected_regions

# Visualization
def visualize_results(image, regions):
    draw = ImageDraw.Draw(image)
    for region in sorted(regions, key=lambda x: x["similarity"], reverse=True):
        x_min, y_min, x_max, y_max = region["bbox"]
        draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
        draw.text((x_min, y_min), 
                 f"{region['similarity']:.2f}", 
                 fill="white")
    return image



In [None]:
# Main execution
if __name__ == "__main__":
    # Initialize models
    sam_mask_generator, clip_model, clip_preprocess = initialize_models()
    # Get reference embedding
    ref_embedding = get_reference_embedding(
        "/content/sample1.png", 
        clip_model, 
        clip_preprocess
    )

    # Detect objects in target image
    target_image, regions = detect_objects(
        "/content/reference_img.png",
        sam_mask_generator,
        clip_model,
        clip_preprocess,
        ref_embedding
    )

    # Visualize and save
    result_image = visualize_results(target_image, regions)
    result_image.save("detection_results.jpg")
    print(f"Found {len(regions)} matching regions")
    
   