In [47]:
import clip
import torch
from PIL import Image
import numpy as np
import pandas as pd

# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [48]:
#configure paths
from pathlib import Path

unsplash_dataset_path = Path("/content/drive/My Drive/unsplash-dataset/curated-data")
features_path = unsplash_dataset_path / "features"

# Read the photos table
photos = pd.read_csv(unsplash_dataset_path / "unsplash-custom-metadata.tsv", sep='\t', header=0)

# Load the features and the corresponding IDs
photo_features = np.load(features_path / "features.npy")
photo_ids = pd.read_csv(features_path / "photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])

In [49]:
def encode_search_query_with_text(search_query):
    with torch.no_grad():
        # Encode and normalize the search query using CLIP
        text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

        # Retrieve the feature vector from the GPU and convert it to a numpy array
        return text_encoded.cpu().numpy()


def encode_search_query_with_img(source_image):
    with torch.no_grad():
      image_feature = model.encode_image(preprocess(Image.open(source_image)).unsqueeze(0).to(device))
      image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

      return image_feature.cpu().numpy()

def find_best_matches_for_txt(text_features, photo_features, photo_ids, results_count=12):
  #compute similarity using Cosine similarity
  similarities = list((text_features @ photo_features.T).squeeze(0))
  #sort photos
  # best_photo_idx = (-similarities).argsort()
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)

  # Return the photo IDs of the best matches
  # return [photo_ids[i] for i in best_photo_idx[:results_count]]
  return best_photos


def find_best_matches_for_img(image_features, photo_features, photo_ids, results_count=12):
  #compute similarity using Cosine similarity
  similarities = list((image_features @ photo_features.T).squeeze(0))

  #sort photos
  # best_photo_idx = (-similarities).argsort()
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)

  # Return the photo IDs of the best matches
  # return [photo_ids[i] for i in best_photo_idx[:results_count]]
  return best_photos

**SAMPLE SEARCH QUERY 1 :**

In [50]:
search_query = "sky and tree"
text_features = encode_search_query_with_text(search_query)
# print(embedding_result)
results = find_best_matches_for_txt(text_features, photo_features, photo_ids, 12)
print(results)

[(0.2785635, 257), (0.2785635, 1257), (0.27304146, 616), (0.27304146, 1616), (0.26252246, 239), (0.26252246, 1239), (0.2624064, 813), (0.2624064, 1813), (0.26218143, 177), (0.26218143, 1177), (0.26068452, 411), (0.26068452, 1411), (0.26046553, 247), (0.26046553, 1247), (0.2583352, 130), (0.2583352, 1130), (0.25722945, 791), (0.25722945, 1791), (0.25622356, 930), (0.25622356, 1930), (0.2530767, 954), (0.2530767, 1954), (0.25166693, 223), (0.25166693, 1223), (0.25136182, 19), (0.25136182, 1019), (0.25105318, 887), (0.25105318, 1887), (0.24977273, 520), (0.24977273, 1520), (0.24838474, 321), (0.24838474, 1321), (0.2470679, 934), (0.2470679, 1934), (0.24680445, 82), (0.24680445, 1082), (0.2466847, 719), (0.2466847, 1719), (0.24667928, 668), (0.24667928, 1668), (0.2461365, 980), (0.2461365, 1980), (0.2443912, 706), (0.2443912, 1706), (0.24419133, 39), (0.24419133, 1039), (0.24389856, 876), (0.24389856, 1876), (0.24360614, 111), (0.24360614, 1111), (0.2434469, 176), (0.2434469, 1176), (0.241

In [54]:
from IPython.display import HTML, display, Image

image_grid = '<div style="display: flex; flex-wrap: wrap;">'
displayed_photos = set()
i = 0
while i < len(results) and len(displayed_photos) < 12:
    idx = results[i][1]
    photo_id = photo_ids[idx]
    if photo_id not in displayed_photos:
        # Get all metadata for this photo
        photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

        image_url = photo_data["photo_image_url"] + "?w=640"
        photographer = f'{photo_data["photographer_first_name"]} {photo_data["photographer_last_name"]}'

        image_grid += f'''
        <div style="margin: 10px; text-align: center;">
            <img src="{image_url}" style="width: 200px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
            <p>Photo by <a href="https://unsplash.com/?">{photographer}</a></p>
        </div>
        '''

        # Add the displayed photo ID to the set
        displayed_photos.add(photo_id)

    i += 1

image_grid += '</div>'

# Display the image grid
display(HTML(image_grid))

**SAMPLE SEARCH QUERY 2 :**

In [63]:
search_query = "snowscape mountains"
text_features = encode_search_query_with_text(search_query)
# print(embedding_result)
results = find_best_matches_for_txt(text_features, photo_features, photo_ids, 12)
print(results)

[(0.2959162, 483), (0.2959162, 1483), (0.29492772, 148), (0.29492772, 1148), (0.2905812, 513), (0.2905812, 1513), (0.29055712, 443), (0.29055712, 1443), (0.28861347, 553), (0.28861347, 1553), (0.28831023, 786), (0.28831023, 1786), (0.2877366, 310), (0.2877366, 1310), (0.28675038, 982), (0.28675038, 1982), (0.28650936, 74), (0.28650936, 1074), (0.2845649, 226), (0.2845649, 1226), (0.2817307, 894), (0.2817307, 1894), (0.28087232, 546), (0.28087232, 1546), (0.2808476, 735), (0.2808476, 1735), (0.28003037, 581), (0.28003037, 1581), (0.27863485, 519), (0.27863485, 1519), (0.2765683, 461), (0.2765683, 1461), (0.27628586, 61), (0.27628586, 1061), (0.27615833, 806), (0.27615833, 1806), (0.27552134, 339), (0.27552134, 1339), (0.274112, 797), (0.274112, 1797), (0.27300602, 955), (0.27300602, 1955), (0.27211016, 952), (0.27211016, 1952), (0.27192348, 236), (0.27192348, 1236), (0.2717014, 336), (0.2717014, 1336), (0.2707968, 496), (0.2707968, 1496), (0.26961896, 527), (0.26961896, 1527), (0.268269

In [64]:
from IPython.display import HTML, display, Image

image_grid = '<div style="display: flex; flex-wrap: wrap;">'
displayed_photos = set()
i = 0
while i < len(results) and len(displayed_photos) < 12:
    idx = results[i][1]
    photo_id = photo_ids[idx]
    if photo_id not in displayed_photos:
        # Get all metadata for this photo
        photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

        image_url = photo_data["photo_image_url"] + "?w=640"
        photographer = f'{photo_data["photographer_first_name"]} {photo_data["photographer_last_name"]}'

        image_grid += f'''
        <div style="margin: 10px; text-align: center;">
            <img src="{image_url}" style="width: 200px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
            <p>Photo by <a href="https://unsplash.com/?">{photographer}</a></p>
        </div>
        '''

        # Add the displayed photo ID to the set
        displayed_photos.add(photo_id)

    i += 1

image_grid += '</div>'

# Display the image grid
display(HTML(image_grid))