# Multilingual Joint Image & Text Embeddings 

This example shows how [SentenceTransformers](https://www.sbert.net) can be used to map images and texts to the same vector space. 

As model, we use the [OpenAI CLIP Model](https://github.com/openai/CLIP), which was trained on a large set of images and image alt texts.

The original CLIP Model only works for English, hence, we used [Multilingual Knowlegde Distillation](https://arxiv.org/abs/2004.09813) to make this model work with 50+ languages.

As a source for fotos, we use the [Unsplash Dataset Lite](https://unsplash.com/data), which contains about 25k images. See the [License](https://unsplash.com/license) about the Unsplash images. 

Note: 25k images is rather small. If you search for really specific terms, the chance are high that no such photo exist in the collection.

In [1]:
!which python

/home/ec2-user/anaconda3/envs/amazonei_pytorch_latest_p36/bin/python


In [3]:
!pip install sentence_transformers

Collecting sentence_transformers
  Using cached sentence-transformers-2.1.0.tar.gz (78 kB)
Collecting torch>=1.6.0
  Downloading torch-1.10.0-cp36-cp36m-manylinux1_x86_64.whl (881.9 MB)
[K     |███████████▊                    | 323.0 MB 143.2 MB/s eta 0:00:04

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[K     |███████████████████████▌        | 647.3 MB 113.1 MB/s eta 0:00:03

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[K     |████████████████████████████████| 881.9 MB 624 bytes/s ta 0:00:01
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 74.4 MB/s eta 0:00:01
Collecting torchvision
  Downloading torchvision-0.11.1-cp36-cp36m-manylinux1_x86_64.whl (23.3 MB)
[K     |████████████████████████████████| 23.3 MB 59.6 MB/s eta 0:00:01
[?25hBuilding wheels for collected packages: sentence-transformers
  Building wheel for sentence-transformers (setup.py) ... [?25ldone
[?25h  Created wheel for sentence-transformers: filename=sentence_transformers-2.1.0-py3-none-any.whl size=121580 sha256=8fa1233010ca355aadf0159cbfc79744a53cca87b95e8c94fd7a202e1e5926a0
  Stored in directory: /home/ec2-user/.cache/pip/wheels/4e/6f/20/06e0c1e209742a37ce7a5a9aa4e420a3abd5081c65b4b34d0a
Successfully built sentence-transformers
Installing collected packages: torch, torchvision, sentencepiece, sentenc

In [42]:
!pip install ftfy

Collecting ftfy
  Downloading ftfy-6.0.3.tar.gz (64 kB)
[K     |████████████████████████████████| 64 kB 4.7 MB/s  eta 0:00:01
Building wheels for collected packages: ftfy
  Building wheel for ftfy (setup.py) ... [?25ldone
[?25h  Created wheel for ftfy: filename=ftfy-6.0.3-py3-none-any.whl size=42256 sha256=f2be7caff5432f051777e4ce0b2a42cb6902b3c4aab1cee4c177867cafdcba1a
  Stored in directory: /home/ec2-user/.cache/pip/wheels/ff/2a/24/75041425faf3347ab146a4a3d0484f723b2c44a7966a06e3f0
Successfully built ftfy
Installing collected packages: ftfy
Successfully installed ftfy-6.0.3
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/amazonei_pytorch_latest_p36/bin/python -m pip install --upgrade pip' command.[0m


In [123]:
from sentence_transformers import SentenceTransformer, util
from PIL import Image, PngImagePlugin
import glob
import torch
import pickle
import numpy as np
import zipfile
import pandas as pd
from IPython.display import display
from IPython.display import Image as IPImage
import os
import tqdm
from tqdm import tqdm as tqdm_n
import hashlib
#from tqdm.autonotebook import tqdm

In [168]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
Image.LOAD_TRUNCATED_IMAGES = True # Otherwise we got ValueError: Decompressed data too large
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

# Paths

In [6]:
# Paths
PATH_SGM = '/home/ec2-user/SageMaker/'
PATH_DATASET = os.path.join(PATH_SGM, 'anonymized-dataset')
PATH_IMGS = os.path.join(PATH_DATASET, 'anonymized_subset')
PATH_SAMPLES = os.path.join(PATH_DATASET, 'anonymized_subset.csv')
PATH_CLIP = os.path.join(PATH_DATASET, "benchmarks", "clip")

# Path all samples
PATH_ALL_SAMPLES = os.path.join(PATH_SGM, 'dataset', 'samples')

# Samples

In [7]:
# Reading all samples
samples_all = pd.read_parquet(PATH_ALL_SAMPLES)

In [118]:
# Load valid answers
path_valid_answers = os.path.join(PATH_CLIP, "valid_answers_anonymized_subset.pickle")
with open(path_valid_answers, 'rb') as handle:
    valid_answers = pickle.load(handle)

In [127]:
# (Anonymized SUBSET) Read samples csv
samples = pd.read_csv(PATH_SAMPLES)
print("Images:", samples.shape[0])

Images: 10000


# Processing

In [128]:
# Replace s3_path by the sample path
replace_s3_path = '/home/ec2-user/SageMaker/dataset/dataset'
samples["s3_path"] = samples["s3_path"].str.replace(replace_s3_path, PATH_IMGS)

# Valid answers

In [129]:
def _compute_valid_answers(data: pd.DataFrame):
    """Generates the valid answers taking into account mutiple images and captions. 
    For the following dataset we will create the dictionary with valid_answers:
    id    caption    hash    |    valid_answers
    0     ABC        X       |    0,1,2,4
    1     EFG        X       |    0,1,4
    2     ABC        Y       |    0,2
    3     HIJ        Z       |    3,
    4     KLM        X       |    0,1,4
    """
    data["cap_hash"] = data["caption"].apply(lambda x : hashlib.md5(str.encode(x)).hexdigest())
    valid_answers = {}

    for i, row in tqdm.tqdm(data.iterrows()):
        idxs_where_duplication = (data["cap_hash"] == row["cap_hash"]) | (data["hash"] == row["hash"])
        list_indexes_duplication = list(np.where(np.array(idxs_where_duplication.to_list()) == True)[0])
        valid_answers[row["img_id"]] = list_indexes_duplication
    return valid_answers

In [130]:
valid_answers = _compute_valid_answers(samples)

10000it [00:15, 656.37it/s]


# CLIP

In [172]:
EMBEDDING_SIZE = 512
batch_size = 1000
batch_t2i = 100
batch_i2t = 100

# Paths of embeddings
path_img_emb = os.path.join(PATH_CLIP, 'img_emb.pt')
path_txt_emb = os.path.join(PATH_CLIP, 'txt_emb.pt')

## Text embeddings

In [18]:
# Multilingual CLIP model
txt_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1')

In [26]:
# Initialize text embedding matrix
txt_inputs = samples["caption"].values
txt_emb = torch.zeros((len(txt_inputs),EMBEDDING_SIZE))

In [27]:
# Keep populating the matrix in batches
for start_index in tqdm.tqdm(range(0, len(txt_inputs), batch_size)):
    txt_batch = txt_inputs[start_index:start_index+batch_size]
    txt_emb[start_index:start_index+batch_size] = txt_model.encode(txt_batch, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False)

100%|██████████| 10/10 [00:31<00:00,  3.16s/it]


In [29]:
# Saving resulting image embeddings
path_txt_emb = os.path.join(PATH_CLIP, 'txt_emb.pt')
torch.save(txt_emb, path_txt_emb)

## Image embeddings

In [98]:
# Encode image embeddings, for embedding images, we need the non-multilingual CLIP model
img_model = SentenceTransformer('clip-ViT-B-32')

In [78]:
# Initialize image embedding matrix
imgs_input = np.array(samples["s3_path"])
img_emb = torch.zeros((len(imgs_input),EMBEDDING_SIZE))

In [79]:
# Keep populating the matrix in batches
# process 1k images, consuming 5GB of GPU and very low RAM since only 1000 images are opened at each iter

for start_index in tqdm.tqdm(range(0, len(imgs_input), batch_size)):
    imgs_batch = imgs_input[start_index:start_index+batch_size]
    img_list = [Image.open(filepath).convert("RGB") for filepath in imgs_batch]
    img_emb[start_index:start_index+batch_size] = img_model.encode(img_list, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False)

  "Palette images with Transparency expressed in bytes should be "
100%|██████████| 10/10 [06:58<00:00, 41.82s/it]


In [82]:
# Saving resulting image embeddings
path_img_emb = os.path.join(PATH_CLIP, 'img_emb.pt')
torch.save(img_emb, path_img_emb)

## I2T and T2I

In [174]:
def hits_2_df(hits, index):
    """
    From the list of lists returned by semantic_search, transforms the data into a dataframe
    with 3 columns: 
    - corpus_id: index of the hit over all the embeddings provided in semantic_search
    - score: score of that index of the hit with the query embedding
    - query_index: index of the sample used as a query. If perfect matching, the highest score of a 
        given query_index should have the same value for corpus_id and query_index.
    """
    df_hits = pd.DataFrame()
    index_hit = index
    for hit in hits:
        df_hit = pd.DataFrame(hit).sort_values('score', ascending=False)
        df_hit["query_index"] = index_hit
        df_hits = pd.concat([df_hits, df_hit])
        index_hit += 1
    return df_hits

In [202]:
def update_recall(index_hit, df_hits, samples, valid_answers, recalls_dict):
    """
    index_hit: index of the queried sample
    df_hits: dataframe resulting from hits_2_df
    samples: dataframe in which we can map the index with the image_id
    valid_answers: dict mapping each image_id to its corresponding index
    recalls_dict: dictionary with primary key the K top hits and, as values, a binary np.array
        with the same size as samples.shape[0], 1 means for that index query we could find in the top K
        hits that indices in the hits predicted by the semantic_search
    """
    df_hit = df_hits[df_hits['query_index'] == index_hit].copy()
    img_id = samples.iloc[index_hit]["img_id"].item()
    valid_answers_query = valid_answers[img_id]
    
    # Iterate over each recall dict
    for k in recalls_dict:
        
        # Only get the top K hits
        df_k = df_hit.head(k)
        
        # Get the predicted K top hits sorted by score, and get its corpus_id (index)
        predicted_hits_query = df_k['corpus_id'].values
        
        # See if those indices of the top K hits intersect with the valid answers indeces
        intersect_hits_answers = np.intersect1d(valid_answers_query, predicted_hits_query)
        
        # if they do, update the recalls_dict in that index position for that top K hits
        if len(intersect_hits_answers) > 0:
            recalls_dict[k][index_hit] = 1
    return recalls_dict

In [184]:
def init_recalls(k_list, length):
    """
    Initializes the binary arrays for each top K recalls that we want to assess
    k_list: list of the top K positions of a given set of ordered hits (i.e [1, 5, 10])
    length: number of total queries that we will make, for each query we will have a 0 or 1 in that position 
        of the array, indicating if we found the query in the top hits (=1) or not (=0)
    """
    r_at_dict = {}
    for k in k_list:
        r_at_dict[k] = np.zeros(length)
    return r_at_dict

In [223]:
def report(task, recall_dict):
    report_dict = {}
    for k in recall_dict:
        report_dict[k] = 100.0 * np.round((np.sum(recall_dict[k]) / len(recall_dict[k])),4)
        print(f"{task}: Recall at {k}: ", np.round(report_dict[k],2), "%")
    return report_dict

### (Retrieval) Text 2 Image

In [211]:
# Load embeddings of images
img_emb = torch.load(path_img_emb)

In [212]:
# Initialize metrics
r_at_t2i = init_recalls([1,5,10], samples.shape[0])

In [213]:
# Iterate over dataset in batch mode
for index in tqdm.tqdm(range(0, samples.shape[0], batch_t2i)):
    
    # Get the rows of the sample batch
    sample_query = samples.iloc[index:index + batch_t2i]
    
    # Get the captions of the batch
    query = sample_query['caption'].tolist()
    
    # Forward it to the model to get the embedding (1x512 torch tensor)
    query_emb = txt_model.encode(query, convert_to_tensor=True, show_progress_bar=False)
    
    # Get the top 10 hits
    hits = util.semantic_search(query_emb, img_emb, top_k=10)
    
    # Get it as a dataframe, adding as column the index of the iteration
    df_hits = hits_2_df(hits, index)
    
    for index_hit in list(range(index,index + batch_t2i)):
        r_at_t2i = update_recall(index_hit, df_hits, samples, valid_answers, r_at_t2i)

100%|██████████| 100/100 [00:48<00:00,  2.06it/s]


In [224]:
report("T2I", r_at_t2i)

T2I: Recall at 1:  6.05 %
T2I: Recall at 5:  16.42 %
T2I: Recall at 10:  23.73 %


{1: 6.05, 5: 16.42, 10: 23.73}

### (Annotation) Image 2 Text

In [218]:
# Load embeddings of captions
txt_emb = torch.load(path_txt_emb)

In [219]:
# Initialize metrics
r_at_i2t = init_recalls([1,5,10], samples.shape[0])

In [220]:
# Iterate over dataset in batch mode
for index in tqdm.tqdm(range(0, samples.shape[0], batch_i2t)):
    
    # Get the rows of the sample batch
    sample_query = samples.iloc[index:index + batch_i2t]
    batch_img_paths = sample_query['s3_path'].tolist()
    
    # Get the images of the batch
    query = [Image.open(filepath).convert("RGB") for filepath in batch_img_paths]
    
    # Forward it to the model to get the embedding (1x512 torch tensor)
    query_emb = img_model.encode(query, convert_to_tensor=True, show_progress_bar=False)
    
    # Get the top 10 hits
    hits = util.semantic_search(query_emb, txt_emb, top_k=10)
    
    # Get it as a dataframe, adding as column the index of the iteration
    df_hits = hits_2_df(hits, index)
    
    for index_hit in list(range(index,index + batch_i2t)):
        r_at_i2t = update_recall(index_hit, df_hits, samples, valid_answers, r_at_i2t)

100%|██████████| 100/100 [06:06<00:00,  3.67s/it]


In [225]:
report("I2T", r_at_i2t)

I2T: Recall at 1:  5.42 %
I2T: Recall at 5:  16.63 %
I2T: Recall at 10:  22.19 %


{1: 5.42, 5: 16.63, 10: 22.189999999999998}