In [1]:
!pip install torch torchvision
!pip install git+https://github.com/openai/CLIP.git


Collecting torch
  Downloading torch-2.2.0-cp39-cp39-manylinux1_x86_64.whl.metadata (25 kB)
Collecting torchvision
  Downloading torchvision-0.17.0-cp39-cp39-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting filelock (from torch)
  Using cached filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)
Collecting sympy (from torch)
  Using cached sympy-1.12-py3-none-any.whl (5.7 MB)
Collecting networkx (from torch)
  Using cached networkx-3.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting jinja2 (from torch)
  Downloading Jinja2-3.1.3-py3-none-any.whl.metadata (3.3 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2023.12.2-py3-none-any.whl.metadata (6.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.1

In [1]:
import os
import csv
import clip
import torch
from PIL import Image

def load_clip_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    return model, preprocess, device

def predict_choice(model, preprocess, device, image_path, text_descriptions):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    text = clip.tokenize(text_descriptions).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

        # 類似度スコアの計算（修正部分）
        logits_per_image = (image_features @ text_features.T).softmax(dim=-1)
        probs = logits_per_image.cpu().numpy()

    return probs[0]


def main():
    model, preprocess, device = load_clip_model()
    base_dir = "clip_dataset"
    images_dir = os.path.join(base_dir, "images")
    csv_file_path = os.path.join(base_dir, "dataset.csv")

    with open(csv_file_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.DictReader(file)
        for row in reader:
            image_path = os.path.join(images_dir, row["image_filename"])
            text_description = row["text_description"]
            button_texts = row["button_texts"].split('|')
            text_descriptions = [f"{text_description} {btn_text}" for btn_text in button_texts]

            probs = predict_choice(model, preprocess, device, image_path, text_descriptions)
            best_choice_index = probs.argmax()
            print(f"Image: {row['image_filename']}, Best choice: {button_texts[best_choice_index]}, Probability: {probs[best_choice_index]:.4f}")

if __name__ == "__main__":
    main()


Image: image_0001.png, Best choice: 1, Probability: 0.2991
Image: image_0002.png, Best choice:  , Probability: 0.1423
Image: image_0003.png, Best choice: 9, Probability: 0.1183
Image: image_0004.png, Best choice: 0, Probability: 0.2292
Image: image_0005.png, Best choice: 2, Probability: 0.1296
Image: image_0006.png, Best choice: 9, Probability: 0.1404
Image: image_0007.png, Best choice: 9, Probability: 0.1422
Image: image_0008.png, Best choice: 8, Probability: 0.1743
Image: image_0009.png, Best choice: 9, Probability: 0.1554
Image: image_0010.png, Best choice: 8, Probability: 0.2101
Image: image_0011.png, Best choice: 1, Probability: 0.1572
Image: image_0012.png, Best choice: 0, Probability: 0.2242
Image: image_0013.png, Best choice: 1, Probability: 0.2257
Image: image_0014.png, Best choice:  , Probability: 0.1493
Image: image_0015.png, Best choice: 1, Probability: 0.1587
Image: image_0016.png, Best choice: 5, Probability: 0.1846
Image: image_0017.png, Best choice: 1, Probability: 0.17

In [2]:
import os
import csv
import clip
from PIL import Image
import torch

def load_clip_model():
    # Load the CLIP model from OpenAI
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    return model, preprocess, device

def predict_choice(model, preprocess, device, image_path, text_descriptions):
    # Preprocess the image and tokenize the text descriptions
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    text_tokens = clip.tokenize(text_descriptions).to(device)

    with torch.no_grad():
        # Encode image and text with the CLIP model
        image_features = model.encode_image(image)
        text_features = model.encode_text(text_tokens)

        # Calculate the similarity scores and convert to probabilities
        logits_per_image = (image_features @ text_features.T).softmax(dim=-1)
        probs = logits_per_image.cpu().numpy()

    return probs[0]




def main():
    model, preprocess, device = load_clip_model()
    base_dir = "clip_dataset"
    images_dir = os.path.join(base_dir, "images")
    csv_file_path = os.path.join(base_dir, "dataset.csv")

    with open(csv_file_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.DictReader(file)
        for row in reader:
            image_path = os.path.join(images_dir, row["image_filename"])
            text_description = row["text_description"]
            choices = row["button_texts"].split('|')

            # 'many' が text_description に含まれているかをチェック
            if "many" in text_description:
                parts = text_description.split(" ")
                try:
                    items_index = parts.index("many") + 1
                    items_phrase = " ".join(parts[items_index:-2]) 
                except ValueError:
                    # 'many' の後に 'are there' が見つからない場合はスキップ
                    print(f"Error in text description for image: {row['image_filename']}")
                    continue

                # 各選択肢に対する説明文を生成
                text_descriptions = [f"There are {choice} {items_phrase}." for choice in choices]
            else:
                # 'many' がなければ、デフォルトのフレーズを使用
                text_descriptions = [f"There are {choice} items." for choice in choices]

            # CLIPモデルによる予測
            probs = predict_choice(model, preprocess, device, image_path, text_descriptions)
            best_choice_index = probs.argmax()
            selected_choice = choices[best_choice_index] if choices[best_choice_index] != ' ' else 'No valid choice'
            print(f"Image: {row['image_filename']}, Best choice: {selected_choice}, Probability: {probs[best_choice_index]:.4f}")

if __name__ == "__main__":
    main()


Image: image_0001.png, Best choice: 1, Probability: 0.2666
Image: image_0002.png, Best choice: 2, Probability: 0.3506
Image: image_0003.png, Best choice: 2, Probability: 0.4395
Image: image_0004.png, Best choice: 2, Probability: 0.2939
Image: image_0005.png, Best choice: 3, Probability: 0.2769
Image: image_0006.png, Best choice: 3, Probability: 0.1641
Image: image_0007.png, Best choice: 3, Probability: 0.5615
Image: image_0008.png, Best choice: 4, Probability: 0.2568
Image: image_0009.png, Best choice: 2, Probability: 0.3362
Image: image_0010.png, Best choice: 2, Probability: 0.3718
Image: image_0011.png, Best choice: 6, Probability: 0.2817
Image: image_0012.png, Best choice: 0, Probability: 0.2903
Image: image_0013.png, Best choice: 8, Probability: 0.2285
Image: image_0014.png, Best choice: No valid choice, Probability: 0.1887
Image: image_0015.png, Best choice: No valid choice, Probability: 0.1416
Image: image_0016.png, Best choice: 2, Probability: 0.3442
Image: image_0017.png, Best 