In [122]:
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
import torch
import torch.nn as nn
from PIL import Image
import pathlib
from typing import *
import pandas as pd
import shutil
import os

processor = AutoProcessor.from_pretrained("geolocal/StreetCLIP")

model = AutoModelForZeroShotImageClassification.from_pretrained("geolocal/StreetCLIP")

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [123]:
model

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05,

In [124]:
processor

CLIPProcessor:
- image_processor: CLIPImageProcessor {
  "crop_size": {
    "height": 336,
    "width": 336
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "CLIPFeatureExtractor",
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "processor_class": "CLIPProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 336
  }
}

- tokenizer: CLIPTokenizerFast(name_or_path='geolocal/StreetCLIP', vocab_size=49408, model_max_length=77, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=

In [125]:
town_name = "Minowacho"
predict_town = True
town = pathlib.Path(f"data/addrs/{town_name}")

maps = {
    "1-chome": "〒223-0061 神奈川県横浜市港北区日吉1丁目",
    "2-chome": "〒223-0061 神奈川県横浜市港北区日吉2丁目",
    "3-chome": "〒223-0061 神奈川県横浜市港北区日吉3丁目",
    "4-chome": "〒223-0061 神奈川県横浜市港北区日吉4丁目",
    "5-chome": "〒223-0061 神奈川県横浜市港北区日吉5丁目",
    "6-chome": "〒223-0061 神奈川県横浜市港北区日吉6丁目",
    "7-chome": "〒223-0061 神奈川県横浜市港北区日吉7丁目",
}

labels = []

if predict_town:
    labels = [
        "Hiyoshi",
        "Hiyoshihoncho",
        "Minowacho"
    ]
else:
    for folder in sorted(list(town.iterdir())):
        name = folder.name
        # if maps.get(name) is not None:
            # labels.append(maps[name])
        if name not in maps:
            continue
        labels.append(name)
print(labels)

['日吉', '日吉本町', '箕輪町']


In [126]:
def predict(labels: List[str], img: torch.Tensor):
    inputs = processor(
        text=labels,
        images=img,
        return_tensors="pt",
        padding=True
    )
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image # this is the image-text similarity score
    probs = logits_per_image.softmax(dim=1)
    print("probs:", probs)
    index = torch.argmax(probs, dim=1).item()
    return labels[index]

In [127]:

path = pathlib.Path(f"data/addrs/{town_name}/3-chome")
err_img = Image.open(pathlib.Path("data/error_img.png").as_posix()).convert("RGB")

for img_path in path.iterdir():
    img = Image.open(img_path.as_posix()).convert("RGB")
    # if err_img == img:
    #     os.remove(img_path)
    #     print("removed:", img_path.as_posix())
    #     continue
    print("img:", img_path.name)
    ret = predict(labels, img=img)
    print("result", ret)

img: 35.54571017-139.64016186.png
probs: tensor([[0.0653, 0.6180, 0.3167]], grad_fn=<SoftmaxBackward0>)
result 日吉本町
img: 35.54520193-139.64016186.png
probs: tensor([[0.0173, 0.6211, 0.3616]], grad_fn=<SoftmaxBackward0>)
result 日吉本町
img: 35.54520193-139.64168843.png
probs: tensor([[0.0232, 0.7503, 0.2266]], grad_fn=<SoftmaxBackward0>)
result 日吉本町
img: 35.54774315-139.64321500.png
probs: tensor([[0.0323, 0.6986, 0.2691]], grad_fn=<SoftmaxBackward0>)
result 日吉本町
img: 35.54723490-139.64321500.png
probs: tensor([[0.0184, 0.3875, 0.5941]], grad_fn=<SoftmaxBackward0>)
result 箕輪町
img: 35.54571017-139.64168843.png
probs: tensor([[0.0302, 0.5802, 0.3896]], grad_fn=<SoftmaxBackward0>)
result 日吉本町
img: 35.54977612-139.64474157.png
probs: tensor([[0.0221, 0.6759, 0.3020]], grad_fn=<SoftmaxBackward0>)
result 日吉本町
img: 35.54977612-139.64525043.png
probs: tensor([[0.0204, 0.5316, 0.4480]], grad_fn=<SoftmaxBackward0>)
result 日吉本町
img: 35.54825139-139.64168843.png
probs: tensor([[0.0546, 0.1661, 0.7793]

KeyboardInterrupt: 