This Colab generates language and image features to be used with pretrained image--language transformers.  It then allows you to use our released models to determine if an image-text pair match!

This replicates our retrieval results in our TACL 2021 paper:

[Decoupling the Role of Data, Attention, and Losses in Multimodal Transformers](https://arxiv.org/abs/2102.00529)

Paper Authors:  Lisa Anne Hendricks, John Mellor, Rosalia Schneider, Jean-Baptiste Alayrac, and Aida Nematzadeh

We also thank Sebastian Borgeaud and Cyprien de Masson d'Autume for their text preprocessing code.

# Preproccessing Language and Images

First, we use a detector to extract image features and SentencePiece to extract language tokens.

In [3]:
# Make sure to follow the Setup\Instructions.md steps first.

import os
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
from io import BytesIO as StringIO
from PIL import Image

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import matplotlib.image as mpimg
import unicodedata

In [4]:
# Replace with the path where you setup the models repo, when you followed the Setup\Instructions.md steps
%cd ..\tf-models\research


D:\projects.py\tf-models\research


In [5]:
from object_detection.utils import visualization_utils as vis_util
from object_detection.utils import label_map_util
from object_detection.core import standard_fields as fields

In [6]:
!wget  https://storage.googleapis.com/dm-mmt-models/spiece.model -P '/tmp'

Attempting to download
    from: 'https://storage.googleapis.com/dm-mmt-models/spiece.model'
      to: 'C:\Users\Mikey\AppData\Local\Temp\spiece.model'


In [7]:
features = {} # input to our model 

## Language Preprocessing

### Helper Functions for Preprocessing Text

In [8]:
SPIECE_UNDERLINE = '▁'  # pylint: disable=invalid-encoded-data

special_symbols = {
    '<cls>': 3,
    '<sep>': 4,
    '<pad>': 5,
    '<mask>': 6,
}
CLS_ID = special_symbols['<cls>']
SEP_ID = special_symbols['<sep>']
PAD_ID = special_symbols['<pad>']
MASK_ID = special_symbols['<mask>']

def is_start_piece(piece):
  """Returns True if the piece is a start piece for a word/symbol."""
  special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
  if piece.startswith(SPIECE_UNDERLINE):
    return True
  if piece.startswith('<'):
    return True
  if piece in special_pieces:
    return True
  return False


def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
  """Preprocess the inputs."""
  if remove_space:
    outputs = ' '.join(inputs.strip().split())
  else:
    outputs = inputs
  outputs = outputs.replace('``', '"').replace('\'\'', '"')

  if not keep_accents:
    outputs = unicodedata.normalize('NFKD', outputs)
    outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
  if lower:
    outputs = outputs.lower()

  return outputs


def encode_pieces(sp_model, text, sample=False):
  """Encode the text to pieces using the given SentencePiece model sp_model."""
  if not sample:
    pieces = sp_model.EncodeAsPieces(text)
  else:
    pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
  new_pieces = []
  for piece in pieces:
    if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
      cur_pieces = sp_model.EncodeAsPieces(
          piece[:-1].replace(SPIECE_UNDERLINE, ''))
      if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
        if len(cur_pieces[0]) == 1:
          cur_pieces = cur_pieces[1:]
        else:
          cur_pieces[0] = cur_pieces[0][1:]
      cur_pieces.append(piece[-1])
      new_pieces.extend(cur_pieces)
    else:
      new_pieces.append(piece)

  return new_pieces


def encode_ids(sp_model, text, sample=False):
  pieces = encode_pieces(sp_model, text, sample=sample)
  ids = [sp_model.PieceToId(piece) for piece in pieces]
  return ids


def tokens_to_word_indices(sp_model, tokens, offset=0):
    """Compute the word ids for the tokens.

    The word indices start at offset, each time a new word is encountered, the
    word id is increased by 1.

    Args:
      tokens: `list` of `int` SentencePiece tokens
      offset: `int` start index

    Returns:
      A `list` of increasing integers. If element i and j are identical, then
      tokens[i] and tokens[j] are part of the same word.
    """
    word_indices = []
    current_index = offset
    for i, token in enumerate(tokens):
      token_piece = sp_model.IdToPiece(token)
      if i > 0 and is_start_piece(token_piece):
        current_index += 1
      word_indices.append(current_index)

    return word_indices

### Load the SentencePiece Model

In [10]:
import sentencepiece as sp
spm_path = os.environ['temp'] + '\spiece.model'
spm = sp.SentencePieceProcessor()
spm.Load(spm_path)

True

### Preprocessing Captions 

In [11]:
def create_sentence_features(seq_len, spm, captions, max_sentence_number=1):
  def _add_sentence_pad():
    pad_number = max_sentence_number - len(captions)
    for _ in range(pad_number):
      all_sents['tokens'] += [MASK_ID] * seq_len
      all_sents['segment_ids'] += [0] * seq_len
      all_sents['padding_mask'] += [1] * seq_len
      all_sents['word_ids'] += [-2] * seq_len

  # Limit the sentence length to seq_len
  # We concatenate all sentences after checking the seq len and adding
  # padding
  all_sents = {}
  for k in ['tokens', 'segment_ids', 'padding_mask', 'word_ids']:
    all_sents[k] = []

  for sentence in captions:
    sentence = preprocess_text(sentence, remove_space=True, lower=True, keep_accents=False)

    tokens = encode_ids(spm, sentence)
    if len(tokens) >= seq_len - 2:
      tokens = tokens[:seq_len - 2]  # since we add two symbols

    word_ids = tokens_to_word_indices(spm, tokens)
    word_ids = ([-1] + word_ids + [-1])
    # Need to create segment ids before adding special symbols to tokens
    segment_ids = ([0] +  # SEP
                    [0] * len(tokens) + [2]  # CLS
                  )
    tokens = ([SEP_ID] + tokens + [CLS_ID])
    padding_mask = [0] * len(tokens)
    # Note, we add padding at the start so that the last token is always [CLS]

    if len(tokens) < seq_len:
      padding_len = seq_len - len(tokens)
      tokens = [MASK_ID] * padding_len + tokens
      
      segment_ids = [0] * padding_len + segment_ids
      padding_mask = [1] * padding_len + padding_mask
      word_ids = [-2] * padding_len + word_ids


    assert len(tokens) == seq_len
    assert len(segment_ids) == seq_len
    assert len(padding_mask) == seq_len
    assert len(word_ids) == seq_len

    all_sents['tokens'] += tokens
    all_sents['segment_ids'] += segment_ids
    all_sents['padding_mask'] += padding_mask
    all_sents['word_ids'] += word_ids

  # Add padding sentences to the end so that each example has
  # max_sentence_number
  if len(captions) < max_sentence_number:
    _add_sentence_pad()

  return {
      'text/token_ids': np.array(all_sents['tokens'], dtype=np.int32),
      'text/segment_ids': np.array(all_sents['segment_ids'], dtype=np.int32),
      'text/padding_mask': np.array(all_sents['padding_mask'], dtype=np.int32),
      'text/word_ids': np.array(all_sents['word_ids'], dtype=np.int32),
      'text/sentence_num': len(captions),
  }
         


Get features for an example caption.

In [12]:
features = create_sentence_features(seq_len=25, spm=spm, captions=['A man with a backpack holding a kitten.'])
print(features)

{'text/token_ids': array([    6,     6,     6,     6,     6,     6,     6,     6,     6,
           6,     6,     6,     6,     6,     4,    24,   326,    33,
          24, 14559,  1757,    24, 22968,     9,     3]), 'text/segment_ids': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 2]), 'text/padding_mask': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0]), 'text/word_ids': array([-2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1,  0,  1,
        2,  3,  4,  5,  6,  7,  8, -1]), 'text/sentence_num': 1}


