In [None]:
from IPython.display import HTML

In [None]:
from IPython.core.interactiveshell import InteractiveShell

In [None]:
InteractiveShell.ast_node_interactivity = "all"

In [13]:
import pandas as pd

# Load dataset to a pandas dataframe
df = pd.read_csv(
    "./TGIF-Release-master/data/tgif-v1.0.tsv",
    delimiter="\t",
    names=['url', 'description']
)
df.head()

Unnamed: 0,url,description
0,https://38.media.tumblr.com/9f6c25cc350f12aa74...,"a man is glaring, and someone with sunglasses ..."
1,https://38.media.tumblr.com/9ead028ef62004ef6a...,a cat tries to catch a mouse on a tablet
2,https://38.media.tumblr.com/9f43dc410be85b1159...,a man dressed in red is dancing.
3,https://38.media.tumblr.com/9f659499c8754e40cf...,an animal comes close to another in the jungle
4,https://38.media.tumblr.com/9ed1c99afa7d714118...,a man in a hat adjusts his tie and makes a wei...


In [14]:
for _, gif in df[:5].iterrows():
  HTML(f"<img src={gif['url']} style='width:120px; height:90px'>")
  print(gif["description"])
  

a man is glaring, and someone with sunglasses appears.


a cat tries to catch a mouse on a tablet


a man dressed in red is dancing.


an animal comes close to another in the jungle


a man in a hat adjusts his tie and makes a weird face.


In [None]:
from sentence_transformers import SentenceTransformer

# Initialize retriever with SentenceTransformer model 
retriever = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
retriever

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [31]:

import pinecone

pinecone.init(
    api_key="API KEY",
    environment="us-west1-gcp"
)

index_name = 'gif-search'

if index_name not in pinecone.list_indexes():
    pinecone.create_index(
        index_name,
        dimension=384,
        metric="cosine"
    )

index = pinecone.Index(index_name)

In [None]:
from tqdm.auto import tqdm

# we will use batches of 64
batch_size = 64

for i in tqdm(range(0, len(df), batch_size)):
    i_end = min(i+batch_size, len(df))
    batch = df.iloc[i:i_end]
    emb = retriever.encode(batch['description'].tolist()).tolist()
    meta = batch.to_dict(orient='records')
    ids = [f"{idx}" for idx in range(i, i_end)]
    to_upsert = list(zip(ids, emb, meta))
    _ = index.upsert(vectors=to_upsert)


  0%|          | 0/1966 [00:00<?, ?it/s]

{'dimension': 384,
 'index_fullness': 0.1,
 'namespaces': {'': {'vector_count': 125782}}}

In [11]:
def search_gif(query):
    xq = retriever.encode(query).tolist()
    xc = index.query(xq, top_k=10,
                    include_metadata=True)
    result = []
    for context in xc['matches']:
        url = context['metadata']['url']
        result.append(url)
    return result
def display_gif(urls):
    figures = []
    for url in urls:
        figures.append(f'''
            <figure style="margin: 5px !important;">
              <img src="{url}" style="width: 120px; height: 90px" >
            </figure>
        ''')
    return HTML(data=f'''
        <div style="display: flex; flex-flow: row wrap; text-align: center;">
        {''.join(figures)}
        </div>
    ''')

In [16]:
display_gif(search_gif("cat walking"))

In [19]:
import gradio

In [33]:
import gradio as gr

def gif_search(query):
    urls = search_gif(query)
    figures = []
    for url in urls:
        figures.append(f'''
            <figure style="margin: 5px !important;">
              <img src="{url}" style="width: 120px; height: 90px" >
            </figure>
        ''')
    return f'''
        <div style="display: flex; flex-flow: row wrap; text-align: center;">
        {''.join(figures)}
        </div>
    '''


interface = gr.Interface(
    gif_search,
    inputs = gr.Textbox(lines=1, placeholder="Enter description of GIF to search for!"),
    outputs = gr.HTML(label="Related Gifs")
)


interface.launch(debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://15947.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.


(<gradio.routes.App at 0x7f382107a250>,
 'http://127.0.0.1:7861/',
 'https://15947.gradio.app')