# CLIP + SAM

In [None]:
import cv2
from segment_anything import build_sam, SamAutomaticMaskGenerator
from PIL import Image, ImageDraw
import clip
import torch
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download the model weights to load them here
sam = build_sam(checkpoint="sam_vit_h_4b8939.pth")
sam = sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
image_path = "assets/example-image.jpg"
image_path = "assets/new_GD_DOM_RGB_LONGLAT.png"
image_path = "segmentResult/67.png"
image_path = "isaid_segm/val/images/images/P0003.png"

In [None]:
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
print(type(image))
masks = mask_generator.generate(image)

In [None]:
def convert_box_xywh_to_xyxy(box):
    x1 = box[0]
    y1 = box[1]
    x2 = box[0] + box[2]
    y2 = box[1] + box[3]
    return [x1, y1, x2, y2]

In [None]:
def segment_image(image, segmentation_mask):
    image_array = np.array(image)
    segmented_image_array = np.zeros_like(image_array)
    segmented_image_array[segmentation_mask] = image_array[segmentation_mask]
    segmented_image = Image.fromarray(segmented_image_array)
    black_image = Image.new("RGB", image.size, (0, 0, 0))
    transparency_mask = np.zeros_like(segmentation_mask, dtype=np.uint8)
    transparency_mask[segmentation_mask] = 255
    transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
    black_image.paste(segmented_image, mask=transparency_mask_image)
    return black_image

In [None]:
# Cut out all masks
image = Image.open(image_path)
print(np.array(image).shape)
print(np.array(image).dtype)
display(image)
cropped_boxes = []

for mask in masks:
    cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))

# for box in cropped_boxes:
    # display(box)

In [None]:
# Load CLIP
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
@torch.no_grad()
def retriev(elements: list[Image.Image], search_text: list[str]) -> int:
    preprocessed_images = [preprocess(image).to(device) for image in elements]
    tokenized_text = clip.tokenize(search_text).to(device)
    stacked_images = torch.stack(preprocessed_images)
    image_features = model.encode_image(stacked_images)
    text_features = model.encode_text(tokenized_text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    probs = 100. * image_features @ text_features.T
    # print(probs[:,:])
    # print(probs[:, :].softmax(dim=0))
    return probs[:, :]

In [None]:
def get_indices_of_values_above_threshold(values, threshold):
    for i, v in enumerate(values):
        # print(i, v)
        # print(v > 20)
        if (v>20):
            print(i, v)
    return [i for i, v in enumerate(values) if v > threshold]

In [None]:
COLOR_MAP = dict({
    # 'background': (0, 0, 0),
    'ship': (0, 0, 63),
    'storage tank': (0, 191, 127),
    'baseball diamond': (0, 63, 0),
    'tennis court': (0, 63, 127),
    'basketball court': (0, 63, 191),
    'ground track field': (0, 63, 255),
    'bridge': (0, 127, 63),
    'large vehicle': (0, 127, 127),
    'small vehicle': (0, 0, 127),
    'helicopter': (0, 0, 191),
    'swimming pool': (0, 0, 255),
    'roundabout': (0, 63, 63),
    'soccer ball field': (0, 127, 191),
    'plane': (0, 127, 255),
    'harbor': (0, 100, 155),
})
# background,ship,storage tank,baseball diamond,tennis court,basketball court,ground track field,bridge,large vehicle,small vehicle,helicopter,swimming pool,roundabout,soccer ball field,plane,harbor

In [None]:

text_list = []
color_list = []
hh_list = []
opacity = 255

for k, v in COLOR_MAP.items():
    text_list.append("a photo of a {}".format(k))
    hh_list.append(k)
    color_list.append((v[0], v[1], v[2], opacity))

print(hh_list)
for hh in hh_list:
    print(hh, end=',')
print('')
print(text_list)
print(color_list)

scores = retriev(cropped_boxes, text_list)
indices = range(0, len(scores))
color_indices = scores.argmax(dim=1)

segmentation_masks = []

original_image = Image.open(image_path)
overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 0))
# overlay_color = (255, 0, 0, 100)

draw = ImageDraw.Draw(overlay_image)
for i in indices:
    draw.bitmap((0, 0), 
                Image.fromarray(masks[i]["segmentation"].astype('uint8') * 255), 
                fill=color_list[scores[i].argmax(dim=0)])

print(scores[50].softmax(0))

result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
display(result_image)
result_image.save('P003_ans.png')

In [None]:
cnt = 0
import os
for box in cropped_boxes:
    # display(box)dd
    cnt += 1
    img_path = os.path.join('./segmentResult3', str(cnt)+'.png')
    box.save(img_path)

In [None]:
for mask in masks:
    print((mask["segmentation"].shape))