In [92]:
import torch
# from transformers import CLIPModel, CLIPProcessor, ViTModel
from transformers import AutoModel, AutoProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"

clip_model_name = "openai/clip-vit-base-patch16"
clip = AutoModel.from_pretrained(clip_model_name)
clip_processor = AutoProcessor.from_pretrained(clip_model_name)

vit = AutoModel.from_pretrained("google/vit-base-patch16-224-in21k")

clip

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [93]:
vit

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (d

In [94]:
from PIL import Image
import torch

snippets = [
    "a photo of free movement",
    "a photo of a dog",
    "a photo of a cat",
  ]

img = Image.open("cat.jpeg")
inputs = clip_processor(
  text=snippets,
  images=img,
  return_tensors="pt",
  padding=True,
).to(device)

with torch.no_grad():
    outputs = clip(**inputs, output_hidden_states=True)

outputs

CLIPOutput(loss=None, logits_per_image=tensor([[20.8677, 22.0518, 28.0532]]), logits_per_text=tensor([[20.8677],
        [22.0518],
        [28.0532]]), text_embeds=tensor([[ 0.0050, -0.0213,  0.0039,  ...,  0.0360, -0.0059,  0.0128],
        [ 0.0282, -0.0117,  0.0112,  ..., -0.0110,  0.0240,  0.0283],
        [ 0.0413, -0.0037,  0.0096,  ...,  0.0059, -0.0040,  0.0131]]), image_embeds=tensor([[ 6.2535e-02, -3.7395e-02, -4.6840e-02,  5.6853e-03,  1.7240e-02,
          1.8315e-02, -3.3103e-03,  6.0352e-02,  3.6465e-02, -1.9146e-02,
          1.3047e-02,  2.3058e-03,  1.2887e-02,  1.3924e-02,  1.8829e-02,
         -5.2959e-03, -2.7521e-02, -1.1035e-02, -2.8547e-02, -2.0736e-02,
         -1.5262e-02,  1.2065e-02,  2.2423e-02, -4.2020e-02, -1.4927e-02,
          2.4974e-02, -1.7222e-02,  4.7236e-03,  1.7965e-02,  2.6333e-02,
          3.7878e-03,  2.7034e-02,  1.1925e-02,  7.9192e-03, -6.6373e-03,
         -1.7795e-03,  1.1701e-02,  1.0027e-02,  5.8458e-02,  1.7746e-02,
          3.8956e-

In [95]:
outputs.keys()

odict_keys(['logits_per_image', 'logits_per_text', 'text_embeds', 'image_embeds', 'text_model_output', 'vision_model_output'])

In [96]:
# image and text embeds should be in same dimension
outputs.image_embeds.shape, outputs.text_embeds.shape

(torch.Size([1, 512]), torch.Size([3, 512]))

In [97]:
# see last encoder layers
img_hidden = outputs.vision_model_output.hidden_states
text_hidden = outputs.text_model_output.hidden_states
img_hidden[-1].shape, text_hidden[-1].shape

(torch.Size([1, 197, 768]), torch.Size([3, 7, 512]))

In [98]:
len(img_hidden), len(text_hidden)

(13, 13)

In [99]:
# we expect these to match
img_last_hidden = outputs.vision_model_output.last_hidden_state
img_hidden[-1] == img_last_hidden

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])

In [100]:
probs = outputs.logits_per_image.softmax(dim=-1)
probs

tensor([[7.5508e-04, 2.4673e-03, 9.9678e-01]])

In [101]:
pred_idx = probs.argmax().item()
pred = snippets[pred_idx]
pred

'a photo of a cat'