# Image Search with Retrieve & Re-Rank
This example shows how to use our image search model to search through the MSCOCO test set with 5k images
using either our retrieve and re-rank approach for efficient and accurate results.

You will need to download ~1.2 GB of data for this example (model and image features).

Note: 5k images is rather small. If you search for really specific terms, the chance are high that no such photo exist in the collection.

In [1]:
from mmt_retrieval import MultimodalTransformer, ImageTextRetrieval
from sentence_transformers import util
import glob
import torch
import pickle
import zipfile
from IPython.display import display
from IPython.display import Image as IPImage
import os
import requests

model_path = "https://public.ukp.informatik.tu-darmstadt.de/reimers/mmt-retrieval/models/v1/oscar_join_mscoco.zip"
model = MultimodalTransformer(model_name_or_path=model_path)

2021-03-19 15:14:28 - Load pretrained SentenceTransformer: https://public.ukp.informatik.tu-darmstadt.de/reimers/mmt-retrieval/models/v1/oscar_join_mscoco.zip
2021-03-19 15:14:28 - Load SentenceTransformer from folder: C:\Users\Gregor/.cache\torch\sentence_transformers\public.ukp.informatik.tu-darmstadt.de_reimers_mmt-retrieval_models_v1_oscar_join_mscoco
2021-03-19 15:14:29 - BertImgModel Image Dimension: 2054
2021-03-19 15:14:29 - Use pytorch device: cuda


In [2]:

feature_file = 'test_img_frcnn_feats.pt' # from the MSCOCO dataset downloaded from OSCAR
if not os.path.exists(feature_file):
    zip_save_path = 'mscoco_test_img_frcnn_feats.zip'
    util.http_get('https://public.ukp.informatik.tu-darmstadt.de/reimers/mmt-retrieval/datasets/mscoco_test_img_frcnn_feats.zip', zip_save_path)
    with zipfile.ZipFile(zip_save_path, 'r') as zip:
        zip.extractall(".")
    os.remove(zip_save_path)


In [3]:
model.image_dict.load_oscar_format_image_features(feature_file)
embedding_storage_file = None
image_searcher = ImageTextRetrieval(images=list(model.image_dict.keys()), embedder=model, cross_encoder=model, embedding_batchsize=32, cross_encoder_batchsize=32)
embeddings = image_searcher.create_embeddings() # the first call to image_searcher.search() also calls this but we do it explicitly in this example

2021-03-19 15:14:31 - Creating embeddings for the images


In [6]:
def display_imageid(imageid):
    url = f"http://images.cocodataset.org/train2014/COCO_train2014_{int(imageid):012d}.jpg"
    r = requests.get(url)
    if r.status_code == 404:
        url = f"http://images.cocodataset.org/val2014/COCO_val2014_{int(imageid):012d}.jpg"
    display(IPImage(url=url, width=200))

def search(query, display_k=3, retrieve_k=10):
    hits = image_searcher.search(text_queries=query, topk=retrieve_k)["images"][0][:display_k]
    print("Query:")
    display(query)
    for hit in hits:
        print(hit)
        display_imageid(hit)
        

In [7]:
search("a golden dog")

Queries: 100%|██████████| 1/1 [00:00<00:00,  7.56it/s]
Query:


'a golden dog'

310757


104421


205854


In [10]:
search("people in a forest")

Queries: 100%|██████████| 1/1 [00:00<00:00,  7.10it/s]
Query:


'people in a forest'

527625


122161


566634


In [17]:
search("a sunset in the city")

Queries: 100%|██████████| 1/1 [00:00<00:00,  5.46it/s]
Query:


'a sunset in the city'

283277


565012


192594


In [16]:
search("a dog in a park")

Queries: 100%|██████████| 1/1 [00:00<00:00,  6.57it/s]
Query:


'a dog in a park'

497014


515668


71004


In [14]:
search("a cat outside")

Queries: 100%|██████████| 1/1 [00:00<00:00,  7.05it/s]
Query:


'a cat outside'

360943


115070


147425


In [20]:
search("people playing baseball")

Queries: 100%|██████████| 1/1 [00:00<00:00,  6.29it/s]
Query:


'people playing baseball'

466422


515241


78707