## Image Preprocessing

###Load the Pretrained Object Detector 

In [13]:
def LoadInferenceGraph(inference_graph_path):
  """Loads inference graph into tensorflow Graph object.

  Args:
    inference_graph_path: Path to inference graph.

  Returns:
    a tf.Graph object.
  """
  od_graph = tf.Graph()
  with od_graph.as_default():
    od_graph_def = tf.GraphDef()
    with open(inference_graph_path, 'rb') as fid:
      serialized_graph = fid.read()
      od_graph_def.ParseFromString(serialized_graph)
      tf.import_graph_def(od_graph_def, name='')
  return od_graph

Download the pretrained object detector.

In [14]:
!wget --no-check-certificate https://storage.googleapis.com/dm-mmt-models/frozen_inference_graph.pb -P '/tmp'

Attempting to download
    from: 'https://storage.googleapis.com/dm-mmt-models/frozen_inference_graph.pb'
      to: 'C:\Users\Mikey\AppData\Local\Temp\frozen_inference_graph.pb'


In [15]:
detection_graph = LoadInferenceGraph(os.environ['temp'] + '\frozen_inference_graph.pb')
print ('Successfully loaded frozen model from {}'.format('https://storage.googleapis.com/dm-mmt-models/frozen_inference_graph.pb'))

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\Mikey\\AppData\\Local\\Tempfrozen_inference_graph.pb'

