### Retrieval performance
We compare the performance of the Xception and CLIP ViT models in terms of top-k accuracy.

In [None]:
from dataset import Street2ShopImageSimilarityTestDataset, evaluate_top_k_accuracies
from models.xception import XceptionModel
from models.clip_vit import CLIPViTModel

In [None]:
model = XceptionModel(embedding_dim=512).load('saved_models/xception_512.pth')
test_dataset = Street2ShopImageSimilarityTestDataset(model, ratio=0.6)
print(len(test_dataset))
print(test_dataset[0])

xception_accuracies, xception_visualization_data = evaluate_top_k_accuracies(test_dataset)

In [None]:
model = CLIPViTModel(embedding_dim=512).load('saved_models/clipvit_512.pth')
test_dataset = Street2ShopImageSimilarityTestDataset(model, ratio=0.6)
print(len(test_dataset))
print(test_dataset[0])

clipvit_accuracies, clipvit_visualization_data = evaluate_top_k_accuracies(test_dataset)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Extract accuracy values
metrics = ['top_1_accuracy', 'top_3_accuracy', 'top_5_accuracy', 'top_10_accuracy']
xception_values = [xception_accuracies[m] for m in metrics]
clipvit_values = [clipvit_accuracies[m] for m in metrics]

# Set up bar positions
x = np.arange(len(metrics))
width = 0.35  # Width of bars

# Create figure and axis
fig, ax = plt.subplots(figsize=(10, 6))

# Create bars
rects1 = ax.bar(x - width/2, xception_values, width, label='Xception', color='skyblue')
rects2 = ax.bar(x + width/2, clipvit_values, width, label='CLIP ViT', color='lightcoral')

# Customize plot
ax.set_ylabel('Accuracy')
ax.set_title('Model Performance Comparison')
ax.set_xticks(x)
ax.set_xticklabels(['Top-1', 'Top-3', 'Top-5', 'Top-10'])
ax.legend()

# Add value labels on top of bars
def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height:.2%}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom', rotation=0)

autolabel(rects1)
autolabel(rects2)

plt.tight_layout()
plt.show()

In [None]:
# Visualize example results
def plot_retrieval_results(dataset, query_idx, retrieved_indices, num_results=5):
    """
    Plot the query street photo and its top retrieved shop photos
    
    Args:
        dataset: The test dataset instance
        query_idx: Index of the query street photo
        retrieved_indices: List of indices for retrieved shop photos
        num_results: Number of top results to show
    """
    # Get query image
    query_item = dataset.test_dataset[query_idx]
    query_image = query_item['street_photo_image']
    
    # Get retrieved images
    retrieved_images = []
    for idx in retrieved_indices[:num_results]:
        retrieved_item = dataset.test_dataset[idx]
        retrieved_images.append(retrieved_item['shop_photo_image'])
    
    # Create subplot
    fig = plt.figure(figsize=(15, 3))
    
    # Plot query image
    plt.subplot(1, num_results + 1, 1)
    plt.imshow(query_image)
    plt.title('Query\n(Street Photo)')
    plt.axis('off')
    
    # Plot retrieved images
    for i, img in enumerate(retrieved_images, 1):
        plt.subplot(1, num_results + 1, i + 1)
        plt.imshow(img)
        plt.title(f'Top-{i}\n(Shop Photo)')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Display a few example retrievals
print("Example retrievals from CLIP ViT model:")
for query_idx, retrieved_indices in clipvit_visualization_data[:3]:
    plot_retrieval_results(test_dataset, query_idx, retrieved_indices[0])
    print("-" * 80)

print("\nExample retrievals from Xception model:")
for query_idx, retrieved_indices in xception_visualization_data[:3]:
    plot_retrieval_results(test_dataset, query_idx, retrieved_indices[0])
    print("-" * 80)

### Pipeline to enhance retrieval performance
Now we will build a pipeline to enhance the retrieval performance. Inspired by Pinterest image search pipeline.