In [1]:
import os
import torch
from PIL import Image
from transformers import Owlv2Processor, Owlv2Model
from tqdm import tqdm

In [8]:
DAM_DIR = "./data/DAM"
TEST_DIR = "./data/test_image_headmind"
OUTPUT_FILE = "matching_items.csv"

In [3]:
processor = Owlv2Processor.from_pretrained("google/owlvit-base-patch32")
model = Owlv2Model.from_pretrained("google/owlvit-base-patch32")
model.eval()


preprocessor_config.json:   0%|          | 0.00/392 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


tokenizer_config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.42k [00:00<?, ?B/s]

You are using a model of type owlvit to instantiate a model of type owlv2. This is not supported for all configurations of models and can yield errors.


model.safetensors:   0%|          | 0.00/613M [00:00<?, ?B/s]

Some weights of Owlv2Model were not initialized from the model checkpoint at google/owlvit-base-patch32 and are newly initialized: ['logit_scale', 'text_model.embeddings.position_embedding.weight', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.v_proj.bias

Owlv2Model(
  (text_model): Owlv2TextTransformer(
    (embeddings): Owlv2TextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(16, 512)
    )
    (encoder): Owlv2Encoder(
      (layers): ModuleList(
        (0-11): 12 x Owlv2EncoderLayer(
          (self_attn): Owlv2Attention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): Owlv2MLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps

In [4]:
def extract_features(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        features = model.get_image_features(**inputs).squeeze(0)
    return features
# Step 1: Extract features for DAM images
print("Extracting features for DAM images...")
dam_embeddings = {}
for file in tqdm(os.listdir(DAM_DIR)):
    if file.endswith(".jpeg"):
        file_path = os.path.join(DAM_DIR, file)
        dam_embeddings[file] = extract_features(file_path)

Extracting features for DAM images...


100%|██████████| 2766/2766 [1:10:36<00:00,  1.53s/it]


In [9]:
# Match each test image to a DAM image
print("Matching test images to DAM references...")
results = []
for file in tqdm(os.listdir(TEST_DIR)):
    if file.endswith((".jpeg", ".jpg")):  # Check for both extensions
        query_path = os.path.join(TEST_DIR, file)
        query_features = extract_features(query_path)
        
        # Compare with all DAM embeddings
        best_match = None
        best_score = -float("inf")
        for dam_file, dam_features in dam_embeddings.items():
            score = torch.nn.functional.cosine_similarity(query_features, dam_features, dim=0).item()
            if score > best_score:
                best_score = score
                best_match = dam_file
        
        results.append((file, best_match, best_score))

# Save results to a CSV
print("Saving results...")
with open(OUTPUT_FILE, "w") as f:
    f.write("Query Image,Matched Reference,Score\n")
    for query, match, score in results:
        f.write(f"{query},{match},{score:.4f}\n")

print(f"Results saved to {OUTPUT_FILE}.")

Matching test images to DAM references...


100%|██████████| 80/80 [02:44<00:00,  2.06s/it]

Saving results...
Results saved to matching_items.csv.





In [12]:
from PIL import Image

def visualize_match(query_image_path, reference_image_path):
    query_image = Image.open(query_image_path)
    reference_image = Image.open(reference_image_path)
    
    # Combine images side by side
    combined = Image.new("RGB", (query_image.width + reference_image.width, max(query_image.height, reference_image.height)))
    combined.paste(query_image, (0, 0))
    combined.paste(reference_image, (query_image.width, 0))
    combined.show()

# Example usage
visualize_match("./data/test_image_headmind/IMG_6901.jpg", "./data/DAM/124S55AM128X0872.jpeg")