### Load an Example Image

In [None]:
def LoadImageIntoNumpyArray(path):

  with open(path, 'rb') as img_file: 
    img = mpimg.imread(img_file)
    (im_width, im_height) = img.shape[:2]
    return img[:,:,:3].astype(np.uint8)

In [None]:
# Download the image
!wget --no-check-certificate https://storage.googleapis.com/dm-mmt-models/COCO_val2014_000000570107.jpeg -P '/tmp/' 
image_np = LoadImageIntoNumpyArray('/tmp/COCO_val2014_000000570107.jpeg')

print('image type: %s' % str(image_np.dtype))
print('image shape: %s' % str(image_np.shape))


###Preprocessing Images

Loading the object-label mappings for the dectector.

In [None]:
!wget --no-check-certificate https://storage.googleapis.com/dm-mmt-models/objatt_labelmap.txt -P '/tmp/' 
label_map_path = '/tmp/objatt_labelmap.txt'
categories = label_map_util.create_categories_from_labelmap(label_map_path, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

Running inference on the object detector for a single image.

In [None]:
def RunInferenceSingleImage(image, graph):
  """Run single image through tensorflow object detection graph.

  This function runs an inference graph (frozen using the functions provided
  in this file) on a (single) provided image and returns inference results in
  numpy arrays.

  Args:
    image: uint8 numpy array with shape (img_height, img_width, 3)
    graph: tensorflow graph object holding loaded model.  This graph can be
      obtained by running the LoadInferenceGraph function above.

  Returns:
    output_dict: a dictionary holding the following entries:
      `num_detections`: an integer
      `detection_boxes`: a numpy (float32) array of shape [N, 4]
      `detection_classes`: a numpy (uint8) array of shape [N]
      `detection_scores`: a numpy (float32) array of shape [N]
      `detection_masks`: a numpy (uint8) array of shape
         [N, image_height, image_width] with values in {0, 1}
      `detection_keypoints`: a numpy (float32) array of shape
         [N, num_keypoints, 2]
  """
  with graph.as_default():
    with tf.Session() as sess:
      # Get handles to input and output tensors
      ops = tf.get_default_graph().get_operations()
      all_tensor_names = {output.name for op in ops for output in op.outputs}
      tensor_dict = {}
      detection_fields = fields.DetectionResultFields
      for key in [
          v for k, v in vars(detection_fields).items()
          if not k.startswith('__')
      ]:
        tensor_name = key + ':0'
        if tensor_name in all_tensor_names:
          tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
              tensor_name)
      image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

      # Run inference
      output_dict = sess.run(tensor_dict,
                             feed_dict={image_tensor: np.expand_dims(image, 0)})

      # all outputs are float32 numpy arrays, so convert types as appropriate
      output_dict['num_detections'] = int(output_dict['num_detections'][0])
      output_dict['detection_classes'] = output_dict[
          'detection_classes'][0].astype(np.uint8)
      output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
      output_dict['detection_scores'] = output_dict['detection_scores'][0]
      if 'detection_masks' in output_dict:
        output_dict['detection_masks'] = output_dict['detection_masks'][0]
      if 'detection_keypoints' in output_dict:
        output_dict['detection_keypoints'] = output_dict['detection_keypoints'][
            0]
  return output_dict

Pass an Image Through the Detector.

In [None]:
# Run inference
output_dict = RunInferenceSingleImage(image_np, detection_graph)
output_dict['detection_features'].shape

In [None]:
output_dict.keys()

Preprocessing the output of the detector to be readable by our models.

In [None]:
image_seq_num = 100 
image_feat  = {}
image_feat['height'] = image_np.shape[0]
image_feat['width'] = image_np.shape[1]

raw_feats = np.mean(np.mean(output_dict['detection_features'], axis=-2), axis=-2).squeeze() # image_feat['detection_features']
num_detections = output_dict['num_detections']

raw_scores = output_dict['detection_multiclass_scores'][:, :num_detections, ...].squeeze()

# Find regions with highest class scores
sorted_score_idxs = np.argsort(np.max(raw_scores[:, 1:], axis=-1))[::-1]

# Collect features, boxes, and scores for highest scoring regions
detection_feats = np.zeros((image_seq_num + 1, raw_feats.shape[-1]))
detection_scores = np.zeros((image_seq_num + 1, raw_scores.shape[-1]))
bbox_feats = np.zeros((image_seq_num + 1, 5))
image_padding = np.ones((image_seq_num + 1,))
padding_offset = max(image_seq_num + 1 - sorted_score_idxs.shape[0], 0)

