In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
!pip install open_clip_torch

Collecting open_clip_torch
  Downloading open_clip_torch-2.20.0-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
Collecting ftfy (from open_clip_torch)
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub (from open_clip_torch)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece (from open_clip_torch)
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
Collecting timm (from open_clip_torch)
  Downloading timm-0.9.5-py3-none-

In [None]:
import torch
from PIL import Image
import open_clip
import os

directory = '/content/drive/MyDrive/obj-images'
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

images = torch.tensor([])
files = []

for filename in os.listdir(directory):
  f = os.path.join(directory, filename)
  files.append(f)
  img = preprocess(Image.open(f)).unsqueeze(0)
  images = torch.cat((images, img), dim=0)

text = tokenizer(["a chicken flying over the cuckoo's nest", "a dog", "a cat"])

with torch.no_grad(), torch.cuda.amp.autocast():
  image_features = model.encode_image(images)
  text_features = model.encode_text(text)
  image_features /= image_features.norm(dim=-1, keepdim=True)
  text_features /= text_features.norm(dim=-1, keepdim=True)

  obj_img_probs = (100.0 * text_features @ image_features.T).softmax(dim=-1)
  # Get object indices
  obj_list = torch.argmax(obj_img_probs, dim=1)

print("List of object images: ", files)
print("Label probs:", obj_img_probs)
print("List of primitive objects:", obj_list)

tensor([[[[ 1.9303,  1.9303,  1.9303,  ...,  1.9303,  1.9303,  1.9303],
          [ 1.9303,  1.9303,  1.9303,  ...,  1.9303,  1.9303,  1.9303],
          [ 1.9303,  1.9303,  1.9303,  ...,  1.9303,  1.9303,  1.9303],
          ...,
          [ 1.7990,  1.7990,  1.7990,  ...,  1.9303,  1.9303,  1.9303],
          [ 1.7990,  1.7698,  1.7698,  ...,  1.9303,  1.9303,  1.9303],
          [ 1.8135,  1.8135,  1.7844,  ...,  1.9303,  1.9303,  1.9303]],

         [[ 2.0749,  2.0749,  2.0749,  ...,  2.0749,  2.0749,  2.0749],
          [ 2.0749,  2.0749,  2.0749,  ...,  2.0749,  2.0749,  2.0749],
          [ 2.0749,  2.0749,  2.0749,  ...,  2.0749,  2.0749,  2.0749],
          ...,
          [ 1.9098,  1.8948,  1.9098,  ...,  2.0749,  2.0749,  2.0749],
          [ 1.8948,  1.8798,  1.8798,  ...,  2.0749,  2.0749,  2.0749],
          [ 1.8948,  1.8948,  1.8798,  ...,  2.0749,  2.0749,  2.0749]],

         [[ 2.1459,  2.1459,  2.1459,  ...,  2.1459,  2.1459,  2.1459],
          [ 2.1459,  2.1459,  