In [None]:
import clip
import torch
from PIL import Image
import numpy as np
import pandas as pd

# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

100%|████████████████████████████████████████| 338M/338M [00:01<00:00, 186MiB/s]


In [None]:
#configure paths
from pathlib import Path

unsplash_dataset_path = Path("/content/drive/My Drive/unsplash-dataset/curated-data")
features_path = unsplash_dataset_path / "features"

# Read the photos table
photos = pd.read_csv(unsplash_dataset_path / "unsplash-custom-metadata.tsv", sep='\t', header=0)

# Load the features and the corresponding IDs
photo_features = np.load(features_path / "features.npy")
photo_ids = pd.read_csv(features_path / "photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])

In [None]:
def encode_search_query_with_text(search_query):
    with torch.no_grad():
        # Encode and normalize the search query using CLIP
        text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

        # Retrieve the feature vector from the GPU and convert it to a numpy array
        return text_encoded.cpu().numpy()


def encode_search_query_with_img(source_image):
    with torch.no_grad():
      image_feature = model.encode_image(preprocess(Image.open(source_image)).unsqueeze(0).to(device))
      image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

      return image_feature.cpu().numpy()

def find_best_matches_for_txt(text_features, photo_features, photo_ids, results_count=12):
  #compute similarity using Cosine similarity
  similarities = list((text_features @ photo_features.T).squeeze(0))
  #sort photos
  # best_photo_idx = (-similarities).argsort()
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)

  # Return the photo IDs of the best matches
  # return [photo_ids[i] for i in best_photo_idx[:results_count]]
  return best_photos


def find_best_matches_for_img(image_features, photo_features, photo_ids, results_count=12):
  #compute similarity using Cosine similarity
  similarities = list((image_features @ photo_features.T).squeeze(0))

  #sort photos
  # best_photo_idx = (-similarities).argsort()
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)

  # Return the photo IDs of the best matches
  # return [photo_ids[i] for i in best_photo_idx[:results_count]]
  return best_photos

**SAMPLE IMAGE QUERY 1:**

In [None]:
from PIL import Image
search_tag1 = "tree"
search_tag2 = "sky"
search_tag3 = "animal"

text_features_fortag1 = encode_search_query_with_text(search_tag1)
text_features_fortag2 = encode_search_query_with_text(search_tag2)
text_features_fortag3 = encode_search_query_with_text(search_tag3)

mixed_features = text_features_fortag1 + text_features_fortag2 + text_features_fortag3
# print(embedding_result)
results = find_best_matches_for_txt(mixed_features, photo_features, photo_ids, 12)
print(results)

