In [2]:
import os
import csv
from PIL import Image
import torch
from model import load_models, extract_features, find_most_similar_image, load_dataset_features

# Folder paths for images and features
IMAGE_FOLDER = r"dataset/image_matching/training_dataset"
SIMCLR_FEATURES_FILE = r'image-matching\features\simclr features.pkl'
DEIT_FEATURES_FILE = r'image-matching\features\deit features.pkl'
CLIP_FEATURES_FILE = r'image-matching\features\clip features.pkl'
CNN_FEATURES_FILE = 'image-matching\features\cnn features.pkl'

# Load models
simclr_model, deit_model, clip_model, base_cnn_model = load_models()

# Function to perform search and save results to a CSV
def search_and_save_results(query_image_path, model_choice, output_csv):
    # Load the query image
    query_image = Image.open(query_image_path).convert("RGB")

    # Select the appropriate model and features
    if model_choice == 'simclr':
        uploaded_features = extract_features(simclr_model, query_image, 'simclr')
        dataset_features = load_dataset_features(IMAGE_FOLDER, simclr_model, 'simclr', SIMCLR_FEATURES_FILE)
    elif model_choice == 'deit':
        uploaded_features = extract_features(deit_model, query_image, 'deit')
        dataset_features = load_dataset_features(IMAGE_FOLDER, deit_model, 'deit', DEIT_FEATURES_FILE)
    elif model_choice == 'clip':
        uploaded_features = extract_features(clip_model, query_image, 'clip')
        dataset_features = load_dataset_features(IMAGE_FOLDER, clip_model, 'clip', CLIP_FEATURES_FILE)
    elif model_choice == 'base_cnn':
        uploaded_features = extract_features(base_cnn_model, query_image, 'base_cnn')
        dataset_features = load_dataset_features(IMAGE_FOLDER, base_cnn_model, 'base_cnn', CNN_FEATURES_FILE)
    else:
        raise ValueError("Invalid model choice. Choose from 'simclr', 'deit', 'clip', 'base_cnn'.")

    # Find the most similar images
    most_similar_images = find_most_similar_image(uploaded_features, dataset_features)[:5]

    # Extract image paths and similarity scores
    similar_image_paths = [img[0] for img in most_similar_images]
    similarity_scores = [f"{img[1]:.4f}" for img in most_similar_images]

    # Write results to CSV
    with open(output_csv, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["query_image", "similar_images", "similarity_scores"])
        writer.writerow([
            query_image_path,
            ", ".join(similar_image_paths),
            ", ".join(similarity_scores)
        ])

    print(f"Results saved to {output_csv}")

# Example Usage
if __name__ == "__main__":
    query_image_path = "evaluation\image-matching\query\listing_1_image_1.png_20240927_122438.png"  # Replace with the actual query image path
    model_choice = "clip"  # Choose from 'simclr', 'deit', 'clip', 'base_cnn'
    output_csv = "output_results.csv"  # Output CSV file
    search_and_save_results(query_image_path, model_choice, output_csv)


ModuleNotFoundError: No module named 'numpy'