In [27]:
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
import torch
import torch.nn as nn
from PIL import Image

processor = AutoProcessor.from_pretrained("geolocal/StreetCLIP")

model = AutoModelForZeroShotImageClassification.from_pretrained("geolocal/StreetCLIP")

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [28]:
model

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05,

In [29]:
processor

CLIPProcessor:
- image_processor: CLIPImageProcessor {
  "crop_size": {
    "height": 336,
    "width": 336
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "CLIPFeatureExtractor",
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "processor_class": "CLIPProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 336
  }
}

- tokenizer: CLIPTokenizerFast(name_or_path='geolocal/StreetCLIP', vocab_size=49408, model_max_length=77, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=

In [30]:
img = Image.open("./japan_001.jpg")
inputs = processor(
    text=["Japan", "China", "United States", "Canada", "Germany", "United Kingdom"],
    images=img,
    return_tensors="pt",
    padding=True
)
inputs

{'input_ids': tensor([[49406,  3400, 49407, 49407],
        [49406,  2817, 49407, 49407],
        [49406,  2690,  4218, 49407],
        [49406,  2698, 49407, 49407],
        [49406,  4464, 49407, 49407],
        [49406,  2690,  7364, 49407]]), 'attention_mask': tensor([[1, 1, 1, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 1, 1, 0],
        [1, 1, 1, 1]]), 'pixel_values': tensor([[[[ 0.2661,  0.2661,  0.2661,  ..., -0.1572, -0.1572, -0.1572],
          [ 0.2807,  0.2661,  0.2661,  ..., -0.1426, -0.1426, -0.1426],
          [ 0.2807,  0.2807,  0.2807,  ..., -0.1426, -0.1426, -0.1426],
          ...,
          [-0.3324, -1.4565, -1.4273,  ..., -1.4419, -1.6463, -1.6317],
          [ 0.1931, -0.4492, -1.0039,  ..., -1.3543, -1.4419, -1.5149],
          [ 0.6749,  0.5435,  0.1347,  ..., -1.2959, -1.4419, -1.5295]],

         [[ 0.9343,  0.9343,  0.9343,  ...,  0.5441,  0.5441,  0.5441],
          [ 0.9493,  0.9343,  0.9343,  ...,  0.5591,  0.5591,  0.55

In [31]:
outputs = model(**inputs)

In [32]:
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) 

In [33]:
probs

tensor([[9.9925e-01, 1.5090e-04, 3.2259e-04, 1.7586e-04, 3.6761e-05, 6.3483e-05]],
       grad_fn=<SoftmaxBackward0>)