In [11]:
import clip
import torch
from PIL import Image
import numpy as np
import pandas as pd

# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [12]:
#configure paths
from pathlib import Path

unsplash_dataset_path = Path("/content/drive/My Drive/unsplash-dataset/curated-data")
features_path = unsplash_dataset_path / "features"

# Read the photos table
photos = pd.read_csv(unsplash_dataset_path / "unsplash-custom-metadata.tsv", sep='\t', header=0)

# Load the features and the corresponding IDs
photo_features = np.load(features_path / "features.npy")
photo_ids = pd.read_csv(features_path / "photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])

In [13]:
def encode_search_query_with_text(search_query):
    with torch.no_grad():
        # Encode and normalize the search query using CLIP
        text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

        # Retrieve the feature vector from the GPU and convert it to a numpy array
        return text_encoded.cpu().numpy()


def encode_search_query_with_img(source_image):
    with torch.no_grad():
      image_feature = model.encode_image(preprocess(Image.open(source_image)).unsqueeze(0).to(device))
      image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

      return image_feature.cpu().numpy()

def find_best_matches_for_txt(text_features, photo_features, photo_ids, results_count=12):
  #compute similarity using Cosine similarity
  similarities = list((text_features @ photo_features.T).squeeze(0))
  #sort photos
  # best_photo_idx = (-similarities).argsort()
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)

  # Return the photo IDs of the best matches
  # return [photo_ids[i] for i in best_photo_idx[:results_count]]
  return best_photos


def find_best_matches_for_img(image_features, photo_features, photo_ids, results_count=12):
  #compute similarity using Cosine similarity
  similarities = list((image_features @ photo_features.T).squeeze(0))

  #sort photos
  # best_photo_idx = (-similarities).argsort()
  best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)

  # Return the photo IDs of the best matches
  # return [photo_ids[i] for i in best_photo_idx[:results_count]]
  return best_photos

**SAMPLE TEXT and text-to-IMAGE QUERY 1:**

In [27]:
from PIL import Image
source_image = "/content/drive/My Drive/unsplash-dataset/curated-data/photos/0Hr2m3V_w1Q.jpg"
search_text = "no snow only green"

text_features= encode_search_query_with_text(search_text)
image_features= encode_search_query_with_img(source_image)

mixed_features = text_features + image_features
# print(embedding_result)
results = find_best_matches_for_txt(mixed_features, photo_features, photo_ids, 12)
print(results)

