In [182]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import io
from tqdm import tqdm
import os
from IPython.display import display
import numpy as np
import torch

In [183]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

clip_model = clip_model.to(device)

In [184]:
def process_image(image_bytes):
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

    inputs = clip_processor(
        images=image,
        return_tensors="pt"
    )
    
    for key in inputs.keys():
        inputs[key] = inputs[key].to(device)

    image_features = clip_model.get_image_features(pixel_values=inputs.pixel_values)
    
    return image_features

In [185]:
def process_all_images(folder_name, output_file, image_paths_file):
    image_embeddings = []
    image_paths = []

    total_files = sum(1 for root, dirs, files in os.walk(folder_name) for file in files if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')))
    
    with tqdm(total=total_files, desc="Processing images", leave=False) as pbar:
        for root, dirs, files in os.walk(folder_name):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                    item_path = os.path.join(root, file)
                    with open(item_path, "rb") as f:
                        image_bytes = f.read()
                    
                    image_features = process_image(image_bytes)
                    image_embeddings.append(image_features.cpu().detach().numpy())
                    image_paths.append(item_path)
                    
                    pbar.update(1)
    
    image_embeddings_tensor = torch.tensor(image_embeddings)
    
    torch.save(image_embeddings_tensor, output_file)
    
    with open(image_paths_file, 'w') as f:
        for item_path in image_paths:
            f.write("%s\n" % item_path)
    
    return image_paths

In [186]:
_ = process_all_images(folder_name="data", output_file="image_embeddings.pt", image_paths_file="image_paths.txt")

                                                                                

In [None]:
# 122.png

# 122.pt

# query -> embedding -> for each .pt, check cosine_similarity