# Semantic Search on WikiArt Dataset using CLIP

This notebook demonstrates how to implement semantic search functionality using the CLIP model on the WikiArt dataset. We'll implement both text-to-image and image-to-image search capabilities.

## Setup and Dependencies
First, let's install and import all necessary packages.

In [None]:
!pip install torch torchvision open_clip_torch Pillow numpy pandas matplotlib tqdm requests

In [None]:
import torch
import open_clip
from PIL import Image
import requests
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load CLIP Model
We'll use the OpenCLIP implementation and load a pre-trained model.

In [None]:
# Load the model and processor
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

model = model.to(device)
model.eval()

## Download and Prepare WikiArt Dataset
We'll download a subset of the WikiArt dataset and prepare it for our semantic search.

In [None]:
def download_image(url, save_path):
    try:
        response = requests.get(url, stream=True)
        if response.status_code == 200:
            with open(save_path, 'wb') as f:
                f.write(response.content)
            return True
        return False
    except:
        return False

# Create directories for dataset
data_dir = Path('wikiart_data')
data_dir.mkdir(exist_ok=True)

# Download sample WikiArt metadata
# Note: In a real implementation, you would need to obtain the actual WikiArt dataset
# This is a placeholder for demonstration purposes
sample_data = {
    'filename': ['sample1.jpg', 'sample2.jpg'],
    'url': ['https://example.com/sample1.jpg', 'https://example.com/sample2.jpg'],
    'artist': ['Artist1', 'Artist2'],
    'title': ['Artwork1', 'Artwork2']
}

df = pd.DataFrame(sample_data)

# Download images
for _, row in tqdm(df.iterrows(), total=len(df)):
    save_path = data_dir / row['filename']
    if not save_path.exists():
        download_image(row['url'], save_path)

## Feature Extraction
Extract and store CLIP features for all images in the dataset.

In [None]:
def extract_features(image_path):
    image = Image.open(image_path).convert('RGB')
    image = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model.encode_image(image)
    return features.cpu().numpy()

# Extract features for all images
image_features = {}
for _, row in tqdm(df.iterrows(), total=len(df)):
    image_path = data_dir / row['filename']
    if image_path.exists():
        features = extract_features(image_path)
        image_features[row['filename']] = features

## Implement Search Functions
Create functions for both text-to-image and image-to-image search.

In [None]:
def text_to_image_search(text_query, n_results=5):
    # Encode text query
    with torch.no_grad():
        text = tokenizer(text_query).to(device)
        text_features = model.encode_text(text)
        text_features = text_features.cpu().numpy()
    
    # Calculate similarities
    similarities = {}
    for filename, feat in image_features.items():
        similarity = np.dot(text_features, feat.T)[0][0]
        similarities[filename] = similarity
    
    # Sort and return top results
    results = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:n_results]
    return results

def image_to_image_search(query_image_path, n_results=5):
    # Extract features for query image
    query_features = extract_features(query_image_path)
    
    # Calculate similarities
    similarities = {}
    for filename, feat in image_features.items():
        similarity = np.dot(query_features, feat.T)[0][0]
        similarities[filename] = similarity
    
    # Sort and return top results
    results = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:n_results]
    return results

## Visualization Functions
Create functions to display search results.

In [None]:
def display_results(results):
    n = len(results)
    fig, axes = plt.subplots(1, n, figsize=(4*n, 4))
    if n == 1:
        axes = [axes]
    
    for ax, (filename, similarity) in zip(axes, results):
        img = Image.open(data_dir / filename)
        ax.imshow(img)
        ax.axis('off')
        title = f'{filename}
Similarity: {similarity:.3f}'
        ax.set_title(title)
    
    plt.tight_layout()
    plt.show()

## Example Usage
Demonstrate how to use the semantic search functionality.

In [None]:
# Example text-to-image search
text_query = "a beautiful landscape painting with mountains"
results = text_to_image_search(text_query)
print(f'Search results for: "{text_query}"')
display_results(results)

# Example image-to-image search
query_image = data_dir / 'sample1.jpg'  # Replace with an actual image path
results = image_to_image_search(query_image)
print(f'Similar images to: {query_image.name}')
display_results(results)