In [4]:
import clip
import torch
from PIL import Image
import numpy as np
import pandas as pd

# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 141MiB/s]


In [6]:
!#configure paths
from pathlib import Path

unsplash_dataset_path = Path("/content/drive/My Drive/unsplash-dataset/curated-data")
features_path = unsplash_dataset_path / "features"

# Read the photos table
photos = pd.read_csv(unsplash_dataset_path / "unsplash-custom-metadata.tsv", sep='\t', header=0)

# Load the features and the corresponding IDs
photo_features = np.load(features_path / "features.npy")
photo_ids = pd.read_csv(features_path / "photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])

In [7]:
def encode_search_query_with_text(search_query):
    with torch.no_grad():
        # Encode and normalize the search query using CLIP
        text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

        # Retrieve the feature vector from the GPU and convert it to a numpy array
        return text_encoded.cpu().numpy()


def encode_search_query_with_img(source_image):
    with torch.no_grad():
      image_feature = model.encode_image(preprocess(Image.open(source_image)).unsqueeze(0).to(device))
      image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

      return image_feature.cpu().numpy()

def find_best_matches_for_txt(text_features, photo_features, photo_ids, results_count=12):
  #compute similarity using Cosine similarity
  similarities = list((text_features @ photo_features.T).squeeze(0))
  #sort photos
  # best_photo_idx = (-similarities).argsort()
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)

  # Return the photo IDs of the best matches
  # return [photo_ids[i] for i in best_photo_idx[:results_count]]
  return best_photos


def find_best_matches_for_img(image_features, photo_features, photo_ids, results_count=12):
  #compute similarity using Cosine similarity
  similarities = list((image_features @ photo_features.T).squeeze(0))

  #sort photos
  # best_photo_idx = (-similarities).argsort()
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)

  # Return the photo IDs of the best matches
  # return [photo_ids[i] for i in best_photo_idx[:results_count]]
  return best_photos

**SAMPLE IMAGE QUERY 1:**

In [11]:
from PIL import Image
source_image = "/content/drive/My Drive/unsplash-dataset/lite/photos/01ZeHnK3F_4.jpg"

image_features = encode_search_query_with_img(source_image)
# print(embedding_result)
results = find_best_matches_for_txt(image_features, photo_features, photo_ids, 12)
print(results)

[(1.0000002, 80), (1.0000002, 1080), (0.89629686, 935), (0.89629686, 1935), (0.86283624, 851), (0.86283624, 1851), (0.86208534, 448), (0.86208534, 1448), (0.85777986, 254), (0.85777986, 1254), (0.85487974, 107), (0.85487974, 1107), (0.84704596, 856), (0.84704596, 1856), (0.8464372, 243), (0.8464372, 1243), (0.8454921, 235), (0.8454921, 1235), (0.84271, 846), (0.84271, 1846), (0.84268695, 38), (0.84268695, 1038), (0.8424051, 548), (0.8424051, 1548), (0.8423514, 111), (0.8423514, 1111), (0.84119785, 504), (0.84119785, 1504), (0.84107536, 220), (0.84107536, 1220), (0.8406843, 11), (0.8406843, 1011), (0.83986187, 455), (0.83986187, 1455), (0.83801126, 141), (0.83801126, 1141), (0.8373183, 238), (0.8373183, 1238), (0.83368486, 574), (0.83368486, 1574), (0.83234733, 303), (0.83234733, 1303), (0.83131874, 967), (0.83131874, 1967), (0.829472, 899), (0.829472, 1899), (0.8271506, 209), (0.8271506, 1209), (0.8260387, 271), (0.8260387, 1271), (0.8233384, 236), (0.8233384, 1236), (0.8217474, 61), (

In [12]:
from IPython.display import HTML, display, Image

image_grid = '<div style="display: flex; flex-wrap: wrap;">'
displayed_photos = set()
i = 0
while i < len(results) and len(displayed_photos) < 12:
    idx = results[i][1]
    photo_id = photo_ids[idx]
    if photo_id not in displayed_photos:
        # Get all metadata for this photo
        photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

        image_url = photo_data["photo_image_url"] + "?w=640"
        photographer = f'{photo_data["photographer_first_name"]} {photo_data["photographer_last_name"]}'

        image_grid += f'''
        <div style="margin: 10px; text-align: center;">
            <img src="{image_url}" style="width: 200px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
            <p>Photo by <a href="https://unsplash.com/?">{photographer}</a></p>
        </div>
        '''

        # Add the displayed photo ID to the set
        displayed_photos.add(photo_id)

    i += 1

image_grid += '</div>'

# Display the image grid
display(HTML(image_grid))

**SAMPLE IMAGE QUERY 2:**

In [13]:
from PIL import Image
source_image = "/content/drive/My Drive/unsplash-dataset/lite/photos/0P5B9gn8ZQ8.jpg"

image_features = encode_search_query_with_img(source_image)
# print(embedding_result)
results = find_best_matches_for_txt(image_features, photo_features, photo_ids, 12)
print(results)

[(0.99999994, 193), (0.99999994, 1193), (0.93510246, 170), (0.93510246, 1170), (0.90809286, 430), (0.90809286, 1430), (0.9052464, 81), (0.9052464, 1081), (0.89020056, 813), (0.89020056, 1813), (0.87578225, 332), (0.87578225, 1332), (0.8706193, 792), (0.8706193, 1792), (0.869986, 977), (0.869986, 1977), (0.8617822, 325), (0.8617822, 1325), (0.86041206, 53), (0.86041206, 1053), (0.85918367, 280), (0.85918367, 1280), (0.85872746, 306), (0.85872746, 1306), (0.8577782, 928), (0.8577782, 1928), (0.8555966, 403), (0.8555966, 1403), (0.8546505, 752), (0.8546505, 1752), (0.85050726, 620), (0.85050726, 1620), (0.84982115, 247), (0.84982115, 1247), (0.84904605, 678), (0.84904605, 1678), (0.84846556, 82), (0.84846556, 1082), (0.8470636, 462), (0.8470636, 1462), (0.8455086, 930), (0.8455086, 1930), (0.84550804, 413), (0.84550804, 1413), (0.8448726, 474), (0.8448726, 1474), (0.84453976, 68), (0.84453976, 1068), (0.8425505, 173), (0.8425505, 1173), (0.83864063, 705), (0.83864063, 1705), (0.83812475, 

In [14]:
from IPython.display import HTML, display, Image

image_grid = '<div style="display: flex; flex-wrap: wrap;">'
displayed_photos = set()
i = 0
while i < len(results) and len(displayed_photos) < 12:
    idx = results[i][1]
    photo_id = photo_ids[idx]
    if photo_id not in displayed_photos:
        # Get all metadata for this photo
        photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

        image_url = photo_data["photo_image_url"] + "?w=640"
        photographer = f'{photo_data["photographer_first_name"]} {photo_data["photographer_last_name"]}'

        image_grid += f'''
        <div style="margin: 10px; text-align: center;">
            <img src="{image_url}" style="width: 200px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
            <p>Photo by <a href="https://unsplash.com/?">{photographer}</a></p>
        </div>
        '''

        # Add the displayed photo ID to the set
        displayed_photos.add(photo_id)

    i += 1

image_grid += '</div>'

# Display the image grid
display(HTML(image_grid))