[(1.2406509, 782), (1.2406509, 1782), (1.1637863, 582), (1.1637863, 1582), (1.1346834, 964), (1.1346834, 1964), (1.1286073, 516), (1.1286073, 1516), (1.1260595, 332), (1.1260595, 1332), (1.1217403, 328), (1.1217403, 1328), (1.1120178, 639), (1.1120178, 1639), (1.1028626, 186), (1.1028626, 1186), (1.1005418, 752), (1.1005418, 1752), (1.1003553, 325), (1.1003553, 1325), (1.0931964, 823), (1.0931964, 1823), (1.0879605, 977), (1.0879605, 1977), (1.0873147, 615), (1.0873147, 1615), (1.0835937, 131), (1.0835937, 1131), (1.0804856, 424), (1.0804856, 1424), (1.0779033, 460), (1.0779033, 1460), (1.0738331, 459), (1.0738331, 1459), (1.0662386, 323), (1.0662386, 1323), (1.0602095, 760), (1.0602095, 1760), (1.0564735, 698), (1.0564735, 1698), (1.0536212, 620), (1.0536212, 1620), (1.0519043, 464), (1.0519043, 1464), (1.051744, 869), (1.051744, 1869), (1.0486367, 388), (1.0486367, 1388), (1.0480113, 642), (1.0480113, 1642), (1.0474973, 280), (1.0474973, 1280), (1.0461307, 129), (1.0461307, 1129), (1

In [28]:
from IPython.display import HTML, display, Image

image_grid = '<div style="display: flex; flex-wrap: wrap;">'
displayed_photos = set()
i = 0
while i < len(results) and len(displayed_photos) < 12:
    idx = results[i][1]
    photo_id = photo_ids[idx]
    if photo_id not in displayed_photos:
        # Get all metadata for this photo
        photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

        image_url = photo_data["photo_image_url"] + "?w=640"
        photographer = f'{photo_data["photographer_first_name"]} {photo_data["photographer_last_name"]}'

        image_grid += f'''
        <div style="margin: 10px; text-align: center;">
            <img src="{image_url}" style="width: 200px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
            <p>Photo by <a href="https://unsplash.com/?">{photographer}</a></p>
        </div>
        '''

        # Add the displayed photo ID to the set
        displayed_photos.add(photo_id)

    i += 1

image_grid += '</div>'

# Display the image grid
display(HTML(image_grid))

**SAMPLE TEXT and text-to-IMAGE QUERY 2:**

In [29]:
from PIL import Image
source_image = "/content/drive/My Drive/unsplash-dataset/curated-data/photos/2Fs6MBBeayk.jpg"
search_text = "rocky mountains"

text_features= encode_search_query_with_text(search_text)
image_features= encode_search_query_with_img(source_image)

mixed_features =  image_features + text_features
# print(embedding_result)
results = find_best_matches_for_txt(mixed_features, photo_features, photo_ids, 12)
print(results)

[(1.2193798, 49), (1.2193798, 1049), (1.1464863, 171), (1.1464863, 1171), (1.1076661, 66), (1.1076661, 1066), (1.0993869, 480), (1.0993869, 1480), (1.0982395, 334), (1.0982395, 1334), (1.0846033, 98), (1.0846033, 1098), (1.084339, 933), (1.084339, 1933), (1.0731936, 360), (1.0731936, 1360), (1.0655466, 539), (1.0655466, 1539), (1.0635585, 358), (1.0635585, 1358), (1.0626488, 210), (1.0626488, 1210), (1.061558, 34), (1.061558, 1034), (1.0591525, 92), (1.0591525, 1092), (1.0521789, 3), (1.0521789, 1003), (1.0515772, 287), (1.0515772, 1287), (1.0508496, 546), (1.0508496, 1546), (1.0399197, 406), (1.0399197, 1406), (1.0396962, 341), (1.0396962, 1341), (1.038203, 9), (1.038203, 1009), (1.0355928, 238), (1.0355928, 1238), (1.0352267, 70), (1.0352267, 1070), (1.0330048, 421), (1.0330048, 1421), (1.0319163, 626), (1.0319163, 1626), (1.0290236, 141), (1.0290236, 1141), (1.0256445, 150), (1.0256445, 1150), (1.0254991, 369), (1.0254991, 1369), (1.0240079, 76), (1.0240079, 1076), (1.0221628, 671),

In [30]:
from IPython.display import HTML, display, Image

image_grid = '<div style="display: flex; flex-wrap: wrap;">'
displayed_photos = set()
i = 0
while i < len(results) and len(displayed_photos) < 12:
    idx = results[i][1]
    photo_id = photo_ids[idx]
    if photo_id not in displayed_photos:
        # Get all metadata for this photo
        photo_data = photos[photos["photo_id"] == photo_id].iloc[0]

        image_url = photo_data["photo_image_url"] + "?w=640"
        photographer = f'{photo_data["photographer_first_name"]} {photo_data["photographer_last_name"]}'

        image_grid += f'''
        <div style="margin: 10px; text-align: center;">
            <img src="{image_url}" style="width: 200px; height: 200px; object-fit: cover; border: 1px solid #ddd;">
            <p>Photo by <a href="https://unsplash.com/?">{photographer}</a></p>
        </div>
        '''

        # Add the displayed photo ID to the set
        displayed_photos.add(photo_id)

    i += 1

image_grid += '</div>'

# Display the image grid
display(HTML(image_grid))