In [None]:
from IPython.display import display, HTML
display(HTML(
"""
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
"""
))

## Environment Set-up

In [None]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
    
    !mkdir images
    
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

PyTorch version: 2.0.0+cu118
Torchvision version: 0.15.1+cu118
CUDA is available: True
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-ddhf6m4k
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-ddhf6m4k
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment-anything
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment-anything: filename=segment_anything-1.0-py3-none-any.whl size=36610 sha256=785

## Set-up

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

In [None]:
import sys
sys.path.append("..")
from tqdm import tqdm

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
path = '/drive/MyDrive/'
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
import os
os.chdir("/content/drive/MyDrive")
!ls

FileNotFoundError: ignored

To generate masks, just run `generate` on an image.

In [None]:
import pickle
path = "/content/drive/MyDrive/"
def load_pickle(filename):
    with open(filename, "rb") as input_file:
        object = pickle.load(input_file)
    return object
dict_prompt_captions = load_pickle(path+"dict_prompt_captions.pickle")
dict_prompt_images = load_pickle(path+'dict_prompt_images.pickle')


      

In [None]:
!pip install transformers
from transformers import CLIPProcessor, CLIPModel

In [None]:
keys = load_pickle(path+'keys_rand_order.pickle')
# Now let's 
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)

dict_model_input = load_pickle(path+'dict_model_input_int.pickle')

# image = 
TOP_K_OBJ = 80
k = 200
i = 0
#for key in keys[:k]:
pic = 0
save_best_crops = {}
for rand_key in keys[:k]:
    img_ex1, text_ex1, img_ex2, text_ex2, img_q, text_q_prompt = dict_model_input[rand_key]
    for img, caption in zip([img_ex1, img_ex2], [text_ex1, text_ex2]):
        masks = mask_generator.generate(img)
        cur_masks_info = []
        crops = []
        pic=pic+1
        for mask in sorted(masks, key=lambda mask: mask["area"])[-TOP_K_OBJ:]:
            if mask["predicted_iou"] < 0.9 or mask["stability_score"] < 0.84:
                continue
            
            x, y, w, h = mask["bbox"]
            masked = img * mask["segmentation"][...,np.newaxis] #make it 3channel
            crop = masked[y : y + h, x : x + w]
            cur_masks_info.append(mask)
            crops.append(crop)


            # abs_dif = abs(w-h)
            # if w > h:
            #     top, bottom = abs_dif // 2, abs_dif // 2
            #     left, right = 0, 0
            # else:
            #     top, bottom = 0, 0
            #     left, right = abs_dif // 2, abs_dif // 2
        #
        text_clip = [caption,'This is a background'] 
        inputs = processor(text=[text_clip], images=crops, return_tensors="pt", padding=True)
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image # this is the image-text cosine similarity score
        #omit these if you only want cosine similarity
        probs = logits_per_image.softmax(dim=-1)
        print(probs.shape)
        best_ind = torch.where(probs[:,0] > 0.85)
        assert(len(best_ind) == 1)
        print(best_ind[0].tolist())
        save_best_crops[rand_key] = [cur_masks_info[i] for i in best_ind]

        #torch.argsort(a, dim=1)
        
        # # padding for clip needed?
        # crop = cv2.copyMakeBorder(
        #     crop,
        #     top,
        #     bottom,
        #     left,
        #     right,
        #     cv2.BORDER_CONSTANT,
        #     value=(0, 0, 0),
        # )
        # crop = Image.fromarray(crop)

        #an den xwrane kai ta 2 montela
        # delete sam, empty torch cuda, 
        # 
        if pic < 30:

            fig = plt.figure()
            ax11 = fig.add_subplot(2, 2, 1)
            ax11.axis('off')
            ax11.imshow(img)
            

            # # showing image
            # plt.imshow(img)
            # plt.axis('off')
            # plt.title("First")
            
            # # Adds a subplot at the 2nd position
            ax12 = fig.add_subplot(2, 2, 2)
            ax12.axis('off')
            ax12.imshow(img)
            # # showing image
            # plt.imshow(img_prompt2)
            # plt.axis('off')
            # plt.title("Second")
            
            # # Adds a subplot at the 3rd position
            ax21 = fig.add_subplot(2, 2, 3)
            ax21.axis('off')
            ax21.imshow(img)
            # # showing image
            # plt.imshow(img_q)
    
            # plt.title("Third")
            
            # # Adds a subplot at the 4th position
            ax22 = fig.add_subplot(2, 2, 4)
            ax22.axis('off')
            ax22.imshow(img)
            # plt.show()
            fig.savefig(path+'images/full_figure_this'+str(pic)+'.png')