[(0.75822824, 813), (0.75822824, 1813), (0.7521702, 223), (0.7521702, 1223), (0.7446105, 82), (0.7446105, 1082), (0.7442293, 239), (0.7442293, 1239), (0.74319625, 316), (0.74319625, 1316), (0.74309325, 129), (0.74309325, 1129), (0.74011827, 321), (0.74011827, 1321), (0.7395566, 616), (0.7395566, 1616), (0.73562235, 829), (0.73562235, 1829), (0.7354013, 649), (0.7354013, 1649), (0.7331871, 520), (0.7331871, 1520), (0.73086655, 411), (0.73086655, 1411), (0.7305293, 155), (0.7305293, 1155), (0.727176, 887), (0.727176, 1887), (0.7263311, 377), (0.7263311, 1377), (0.7257677, 798), (0.7257677, 1798), (0.72053456, 530), (0.72053456, 1530), (0.72028786, 130), (0.72028786, 1130), (0.7196358, 859), (0.7196358, 1859), (0.7189603, 296), (0.7189603, 1296), (0.7149491, 247), (0.7149491, 1247), (0.7143646, 102), (0.7143646, 1102), (0.7138322, 58), (0.7138322, 1058), (0.71247476, 443), (0.71247476, 1443), (0.71120936, 752), (0.71120936, 1752), (0.70980215, 325), (0.70980215, 1325), (0.70911163, 706), 

In [None]:
from IPython.display import HTML, display, Image

image_grid = '<div style="display: flex; flex-wrap: wrap;">'
displayed_photos = set()
i = 0
while i < len(results) and len(displayed_photos) < 12:
    idx = results[i][1]
    photo_id = photo_ids[idx]
    if photo_id not in displayed_photos:
        # Get all metadata for this photo
        photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

        image_url = photo_data["photo_image_url"] + "?w=640"
        photographer = f'{photo_data["photographer_first_name"]} {photo_data["photographer_last_name"]}'

        image_grid += f'''
        <div style="margin: 10px; text-align: center;">
            <img src="{image_url}" style="width: 200px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
            <p>Photo by <a href="https://unsplash.com/?">{photographer}</a></p>
        </div>
        '''

        # Add the displayed photo ID to the set
        displayed_photos.add(photo_id)

    i += 1

image_grid += '</div>'

# Display the image grid
display(HTML(image_grid))

**SAMPLE IMAGE QUERY 2:**

In [None]:
from PIL import Image
search_tag1 = "tower"
search_tag2 = "night"

text_features_fortag1 = encode_search_query_with_text(search_tag1)
text_features_fortag2 = encode_search_query_with_text(search_tag2)

mixed_features = text_features_fortag1 + text_features_fortag2
# print(embedding_result)
results = find_best_matches_for_txt(mixed_features, photo_features, photo_ids, 12)
print(results)

[(0.49887043, 86), (0.49887043, 1086), (0.48654407, 16), (0.48654407, 1016), (0.4798858, 176), (0.4798858, 1176), (0.47382596, 530), (0.47382596, 1530), (0.47292075, 470), (0.47292075, 1470), (0.47080895, 649), (0.47080895, 1649), (0.4695251, 239), (0.4695251, 1239), (0.4687327, 26), (0.4687327, 1026), (0.468498, 256), (0.468498, 1256), (0.46806657, 151), (0.46806657, 1151), (0.46747676, 377), (0.46747676, 1377), (0.46705887, 129), (0.46705887, 1129), (0.4655376, 914), (0.4655376, 1914), (0.46536845, 864), (0.46536845, 1864), (0.46400258, 859), (0.46400258, 1859), (0.46337834, 296), (0.46337834, 1296), (0.46088654, 520), (0.46088654, 1520), (0.46071944, 948), (0.46071944, 1948), (0.459691, 673), (0.459691, 1673), (0.4582977, 153), (0.4582977, 1153), (0.4573453, 136), (0.4573453, 1136), (0.45680386, 972), (0.45680386, 1972), (0.45635116, 551), (0.45635116, 1551), (0.45608312, 411), (0.45608312, 1411), (0.45554495, 798), (0.45554495, 1798), (0.45540208, 899), (0.45540208, 1899), (0.45526

In [None]:
from IPython.display import HTML, display, Image

image_grid = '<div style="display: flex; flex-wrap: wrap;">'
displayed_photos = set()
i = 0
while i < len(results) and len(displayed_photos) < 12:
    idx = results[i][1]
    photo_id = photo_ids[idx]
    if photo_id not in displayed_photos:
        # Get all metadata for this photo
        photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

        image_url = photo_data["photo_image_url"] + "?w=640"
        photographer = f'{photo_data["photographer_first_name"]} {photo_data["photographer_last_name"]}'

        image_grid += f'''
        <div style="margin: 10px; text-align: center;">
            <img src="{image_url}" style="width: 200px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
            <p>Photo by <a href="https://unsplash.com/?">{photographer}</a></p>
        </div>
        '''

        # Add the displayed photo ID to the set
        displayed_photos.add(photo_id)

    i += 1

image_grid += '</div>'

# Display the image grid
display(HTML(image_grid))