for i, index in enumerate(sorted_score_idxs[:image_seq_num]):
  padded_index = i + padding_offset
  detection_feats[padded_index, :] = raw_feats[index, :]
  detection_scores[padded_index, :] = raw_scores[index, :]
  # index 0 is 'background'
  bbox_feats[padded_index, :4] = output_dict['detection_boxes'][index, :]
  bbox_w = (output_dict['detection_boxes'][index, 3] -
            output_dict['detection_boxes'][index, 1]) * image_feat['width']
  bbox_h = (output_dict['detection_boxes'][index, 2] -
            output_dict['detection_boxes'][index, 0]) * image_feat['height']
  bbox_area = (bbox_w * bbox_h) / (image_feat['height'] * image_feat['width'])
  bbox_feats[padded_index, -1] = bbox_area
  image_padding[padded_index] = 0

# Add in global image feature
detection_feats[-1, :]= np.mean(detection_feats[padding_offset:-1, ...], axis=0).squeeze()
bbox_feats[-1, :] = [0, 0, 1, 1, 1]
image_padding[-1] = 0

features.update(
    {'image/bboxes': bbox_feats.astype(np.float32),
     'image/padding_mask':  image_padding.astype(np.int32), 
     'image/detection_features': detection_feats.astype(np.float32),
     'image/detection_scores': detection_scores.astype(np.float32)})               


In [None]:
print(features['image/bboxes'].shape)
print(features['image/detection_features'].shape)

### Visualizing the Detector Regions 

In [None]:
%matplotlib inline

detection_classes = []
detection_scores = []
tuplet_index = {}

for i in range(100):
  raw_detection_scores_obj = output_dict['detection_multiclass_scores'][:,i,1:1600][0,:]
  raw_detection_scores_att = output_dict['detection_multiclass_scores'][:,i,1600:][0,:]
  max_obj = np.argmax(raw_detection_scores_obj)
  max_att = np.argmax(raw_detection_scores_att)
  tuplet_index[i] = {}
  tuplet_index[i]['name'] = '%s %s' %(category_index[max_att+1600]['name'],
                              category_index[max_obj+1]['name'])
  detection_classes.append(i)
  detection_scores.append(raw_detection_scores_obj[max_obj] +
                          raw_detection_scores_att[max_att])

# Create detections visualization
bboxes = vis_util.visualize_boxes_and_labels_on_image_array(
    image_np.copy(),
    output_dict['detection_boxes'],
    np.array(detection_classes),
    detection_scores,
    tuplet_index,
    instance_masks=None,
    use_normalized_coordinates=True,
    max_boxes_to_draw=15,
    min_score_thresh=.05,
    agnostic_mode=False)

fig = plt.gcf()
fig.set_size_inches(18.5, 10.5)
_ = plt.imshow(bboxes)
plt.axis('off')

# Running Image-Text Pairs through the MMT

Now that we have extracted our image and text features we can run them through our MMT model.

## Use features extracted in colab

In [None]:
# Select a model

#@title Category-conditional sampling { display-mode: "form", run: "auto" }

tags = ['architecture-ft_image-q-12',
        'architecture-ft_image-q-24',
        'architecture-ft_language-q-12',
        'architecture-ft_language-q-24',
        'architecture-ft_single-modality',
        'architecture-ft_single-stream',
        'architecture_heads1-768',
        'architecture_heads18-64',
        'architecture_heads3-256',
        'architecture_heads6-64',
        'architecture_image-q-12',
        'architecture_image-q-24',
        'architecture_language-q-12',
        'architecture_language-q-24',
        'architecture_mixed-modality',
        'architecture_single-modality',
        'architecture_single-modality-hloss',
        'architecture_single-stream',
        'architecture_vilbert-12block',
        'architecture_vilbert-1block',
        'architecture_vilbert-2block',
        'architecture_vilbert-4block',
        'baseline-ft_baseline',
        'baseline-ft_baseline-cls',
        'baseline-ft_baseline-no-bert-transfer',
        'baseline_baseline',
        'baseline_baseline-cls',
        'baseline_baseline-no-bert-transfer',
        'data-ft_cc',
        'data-ft_combined-dataset',
        'data-ft_combined-instance',
        'data-ft_mscoco',
        'data-ft_mscoco-narratives',
        'data-ft_oi-narratives',
        'data-ft_sbu',
        'data-ft_uniter-dataset',
        'data-ft_uniter-instance',
        'data-ft_vg',
        'data_cc',
        'data_cc-with-bert',
        'data_combined-dataset',
        'data_combined-instance',
        'data_mscoco',
        'data_mscoco-narratives',
        'data_oi-narratives',
        'data_sbu',
        'data_uniter-dataset',
        'data_uniter-instance',
        'data_vg',
        'loss_itm+mrm',
        'loss_itm_mrm',
        'loss_single-modality-contrastive1024',
        'loss_single-modality-contrastive32',
        'loss_v1-contrastive32',
        'pixel_vilbert_cc-full-image']

