In [1]:
import warnings
import numpy as np
import matplotlib.pyplot as plt
import requests
from PIL import Image
from io import BytesIO
from lang_sam import LangSAM

image = "img1.jpg"
text_prompt = "clothes"

def save_mask(mask_np, filename):
    mask_image = Image.fromarray((mask_np * 255).astype(np.uint8))
    mask_image.save(filename)

def display_image_with_masks(image, masks):
    num_masks = len(masks)

    fig, axes = plt.subplots(1, num_masks + 1, figsize=(15, 5))
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    for i, mask_np in enumerate(masks):
        axes[i+1].imshow(mask_np, cmap='gray')
        axes[i+1].set_title(f"Mask {i+1}")
        axes[i+1].axis('off')

    plt.tight_layout()
    plt.show()

def display_image_with_boxes(image, boxes, logits):
    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.set_title("Image with Bounding Boxes")
    ax.axis('off')

    for box, logit in zip(boxes, logits):
        x_min, y_min, x_max, y_max = box
        confidence_score = round(logit.item(), 2)  # Convert logit to a scalar before rounding
        box_width = x_max - x_min
        box_height = y_max - y_min

        # Draw bounding box
        rect = plt.Rectangle((x_min, y_min), box_width, box_height, fill=False, edgecolor='red', linewidth=2)
        ax.add_patch(rect)

        # Add confidence score as text
        ax.text(x_min, y_min, f"Confidence: {confidence_score}", fontsize=8, color='red', verticalalignment='top')

    plt.show()

def print_bounding_boxes(boxes):
    print("Bounding Boxes:")
    for i, box in enumerate(boxes):
        print(f"Box {i+1}: {box}")

def print_detected_phrases(phrases):
    print("\nDetected Phrases:")
    for i, phrase in enumerate(phrases):
        print(f"Phrase {i+1}: {phrase}")

def print_logits(logits):
    print("\nConfidence:")
    for i, logit in enumerate(logits):
        print(f"Logit {i+1}: {logit}")

def main():
    warnings.filterwarnings("ignore")

    try:
        image_pil = Image.open(image).convert("RGB")
        model = LangSAM()
        masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
        
        if len(masks) == 0:
            print(f"No objects of the '{text_prompt}' prompt detected in the image.")
        else:
            # Convert masks to numpy arrays
            masks_np = [mask.squeeze().cpu().numpy() for mask in masks]

            # Display the original image and masks side by side
            display_image_with_masks(image_pil, masks_np)

            # Display the image with bounding boxes and confidence scores
            display_image_with_boxes(image_pil, boxes, logits)

            # Save the masks
            for i, mask_np in enumerate(masks_np):
                mask_path = f"image_mask_{i+1}.png"
                save_mask(mask_np, mask_path)


    except (requests.exceptions.RequestException, IOError) as e:
        print(f"Error: {e}")
    
if __name__ == "__main__":
    main()




  from .autonotebook import tqdm as notebook_tqdm
Downloading (…)ingDINO_SwinB.cfg.py: 100%|██████████| 1.01k/1.01k [00:00<00:00, 1.94MB/s]


final text_encoder_type: bert-base-uncased


Downloading tokenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<00:00, 54.5kB/s]
Downloading config.json: 100%|██████████| 570/570 [00:00<00:00, 1.11MB/s]
Downloading vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 450kB/s]
Downloading tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 592kB/s]
Downloading model.safetensors: 100%|██████████| 440M/440M [00:28<00:00, 15.6MB/s] 
Downloading (…)no_swinb_cogcoor.pth: 100%|██████████| 938M/938M [00:59<00:00, 15.7MB/s] 


Model loaded from C:\Users\yeeqi\.cache\huggingface\hub\models--ShilongLiu--GroundingDINO\snapshots\a94c9b567a2a374598f05c584e96798a170c56fb\groundingdino_swinb_cogcoor.pth 
 => _IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight', 'bert.embeddings.position_ids'])


Downloading: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" to C:\Users\yeeqi/.cache\torch\hub\checkpoints\sam_vit_h_4b8939.pth
100%|██████████| 2.39G/2.39G [01:33<00:00, 27.4MB/s]


: 