In [1]:
!pip install torch torchvision
!pip install git+https://github.com/openai/CLIP.git


Collecting torch
  Downloading torch-2.2.0-cp39-cp39-manylinux1_x86_64.whl.metadata (25 kB)
Collecting torchvision
  Downloading torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting filelock (from torch)
  Using cached filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)
Collecting sympy (from torch)
  Using cached sympy-1.12-py3-none-any.whl (5.7 MB)
Collecting networkx (from torch)
  Using cached networkx-3.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting jinja2 (from torch)
  Downloading Jinja2-3.1.3-py3-none-any.whl.metadata (3.3 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2023.12.2-py3-none-any.whl.metadata (6.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.1

In [3]:
import os
import csv
import clip
from PIL import Image
import torch

def load_clip_model():
    # Load the CLIP model from OpenAI
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    return model, preprocess, device

def predict_choice(model, preprocess, device, image_path, text_descriptions):
    # Preprocess the image and tokenize the text descriptions
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    text_tokens = clip.tokenize(text_descriptions).to(device)

    with torch.no_grad():
        # Encode image and text with the CLIP model
        image_features = model.encode_image(image)
        text_features = model.encode_text(text_tokens)

        # Calculate the similarity scores and convert to probabilities
        logits_per_image = (image_features @ text_features.T).softmax(dim=-1)
        probs = logits_per_image.cpu().numpy()

    return probs[0]




def main():
    model, preprocess, device = load_clip_model()
    base_dir = "clip_dataset"
    images_dir = os.path.join(base_dir, "images")
    csv_file_path = os.path.join(base_dir, "dataset.csv")

    with open(csv_file_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.DictReader(file)
        for row in reader:
            image_path = os.path.join(images_dir, row["image_filename"])
            text_description = row["text_description"]
            choices = row["button_texts"].split('|')

            # 'many' が text_description に含まれているかをチェック
            if "many" in text_description:
                parts = text_description.split(" ")
                try:
                    items_index = parts.index("many") + 1
                    items_phrase = " ".join(parts[items_index:-2]) 
                except ValueError:
                    # 'many' の後に 'are there' が見つからない場合はスキップ
                    print(f"Error in text description for image: {row['image_filename']}")
                    continue

                # 各選択肢に対する説明文を生成
                text_descriptions = [f"There are {choice} {items_phrase}." for choice in choices]
            else:
                # 'many' がなければ、デフォルトのフレーズを使用
                text_descriptions = [f"There are {choice} items." for choice in choices]

            # CLIPモデルによる予測
            probs = predict_choice(model, preprocess, device, image_path, text_descriptions)
            best_choice_index = probs.argmax()
            selected_choice = choices[best_choice_index] if choices[best_choice_index] != ' ' else 'No valid choice'
            print(f"Image: {row['image_filename']}, Best choice: {selected_choice}, Probability: {probs[best_choice_index]:.4f}")

if __name__ == "__main__":
    main()
