In [1]:
import warnings

import torch

warnings.filterwarnings("ignore")
import numpy as np
import matplotlib.pyplot as plt
import requests
from PIL import Image
from io import BytesIO
from lang_sam import LangSAM
import os
import re

In [2]:
def split(prompt):
    name = ''
    str = re.split('[, .]', prompt)
    for i in range(len(str)):
        if i!=len(str)-1 and str[i]!='':
            name +=str[i]+'_'
        else:
            name +=str[i]
    return name

def download_image(url):
    response = requests.get(url)
    response.raise_for_status()
    return Image.open(BytesIO(response.content)).convert("RGB")

def save_mask(mask_np, filename):
    mask_image = Image.fromarray((mask_np * 255).astype(np.uint8))
    mask_image.save(filename)

def display_image_with_masks(image, masks):
    num_masks = len(masks)

    fig, axes = plt.subplots(1, num_masks + 1, figsize=(15, 5))
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('on')

    for i, mask_np in enumerate(masks):
        # print(mask_np)
        axes[i+1].imshow(mask_np, cmap='gray')
        axes[i+1].set_title(f"Mask {i+1}")
        axes[i+1].axis('off')

    plt.tight_layout()
    plt.show()

def display_image_with_boxes(image, boxes, logits):
    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.set_title("Image with Bounding Boxes")
    ax.axis('off')
    for box, logit in zip(boxes, logits):
        x_min, y_min, x_max, y_max = box
        confidence_score = round(logit.item(), 2)  # Convert logit to a scalar before rounding
        box_width = x_max - x_min
        box_height = y_max - y_min

        # Draw bounding box
        rect = plt.Rectangle((x_min, y_min), box_width, box_height, fill=False, edgecolor='red', linewidth=2)
        ax.add_patch(rect)

        # Add confidence score as text
        ax.text(x_min, y_min, f"Confidence: {confidence_score}", fontsize=8, color='red', verticalalignment='top')
    plt.show()

def print_bounding_boxes(boxes):
    print("Bounding Boxes:")
    for i, box in enumerate(boxes):
        print(f"Box {i+1}: {box}")

def print_detected_phrases(phrases):
    print("\nDetected Phrases:")
    for i, phrase in enumerate(phrases):
        print(f"Phrase {i+1}: {phrase}")

def print_logits(logits):
    print("\nConfidence:")
    for i, logit in enumerate(logits):
        print(f"Logit {i+1}: {logit}")

def show_mask(mask, ax, random_color=False):
    for i, mask_i in enumerate(mask):
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
        else:
            color = np.array([30/255, 144/255, 255/255, 0.6])
        h, w = mask.shape[-2:]
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        print(mask_image)
        ax.imshow(mask_image)

In [3]:
image = "pics/test.jpg"
text_prompt1 = "train tracks."
text_prompt2 = "backgrounds."
text_prompt3 = "person.animal."
len_prompts = 3

In [None]:
def main():
    # Suppress warning messages
    global boxes, masks, phrases, logits
    try:
        if image.startswith("http"):
            image_pil = download_image(image)
        else:
            image_pil = Image.open(image).convert("RGB")

        width, height = image_pil.size
        model = LangSAM()


        for i in range(1, len_prompts+1):
            prompt = "text_prompt" + f"{i}"
            text_prompt = globals()[prompt]
            masksi, boxesi, phrasesi, logitsi = model.predict(image_pil, text_prompt)
            print(phrasesi)
            print(np.array(phrasesi).shape)
            # print(boxesi.shape) # tensor
            # # print(phrasesi.shape) # list
            # print(logitsi.shape) # tensor
            if len(masksi) == 0:
                print(f"No objects of the '{text_prompt}' prompt detected in the image.")
            else:
                if i==1:
                    masks = masksi
                    boxes = boxesi
                    phrases = phrasesi
                    logits = logitsi
                else:
                    masks = torch.cat((masks, masksi), dim=0)
                    boxes = torch.cat((boxes, boxesi), dim=0)
                    logits = torch.cat((logits, logitsi), dim=0)
                    phrases = phrases + phrasesi

        # masks, boxes, phrases, logits = model.predict( image_pil, text_prompt3)

        if len(masks) == 0:
            text_prompt = text_prompt1 + text_prompt2 + text_prompt3
            print(f"No objects of the '{text_prompt}' prompt detected in the image.")
        else:
            # Convert masks to numpy arrays
            masks_np = [mask.squeeze().cpu().numpy() for mask in masks]
            # Display the original image and masks side by side
            # print(masks_np)
            display_image_with_masks(image_pil, masks_np)
            # ax = plt.gca()
            # img = np.ones((width, height))
            # img[:,:,3] = 0
            # for i, maskp in enumerate(masks_np):
            #     color_mask = np.concatenate([np.random.random(2), [0.35]])
            #     img[maskp] = color_mask
            # ax.imshow(img)

            masksp= np.zeros((height, width), dtype=bool)
            for _, maskp in enumerate(masks_np):
                masksp = np.logical_or(masksp, maskp)
            plt.imshow(image_pil)
            plt.axis('off')
            show_mask(masks_np, plt.gca())

            # save masks
            # name = split(text_prompt)
            # output_path = os.path.join('New_results', name)
            # os.mkdir(output_path)
            # vis_mask_output_path = os.path.join(output_path, f'{name}.jpg')
            # with open(vis_mask_output_path, 'wb') as outfile:
            #     plt.savefig(outfile, format='jpg')

            # Display the image with bounding boxes and confidence scores
            display_image_with_boxes(image_pil, boxes, logits)

            # # Save the masks
            # for i, mask_np in enumerate(masks_np):
            #     mask_path = f"image_mask_{i+1}.png"
            #     # mask_path = f"/data/chenxu/trail_projects/Personalize-SAM/New_data/Annotations/{text_prompt}/0{index}.png"
            #     save_mask(mask_np, mask_path)
            #
            # Print the bounding boxes, phrases, and logits
            print_bounding_boxes(boxes)
            print_detected_phrases(phrases)
            print_logits(logits)

    except (requests.exceptions.RequestException, IOError) as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    main()

final text_encoder_type: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model loaded from /home/chenxu/.cache/huggingface/hub/models--ShilongLiu--GroundingDINO/snapshots/ckpt/groundingdino_swinb_cogcoor.pth 
 => _IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight'])
['train tracks']
(1,)
