# 中文CLIP图片搜索演示

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'
import sys
sys.path.insert(0, '..')
from pathlib import Path
import pandas as pd
import numpy as np
import torch
from clip.clip import build_model
from clip.simple_tokenizer import SimpleCharTokenizer
from clip.data import tokenize

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

device: cuda


In [3]:
state_dict = torch.load('../data/clip_zh.mse.pt')

model = build_model(state_dict).to(device)

embed_dim: 512
vocab_size: 8000
text_projection shape: torch.Size([512, 512])


In [4]:
# Load the photo IDs
photo_ids = pd.read_csv("../data/unsplash-dataset/photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])

# Load the features vectors
photo_features = np.load("../data/unsplash-dataset/features.npy")

# Print some statistics
print(f"Photos loaded: {len(photo_ids)}")

Photos loaded: 1981161


In [5]:
zh_tokenizer = SimpleCharTokenizer('../data/zh_vocab.txt')

Loading vocab from ../data/zh_vocab.txt


In [6]:
def encode_search_query(search_query):
    with torch.no_grad():
        # Encode and normalize the search query using CLIP
        text_encoded = model.encode_text(tokenize(zh_tokenizer, search_query).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

    # Retrieve the feature vector
    return text_encoded.cpu().numpy()

In [7]:
def find_best_matches(text_features, photo_features, photo_ids, results_count=3):
    # Compute the similarity between the search query and each photo using the Cosine similarity
    similarities = (photo_features @ text_features.T).squeeze(1)

    # Sort the photos by their similarity score
    best_photo_idx = (-similarities).argsort()

    # Return the photo IDs of the best matches
    return [photo_ids[i] for i in best_photo_idx[:results_count]]

In [8]:
from IPython.display import Image
from IPython.core.display import HTML

def display_photo(photo_id):
    # Get the URL of the photo resized to have a width of 320px
    photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=320"

    # Display the photo
    display(Image(url=photo_image_url))

    # Display the attribution text
    display(HTML(f'Photo on <a target="_blank" href="https://unsplash.com/photos/{photo_id}">Unsplash</a> '))
    print()

## 以文搜图

In [9]:
def search_unslash(search_query, photo_features, photo_ids, results_count=3):
    # Encode the search query
    text_features = encode_search_query(search_query)

    # Find the best matches
    best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)

    # Display the best photos
    for photo_id in best_photo_ids:
        display_photo(photo_id)

In [10]:
search_query = "两只狗在雪地里玩耍"

search_unslash(search_query, photo_features, photo_ids, 3)










In [11]:
search_query = "写在墙上的爱"

search_unslash(search_query, photo_features, photo_ids, 3)










In [12]:
search_query = "当你的代码终于成功时的感受"

search_unslash(search_query, photo_features, photo_ids, 3)










In [13]:
search_query = "晚上的悉尼歌剧院和哈勃桥"

search_unslash(search_query, photo_features, photo_ids, 3)










In [14]:
search_query = "悉尼歌剧院和蓝蓝的天空"

search_unslash(search_query, photo_features, photo_ids, 3)










## 以图搜图

In [15]:
def search_by_photo(query_photo_id, photo_ids, query_num=3):

    # Find the feature vector for the specified photo ID
    query_photo_index = photo_ids.index(query_photo_id)
    query_photo_features = np.expand_dims(photo_features[query_photo_index], axis=0)
    print(query_photo_features.shape)

    # Find the best match
    best_photo_ids = find_best_matches(query_photo_features, photo_features, photo_ids, query_num+1)

    for photo_id in best_photo_ids:
        if photo_id != query_photo_id:
            display_photo(photo_id)

In [16]:
def show_img_by_id(photo_id):
    display(Image(url=f"https://unsplash.com/photos/{photo_id}/download?w=320"))

In [17]:
show_img_by_id('QP-l1uE19iI')

In [18]:
search_by_photo('QP-l1uE19iI', photo_ids, 3)

(1, 512)











## 图文混合搜图

In [19]:
def search_by_text_and_photo(query_text, query_photo_id, photo_weight=0.2):
    # Encode the search query
    text_features = encode_search_query(query_text)

    # Find the feature vector for the specified photo ID
    query_photo_index = photo_ids.index(query_photo_id)
    query_photo_features = photo_features[query_photo_index]

    # Combine the test and photo queries and normalize again
    search_features = text_features + query_photo_features * photo_weight
    search_features = torch.from_numpy(search_features)
    search_features /= search_features.norm(dim=-1, keepdim=True)
    search_features = search_features.numpy()

    # Find the best match
    best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 1)

    # Display the results
    print("Test search result")
    search_unslash(query_text, photo_features, photo_ids, 1)

    print("Photo query")
    display(Image(url=f"https://unsplash.com/photos/{query_photo_id}/download?w=320"))

    print("Result for text query + photo query")
    display_photo(best_photo_ids[0])

In [20]:
search_by_text_and_photo("悉尼歌剧院", "MaerUPAjPbs")

Test search result



Photo query


Result for text query + photo query





In [21]:
search_by_text_and_photo("悉尼歌剧院", "1pNBJ2zUfn4", 0.3)

Test search result



Photo query


Result for text query + photo query





In [22]:
search_by_text_and_photo("悉尼歌剧院", "jnBDclcdZ7A", 0.5)

Test search result



Photo query


Result for text query + photo query