model = "data_cc" #@param ["architecture-ft_image-q-12", "architecture-ft_image-q-24", "architecture-ft_language-q-12", "architecture-ft_language-q-24", "architecture-ft_single-modality", "architecture-ft_single-stream", "architecture_heads1-768", "architecture_heads18-64", "architecture_heads3-256", "architecture_heads6-64", "architecture_image-q-12", "architecture_image-q-24", "architecture_language-q-12", "architecture_language-q-24", "architecture_mixed-modality", "architecture_single-modality", "architecture_single-modality-hloss", "architecture_single-stream", "architecture_vilbert-12block", "architecture_vilbert-1block", "architecture_vilbert-2block", "architecture_vilbert-4block", "baseline-ft_baseline", "baseline-ft_baseline-cls", "baseline-ft_baseline-no-bert-transfer", "baseline_baseline", "baseline_baseline-cls", "baseline_baseline-no-bert-transfer", "data-ft_cc", "data-ft_combined-dataset", "data-ft_combined-instance", "data-ft_mscoco", "data-ft_mscoco-narratives", "data-ft_oi-narratives", "data-ft_sbu", "data-ft_uniter-dataset", "data-ft_uniter-instance", "data-ft_vg", "data_cc", "data_cc-with-bert", "data_combined-dataset", "data_combined-instance", "data_mscoco", "data_mscoco-narratives", "data_oi-narratives", "data_sbu", "data_uniter-dataset", "data_uniter-instance", "data_vg", "loss_itm+mrm", "loss_itm_mrm", "loss_single-modality-contrastive1024", "loss_single-modality-contrastive32", "loss_v1-contrastive32"]

tfhub_link = "https://tfhub.dev/deepmind/mmt/%s/1" %model

In [None]:
model = hub.load(tfhub_link)

In [None]:
inputs={'image/bboxes': tf.expand_dims(features['image/bboxes'], 0),
        'text/padding_mask': tf.expand_dims(features['text/padding_mask'], 0),
        'image/padding_mask': tf.expand_dims(features['image/padding_mask'], 0),
        'masked_tokens': tf.expand_dims(features['text/token_ids'], 0),
        'text/segment_ids': tf.expand_dims(features['text/segment_ids'], 0),
        'image/detection_features': tf.expand_dims(features['image/detection_features'], 0),
        'text/token_ids': tf.expand_dims(features['text/token_ids'], 0)
          }

output = model.signatures['default'](**inputs)
score = tf.nn.softmax(output['output']).numpy()[0]

if score > 0.5:
  print('The text and image match!  (score: %0.03f)' %score)
else: 
  print('The text and image do not match :( (score: %0.03f)' %score) 

# Running with Pre-Extracted Features

We have pre-extracted MSCOCO and Flickr image features.  You can uset these pre-extracted features to do retrieval.

## Use Precomputed features

In [None]:
import pickle as pkl

!wget --no-check-certificate https://storage.googleapis.com/dm-mmt-models/features/coco_test/570107.pkl -P '/tmp/' 
with open('/tmp/570107.pkl', 'rb') as f:
  im_feats = pkl.load(f)

In [None]:
inputs={'image/bboxes': tf.expand_dims(features['image/bboxes'], 0),
        'text/padding_mask': tf.expand_dims(features['text/padding_mask'], 0),
        'image/padding_mask': tf.expand_dims(im_feats['image/padding_mask'], 0),
        'masked_tokens': tf.expand_dims(features['text/token_ids'], 0),
        'text/segment_ids': tf.expand_dims(features['text/segment_ids'], 0),
        'image/detection_features': tf.expand_dims(im_feats['image/detection_features'], 0),
        'text/token_ids': tf.expand_dims(features['text/token_ids'], 0)
          }

output = model.signatures['default'](**inputs)
score = tf.nn.softmax(output['output']).numpy()[0]

if score > 0.5:
  print('The text and image match!  (score: %0.03f)' %score)
else: 
  print('The text and image do not match :( (score: %0.03f)' %score) 