# SBERT and Semantic Search

In this notebook, we will demonstrate how SBERT and semantic search can be used as tools to search for a GIF with simple sentences.

We load the relevant libraries in the beginning.

In [None]:
%%capture
!pip install -U pandas pinecone-client sentence-transformers tqdm --q
!wget https://github.com/raingo/TGIF-Release/archive/master.zip
!unzip master.zip


import pandas as pd
from sentence_transformers import SentenceTransformer, util
from PIL import Image
import glob
import torch
import pickle
import zipfile
from IPython.display import display
from IPython.display import Image as IPImage
import os
from tqdm.autonotebook import tqdm
torch.set_num_threads(4)

from IPython.display import HTML
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"




We then load the dataset into a pandas DataFrame. We choose to only include the first 10.000 datapoints to reduce the time it takes to load the data.

In [None]:
df = pd.read_csv(
    "./TGIF-Release-master/data/tgif-v1.0.tsv",
    delimiter="\t",
    names=['url', 'description']
)
df = df[:10000]

In [None]:
df.head()

Unnamed: 0,url,description
0,https://38.media.tumblr.com/9f6c25cc350f12aa74...,"a man is glaring, and someone with sunglasses ..."
1,https://38.media.tumblr.com/9ead028ef62004ef6a...,a cat tries to catch a mouse on a tablet
2,https://38.media.tumblr.com/9f43dc410be85b1159...,a man dressed in red is dancing.
3,https://38.media.tumblr.com/9f659499c8754e40cf...,an animal comes close to another in the jungle
4,https://38.media.tumblr.com/9ed1c99afa7d714118...,a man in a hat adjusts his tie and makes a wei...


Then we load the CLIP model.

The 'clip-ViT-B-32' model refers to a specific version of the CLIP (Contrastive Language-Image Pretraining) model that uses the Vision Transformer (ViT) architecture with a batch size of 32 during training. CLIP is a model developed by OpenAI for learning visual concepts from natural language descriptions. The model has already been trained on a large dataset to understand the relationships between images and text. This enables it to provide meaningful embeddings for sentences or phrases.

In [None]:
%%capture
model = SentenceTransformer('clip-ViT-B-32')

In the next step, we encode the text in the "description" column of the dataframe.

In [None]:
img_emb = model.encode(df['description'], convert_to_tensor=True, show_progress_bar=True)

Batches:   0%|          | 0/313 [00:00<?, ?it/s]

In the above line of code, we also convert the output to a tensor, which is printed below.

In [None]:
img_emb

tensor([[-0.1049,  0.1396, -0.0901,  ...,  0.0105,  0.2349, -0.0315],
        [ 0.1592, -0.0597, -0.5004,  ...,  0.4274, -0.2427,  0.0760],
        [ 0.2001,  0.0212, -0.1524,  ..., -0.0444, -0.0652, -0.0686],
        ...,
        [ 0.0620,  0.0174, -0.2394,  ..., -0.2104,  0.0838, -0.3832],
        [ 0.0455,  0.0706, -0.1184,  ...,  0.2117, -0.4585,  0.1548],
        [-0.1170,  0.3385,  0.1898,  ..., -0.0747,  0.1751, -0.2460]])

In the next cell of code we define the search function.

The search funtion performs a similarity search based on a given query, which can be either an image or a text string.

The entered query is encoded into a numerical representation in the "query_emb" part.

The util.semantic_search function is used to perform a semantic search. It computes the cosine similarity between the encoded query (query_emb) and all the image/gif embeddings (img_emb). The function is set up to give the out for the top k gif results. We have set k=3, so we will get the top 3 hits for the similarity serach.

In [None]:
def search(query, k=3):
    # We encode the query:
    query_emb = model.encode([query], convert_to_tensor=True, show_progress_bar=False)

    # Then, we use the util.semantic_search function, which computes the cosine similarity
    # between the query embedding and all image embeddings, and returns the top_k highest ranked gifs:
    hits = util.semantic_search(query_emb, img_emb, top_k=k)[0]

    print("Query:")
    display(query)

    # We want to display the top-k gifs
    for entry in hits:
        idx = entry['corpus_id']
        pic = Image(url=df.iloc[idx]['url'])
        display(pic)

    return


In [None]:
search("happy")

Query:


'happy'