<a href="https://colab.research.google.com/github/alpacaYiChun/ML/blob/master/Clip2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction

The example demonstrates how to build a dual encoder (also known as two-tower) neural network
model to search for images using natural language. The model is inspired by
the [CLIP](https://openai.com/blog/clip/)
approach, introduced by Alec Radford et al. The idea is to train a vision encoder and a text
encoder jointly to project the representation of images and their captions into the same embedding
space, such that the caption embeddings are located near the embeddings of the images they describe.

This example requires TensorFlow 2.4 or higher.
In addition, [TensorFlow Hub](https://www.tensorflow.org/hub)
and [TensorFlow Text](https://www.tensorflow.org/tutorials/tensorflow_text/intro)
are required for the BERT model, and [TensorFlow Addons](https://www.tensorflow.org/addons)
is required for the AdamW optimizer. These libraries can be installed using the
following command:

```python
pip install -q -U tensorflow-hub tensorflow-text tensorflow-addons
```

In [1]:
from google.colab import drive

drive.mount('/content/gdrive')

bgPath = '/content/gdrive/My Drive/CLIP'

Mounted at /content/gdrive


## Setup

In [2]:
!pip install tensorflow==2.15.1
!pip install tensorflow_text==2.15.0
!pip install tensorflow_hub==0.15.0
#!pip install -q -U tensorflow tensorflow-hub tensorflow-text tqdm
#pip install --upgrade keras-nlp
#pip list | grep keras

import os
import collections
import json
import numpy as np
import tensorflow as tf
from tensorflow import keras
!pip install transformers tensorflow
from transformers import TFAutoModel, AutoImageProcessor
from tensorflow.keras import layers
from tensorflow.keras import regularizers
#import keras_nlp
import tensorflow_hub as hub
import tensorflow_text as text
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tqdm import tqdm

#from tensorflow.python.framework.ops import disable_eager_execution
#disable_eager_execution()

# Suppressing tf.hub warnings
tf.get_logger().setLevel("ERROR")

#tf.config.experimental_run_functions_eagerly(True)

Collecting tensorflow==2.15.1
  Downloading tensorflow-2.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.15.1)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=23.5.26 (from tensorflow==2.15.1)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting google-pasta>=0.1.1 (from tensorflow==2.15.1)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow==2.15.1)
  Downloading libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl.metadata (5.2 kB)
Collecting ml-dtypes~=0.3.1 (from tensorflow==2.15.1)
  Downloading ml_dtypes-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 (from tensorflow==2.15.1)
  Downloading protobuf-4.25.6-cp37-abi3-manylinux2014_x8



In [3]:
W = 224
H = 224
C = 3

## Prepare the data

We will use the [MS-COCO](https://cocodataset.org/#home) dataset to train our
dual encoder model. MS-COCO contains over 82,000 images, each of which has at least
5 different caption annotations. The dataset is usually used for
[image captioning](https://www.tensorflow.org/tutorials/text/image_captioning)
tasks, but we can repurpose the image-caption pairs to train our dual encoder
model for image search.

###
Download and extract the data

First, let's download the dataset, which consists of two compressed folders:
one with images, and the other—with associated image captions.
Note that the compressed images folder is 13GB in size.

In [4]:
root_dir = "datasets"
annotations_dir = os.path.join(root_dir, "annotations")
images_dir = os.path.join(root_dir, "train2014")
tfrecords_dir = os.path.join(root_dir, "tfrecords")
annotation_file = os.path.join(annotations_dir, "captions_train2014.json")

# Download caption annotation files
if not os.path.exists(annotations_dir):
    annotation_zip = tf.keras.utils.get_file(
        "captions.zip",
        cache_dir=os.path.abspath("."),
        origin="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
        extract=True,
    )
    os.remove(annotation_zip)

# Download image files
if not os.path.exists(images_dir):
    image_zip = tf.keras.utils.get_file(
        "train2014.zip",
        cache_dir=os.path.abspath("."),
        origin="http://images.cocodataset.org/zips/train2014.zip",
        extract=True,
    )
    os.remove(image_zip)

print("Dataset is downloaded and extracted successfully.")

with open(annotation_file, "r") as f:
    annotations = json.load(f)["annotations"]

image_path_to_caption = collections.defaultdict(list)
for element in annotations:
    caption = f"{element['caption'].lower().rstrip('.')}"
    image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
    image_path_to_caption[image_path].append(caption)

image_paths = list(image_path_to_caption.keys())
print(f"Number of images: {len(image_paths)}")

Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2014.zip
Downloading data from http://images.cocodataset.org/zips/train2014.zip
Dataset is downloaded and extracted successfully.
Number of images: 82783


### Process and save the data to TFRecord files

You can change the `sample_size` parameter to control many image-caption pairs
will be used for training the dual encoder model.
In this example we set `train_size` to 30,000 images,
which is about 35% of the dataset. We use 2 captions for each
image, thus producing 60,000 image-caption pairs. The size of the training set
affects the quality of the produced encoders, but more examples would lead to
longer training time.

In [5]:
train_size = 20000
valid_size = 5000
captions_per_image = 3
images_per_file = 2000

train_image_paths = image_paths[:train_size]
num_train_files = int(np.ceil(train_size / images_per_file))
train_files_prefix = os.path.join(tfrecords_dir, "train")

valid_image_paths = image_paths[-valid_size:]
num_valid_files = int(np.ceil(valid_size / images_per_file))
valid_files_prefix = os.path.join(tfrecords_dir, "valid")

tf.io.gfile.makedirs(tfrecords_dir)


def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def create_example(image_path, caption):
    feature = {
        "caption": bytes_feature(caption.encode()),
        "raw_image": bytes_feature(tf.io.read_file(image_path).numpy()),
        "image_path": bytes_feature(image_path.encode()),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))


def write_tfrecords(file_name, image_paths):
    caption_list = []
    image_path_list = []
    for image_path in image_paths:
        captions = image_path_to_caption[image_path][:captions_per_image]
        caption_list.extend(captions)
        image_path_list.extend([image_path] * len(captions))

    with tf.io.TFRecordWriter(file_name) as writer:
        for example_idx in range(len(image_path_list)):
            example = create_example(
                image_path_list[example_idx], caption_list[example_idx]
            )
            writer.write(example.SerializeToString())
    return example_idx + 1


def write_data(image_paths, num_files, files_prefix):
    example_counter = 0
    for file_idx in tqdm(range(num_files)):
        file_name = files_prefix + "-%02d.tfrecord" % (file_idx)
        start_idx = images_per_file * file_idx
        end_idx = start_idx + images_per_file
        example_counter += write_tfrecords(file_name, image_paths[start_idx:end_idx])
    return example_counter


train_example_count = write_data(train_image_paths, num_train_files, train_files_prefix)
print(f"{train_example_count} training examples were written to tfrecord files.")

valid_example_count = write_data(valid_image_paths, num_valid_files, valid_files_prefix)
print(f"{valid_example_count} evaluation examples were written to tfrecord files.")

100%|██████████| 10/10 [01:10<00:00,  7.06s/it]


60000 training examples were written to tfrecord files.


100%|██████████| 3/3 [00:35<00:00, 11.79s/it]

15000 evaluation examples were written to tfrecord files.





### Create `tf.data.Dataset` for training and evaluation

In [None]:

feature_description = {
    "caption": tf.io.FixedLenFeature([], tf.string),
    "raw_image": tf.io.FixedLenFeature([], tf.string),
}


def read_example(example):
    features = tf.io.parse_single_example(example, feature_description)
    raw_image = features.pop("raw_image")
    features["image"] = tf.image.resize(
        tf.image.decode_jpeg(raw_image, channels=3), size=(W, H)
    )
    return features


def get_dataset(file_pattern, batch_size):

    return (
        tf.data.TFRecordDataset(tf.data.Dataset.list_files(file_pattern))
        .map(
            read_example,
            num_parallel_calls=tf.data.AUTOTUNE,
            deterministic=False,
        )
        .shuffle(batch_size * 10)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
        .batch(batch_size)
    )

def parse_text_fn(example):
    example = tf.io.parse_single_example(example, feature_description)
    return example['caption']

def get_text_dataset(pattern, batch_size):
    files = tf.data.Dataset.list_files(pattern)
    text_dataset = files.interleave(
        lambda x: tf.data.TFRecordDataset(x).map(parse_text_fn, num_parallel_calls=tf.data.AUTOTUNE),
        cycle_length=tf.data.AUTOTUNE,
        block_length=1,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    return text_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

## Implement the projection head

The projection head is used to transform the image and the text embeddings to
the same embedding space with the same dimensionality.

In [7]:

def project_embeddings(
    embeddings, num_projection_layers, projection_dims, dropout_rate
):
    projected_embeddings = layers.Dense(units=projection_dims)(embeddings)
    for _ in range(num_projection_layers):
        x = tf.nn.gelu(projected_embeddings)
        x = layers.Dense(projection_dims)(x)
        x = layers.Dropout(dropout_rate)(x)
        x = layers.Add()([projected_embeddings, x])
        projected_embeddings = layers.LayerNormalization()(x)
    return projected_embeddings


## Implement the vision encoder

In this example, we use [Xception](https://keras.io/api/applications/xception/)
from [Keras Applications](https://keras.io/api/applications/) as the base for the
vision encoder.

In [8]:
# Load the Vision Transformer (ViT) model and image processor
'''
vit_model = TFAutoModel.from_pretrained("google/vit-base-patch16-224-in21k")
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

class ImagePreprocessingLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(ImagePreprocessingLayer, self).__init__(**kwargs)

    def call(self, inputs):
        # Ensure correct dtype
        inputs = tf.cast(inputs, tf.float32)

        # Normalize images to [0, 1] (ViT expects pixel values between 0-1)
        inputs = inputs / 255.0

        # Ensure correct shape for ViT (batch, channels, height, width)
        inputs = tf.image.resize(inputs, (224, 224))  # Resize images
        inputs = tf.transpose(inputs, perm=[0, 3, 1, 2])  # ViT expects channels-first

        return inputs

class ViTEmbeddingLayer(layers.Layer):
    def __init__(self, vit_model, **kwargs):
        super(ViTEmbeddingLayer, self).__init__(**kwargs)
        self.vit_model = vit_model  # Store ViT model
        self.vit_model.trainable = True  # Freeze ViT weights

    def call(self, inputs):
        vit_outputs = self.vit_model(inputs)  # Get ViT model output
        cls_embeddings = vit_outputs.last_hidden_state[:, 0, :]  # Extract CLS token
        return cls_embeddings
'''

def create_vision_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False, go_back=0, project=True, alg="xception"
):
    base_vision_model = None

    # Receive the images as inputs.
    inputs = layers.Input(shape=(W, H, C), name="image_input")

    embeddings = None

    if alg == "custom":
      def block(x, i):
        org = x
        x = layers.SeparableConv2D(32, (3, 3), padding='same', name=f'sepconv1{i}')(x)
        x = layers.BatchNormalization(name=f'bn1{i}')(x)
        x = layers.Activation('relu', name=f'relu1{i}')(x)

        x = layers.SeparableConv2D(64, (3, 3), padding='same', name=f'sepconv2{i}')(x)
        x = layers.BatchNormalization(name=f'bn2{i}')(x)
        x = layers.Activation('relu', name=f'relu2{i}')(x)

        shortcut = layers.SeparableConv2D(64, (1, 1), padding='same', name=f'shortcut{i}')(org)
        shortcut = layers.BatchNormalization(name=f'shortcut_bn{i}')(shortcut)
        shortcut = layers.Activation('relu', name=f'shortcut_relu{i}')(shortcut)

        x = layers.add([x, shortcut], name=f'add{i}')
        x = layers.Activation('relu', name=f'relu3{i}')(x)

        return x

      x = inputs
      for i in range(6):
        x = block(x, i)

      embeddings = layers.GlobalAveragePooling2D()(x)
    elif alg == "vit":
      processed_inputs = ImagePreprocessingLayer()(inputs)
      embeddings = ViTEmbeddingLayer(vit_model)(processed_inputs)
      base_vision_model = vit_model
    else:
      xception = keras.applications.Xception(
          include_top=False, weights="imagenet", pooling="avg"
      )
      base_vision_model = xception
      for layer in xception.layers:
          layer.trainable = False
      if go_back > 0:
          for layer in xception.layers[-go_back:]:
              layer.trainable = True
      xception_input = tf.keras.applications.xception.preprocess_input(inputs)
      embeddings = xception(xception_input)

    outputs = embeddings
    # Project the embeddings produced by the model.
    if project:
      outputs = project_embeddings(
          embeddings, num_projection_layers, projection_dims, dropout_rate
      )
    # Create the vision encoder model.
    return base_vision_model, keras.Model(inputs, outputs, name="vision_encoder")


In [9]:
text_dataset = get_text_dataset(os.path.join(tfrecords_dir, "train-*.tfrecord"), 256)

## Implement the text encoder

We use [BERT](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1)
from [TensorFlow Hub](https://tfhub.dev) as the text encoder

In [10]:
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(0, maxlen, 1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

In [11]:

def create_text_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False, project=True, adapt=0, alg="bert"
):
    # Load the BERT preprocessing module.
    preprocess = hub.KerasLayer(
        "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2",
        name="text_preprocessing",
    )
    # Load the pre-trained BERT model to be used as the base encoder.
    bert = hub.KerasLayer(
        "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1",
        name="bert",
    )
    # Set the trainability of the base encoder.
    bert.trainable = trainable
    text_base_model = bert

    # Receive the text as inputs.
    inputs = layers.Input(shape=(), dtype=tf.string, name="text_input")

    def bert_process(x):
      # Preprocess the text.
      bert_inputs = preprocess(x)
      # Generate embeddings for the preprocessed text using the BERT model.
      embeddings = bert(bert_inputs)["pooled_output"]
      return embeddings

    def bert_lstm(x):
      # Preprocess the text.
      bert_inputs = preprocess(x)
      # Generate embeddings for the preprocessed text using the BERT model.
      bert_outputs = bert(bert_inputs)
      hidden_states = bert_outputs["sequence_output"]
      embeddings = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.1, recurrent_dropout=0.1))(hidden_states)
      embeddings = layers.Bidirectional(layers.LSTM(128, return_sequences=False, dropout=0.1, recurrent_dropout=0.1))(embeddings)
      return embeddings

    def lstm_process(x):
      max_tokens = 10000
      embedding_dim = 256
      lstm_units = 256
      e = x
      # Step 1: Tokenize the text
      tv = layers.TextVectorization(max_tokens=max_tokens, name="tv", standardize="lower_and_strip_punctuation")
      tv.adapt(text_dataset.map(lambda x: x))
      e = tv(e)
      # Step 2: Embed the tokenized text
      e = layers.Embedding(max_tokens, embedding_dim, name="embedding")(e)
      # Step 3: Use an LSTM layer to get a sentence embedding
      e = layers.Bidirectional(layers.LSTM(lstm_units, name="lstm1", return_sequences=True, dropout=0.2, recurrent_dropout=0.2))(e)
      e = layers.LayerNormalization(name="layer_norm1")(e)
      e = layers.Bidirectional(layers.LSTM(lstm_units, name="lstm2", return_sequences=True, dropout=0.2, recurrent_dropout=0.2))(e)
      e = layers.LayerNormalization(name="layer_norm2")(e)
      e = layers.GlobalAveragePooling1D()(e)
      #e = layers.LSTM(lstm_units, name="lstm3", return_sequences=False, dropout=0.2, recurrent_dropout=0.2)(e)
      #e = layers.LayerNormalization(name="layer_norm3")(e)
      return e

    def transformer_process(x):
      e = x
      max_tokens = 10000
      output_sequence_length = 100

      tv = layers.TextVectorization(
        max_tokens=max_tokens,
        output_mode='int',
        output_sequence_length=output_sequence_length,
        standardize="lower_and_strip_punctuation", name="tv")
      tv.adapt(text_dataset.map(lambda x: x))
      e = tv(e)

      embedding_layer = TokenAndPositionEmbedding(maxlen=output_sequence_length, vocab_size=max_tokens, embed_dim=128)
      e = embedding_layer(e)

      for i in range(3):
        e = TransformerBlock(128, 8, 128, rate=0.1)(e)

      e = layers.GlobalAveragePooling1D()(e)

      return e


    # Generate the embeddings for the text using the BERT model.
    embeddings = None
    if alg == "bert":
      embeddings = bert_process(inputs)
    elif alg == "lstm":
      embeddings = lstm_process(inputs)
    elif alg == "bert_lstm":
      embeddings = bert_lstm(inputs)
    elif alg == "transformer":
      embeddings = transformer_process(inputs)
    else:
      raise ValueError("Invalid algorithm")

    # Project the embeddings produced by the model.
    outputs = embeddings
    if project:
      outputs = project_embeddings(
          embeddings, num_projection_layers, projection_dims, dropout_rate
      )
    # Adapt dims if needed.
    if adapt > 0:
      outputs = layers.Dense(adapt)(outputs)
    # Create the text encoder model.
    return text_base_model, keras.Model(inputs, outputs, name="text_encoder")


## Implement the dual encoder

To calculate the loss, we compute the pairwise dot-product similarity between
each `caption_i` and `images_j` in the batch as the predictions.
The target similarity between `caption_i`  and `image_j` is computed as
the average of the (dot-product similarity between `caption_i` and `caption_j`)
and (the dot-product similarity between `image_i` and `image_j`).
Then, we use crossentropy to compute the loss between the targets and the predictions.

In [12]:

class DualEncoder(keras.Model):
    def __init__(self, text_encoder, image_encoder, temperature=1.0, mature_size='both', **kwargs):
        super().__init__(**kwargs)
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder
        self.temperature = temperature
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mature_size = mature_size

    @property
    def metrics(self):
        return [self.loss_tracker]

    def call(self, features, training=False):
        # Place each encoder on a separate GPU (if available).
        # TF will fallback on available devices if there are fewer than 2 GPUs.
        with tf.device("/gpu:0"):
            # Get the embeddings for the captions.
            caption_embeddings = text_encoder(features["caption"], training=training)
        with tf.device("/gpu:1"):
            # Get the embeddings for the images.
            image_embeddings = vision_encoder(features["image"], training=training)
        return caption_embeddings, image_embeddings

    def compute_loss(self, caption_embeddings, image_embeddings):
        caption_embeddings = tf.math.l2_normalize(caption_embeddings, axis=1)
        image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)

        # logits[i][j] is the dot_similarity(caption_i, image_j).
        logits = (
            tf.matmul(caption_embeddings, image_embeddings, transpose_b=True)
            / self.temperature
        )

        cap_to_img = logits
        img_to_cap = tf.transpose(logits)

        # images_similarity[i][j] is the dot_similarity(image_i, image_j).
        images_similarity = tf.matmul(
            image_embeddings, image_embeddings, transpose_b=True
        )
        # captions_similarity[i][j] is the dot_similarity(caption_i, caption_j).
        captions_similarity = tf.matmul(
            caption_embeddings, caption_embeddings, transpose_b=True
        )
        # Get the targets
        if self.mature_size == 'both':
          targets = (captions_similarity + images_similarity) / (2 * self.temperature)
        elif self.mature_size == 'caption':
          targets = captions_similarity / self.temperature
        else:
          targets = images_similarity / self.temperature

        targets = keras.activations.softmax(targets)

        # Compute the loss for the captions using crossentropy
        captions_loss = keras.losses.categorical_crossentropy(
            y_true=targets, y_pred=cap_to_img, from_logits=True
        )
        # Compute the loss for the images using crossentropy
        images_loss = keras.losses.categorical_crossentropy(
            y_true=targets, y_pred=img_to_cap, from_logits=True
        )

        if self.mature_size == 'both':
          # Return the mean of the loss over the batch.
          return (captions_loss + images_loss) / 2
        elif self.mature_size == 'caption':
          return images_loss
        else:
          return captions_loss

    def train_step(self, features):
        with tf.GradientTape() as tape:
            # Forward pass
            caption_embeddings, image_embeddings = self(features, training=True)
            loss = self.compute_loss(caption_embeddings, image_embeddings)
        # Backward pass
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # Monitor loss
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def test_step(self, features):
        caption_embeddings, image_embeddings = self(features, training=False)
        loss = self.compute_loss(caption_embeddings, image_embeddings)
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}


## Train the dual encoder model

In this experiment, we freeze the base encoders for text and images, and make only
the projection head trainable.

In [13]:
num_epochs = 20  # In practice, train for at least 30 epochs
batch_size = 256

base_vision_model, vision_encoder = create_vision_encoder(
    num_projection_layers=1, projection_dims=256, dropout_rate=0.1, trainable=False, go_back=0, project=False, alg="xception"
)
base_text_model, text_encoder = create_text_encoder(
    num_projection_layers=1, projection_dims=256, dropout_rate=0.1, trainable=False, project=False, adapt=2048, alg="transformer"
)

dual_encoder = DualEncoder(text_encoder, vision_encoder, mature_size="image", temperature=0.02)
dual_encoder.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001, weight_decay=0.001)
)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5


Note that training the model with 60,000 image-caption pairs, with a batch size of 256,
takes around 12 minutes per epoch using a V100 GPU accelerator. If 2 GPUs are available,
the epoch takes around 8 minutes.

In [None]:
vision_encoder.summary()
text_encoder.summary()

In [15]:
vision_save_path = f"{bgPath}/vision_encoder_custom_bert"
text_save_path = f"{bgPath}/text_encoder_custom_bert"
vision_save_path_lstm = f"{bgPath}/vision_encoder_custom_lstm"
text_save_path_lstm = f"{bgPath}/text_encoder_custom_lstm"
vision_save_path_bert_lstm = f"{bgPath}/vision_encoder_bert_lstm"
text_save_path_bert_lstm = f"{bgPath}/text_encoder_bert_lstm"
vision_save_path_vit = f"{bgPath}/vision_encoder_vit"
text_save_path_vit = f"{bgPath}/text_encoder_vit"
vision_save_path_transformer = f"{bgPath}/vision_encoder_transformer"
text_save_path_transformer = f"{bgPath}/text_encoder_transformer"

In [16]:
train_dataset = get_dataset(os.path.join(tfrecords_dir, "train-*.tfrecord"), batch_size)
valid_dataset = get_dataset(os.path.join(tfrecords_dir, "valid-*.tfrecord"), batch_size)

In [17]:
def read_image(image_path):
    image_array = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)
    return tf.image.resize(image_array, (W, H))

def generate_image_embeddings(image_paths, batch_size=256):
  #print(f"Generating embeddings for {len(image_paths)} images...")
  image_embeddings = vision_encoder.predict(
      tf.data.Dataset.from_tensor_slices(image_paths).map(read_image).batch(batch_size),
      verbose=1,
  )
  #print(f"Image embeddings shape: {image_embeddings.shape}.")
  return image_embeddings

def generate_text_embeddings(queries):
  #print(f"Generating embeddings for {len(queries)} queries...")
  query_embedding = text_encoder(tf.convert_to_tensor(queries))
  #print(f"Query embeddings shape: {query_embedding.shape}.")
  return query_embedding

def find_matches(image_paths, queries, ieb=None, qeb=None, k=9, normalize=True):
    image_embeddings = ieb
    query_embedding = qeb
    if query_embedding is None:
      query_embedding = generate_text_embeddings(queries)
    if image_embeddings is None:
      image_embeddings = generate_image_embeddings(image_paths)
    print(image_embeddings.shape)
    print(query_embedding.shape)
    if normalize:
        image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)
        query_embedding = tf.math.l2_normalize(query_embedding, axis=1)

    dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True)

    results = tf.math.top_k(dot_similarity, k).indices.numpy()

    return [[image_paths[idx] for idx in indices] for indices in results]

def compute_top_k_accuracy(image_paths, tpk_image_embeddings=None, tpk=100):
    hits = 0
    num_batches = int(np.ceil(len(image_paths) / batch_size))
    for idx in tqdm(range(num_batches)):
        start_idx = idx * batch_size
        end_idx = start_idx + batch_size
        current_image_paths = image_paths[start_idx:end_idx]
        queries = [
            image_path_to_caption[image_path][0] for image_path in current_image_paths
        ]
        result = find_matches(image_paths, queries, ieb=tpk_image_embeddings, k=tpk)
        hits += sum(
            [
                image_path in matches
                for (image_path, matches) in list(zip(current_image_paths, result))
            ]
        )

    return hits / len(image_paths)

def get_performance(train_image_embeddings=None, test_image_embeddings=None):
  print("Scoring training data...")
  train_accuracy = compute_top_k_accuracy(train_image_paths, tpk_image_embeddings=train_image_embeddings, tpk=100)
  print(f"Train accuracy: {round(train_accuracy * 100, 3)}%")

  print("Scoring evaluation data...")
  eval_accuracy = compute_top_k_accuracy(image_paths[train_size:], tpk_image_embeddings=test_image_embeddings, tpk=100)
  print(f"Eval accuracy: {round(eval_accuracy * 100, 3)}%")

In [None]:
# Create a learning rate scheduler callback.
reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss", factor=0.2, patience=2
)
# Create an early stopping callback.
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss", patience=3, restore_best_weights=True
)

class PerformanceReview(tf.keras.callbacks.Callback):
    def __init__(self, fixed_image=True):
        super(PerformanceReview, self).__init__()
        if fixed_image:
          self.train_embeddings = generate_image_embeddings(train_image_paths)
          self.test_embeddings = generate_image_embeddings(image_paths[train_size:])
        else:
          self.train_embeddings = None
          self.test_embeddings = None

    def on_epoch_end(self, epoch, logs=None):
        get_performance(self.train_embeddings, self.test_embeddings)

pr = PerformanceReview(fixed_image=True)



In [None]:
print(f"Number of GPUs: {len(tf.config.list_physical_devices('GPU'))}")
print(f"Number of examples (caption-image pairs): {train_example_count}")
print(f"Batch size: {batch_size}")
print(f"Steps per epoch: {int(np.ceil(train_example_count / batch_size))}")

history = dual_encoder.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=valid_dataset,
    callbacks=[pr, reduce_lr, early_stopping],
)

Number of GPUs: 0
Number of examples (caption-image pairs): 60000
Batch size: 256
Steps per epoch: 235
Epoch 1/20
    235/Unknown - 1511s 6s/step - loss: 5.5670

In [None]:
print("Training completed. Saving vision and text encoders...")
vision_encoder.save(vision_save_path_transformer)
text_encoder.save(text_save_path_transformer)
#vision_encoder.save_weights(f"{bgPath}/vision_encoder_vit/vit_vision_weights.h5")
#text_encoder.save_weights(f"{bgPath}/vision_encoder_vit/vit_text_weights.h5")
print("Models are saved.")

Plotting the training loss:

In [None]:
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend(["train", "valid"], loc="upper right")
plt.show()

### Retrieve relevant images

In this example, we use exact matching by computing the dot product similarity
between the input query embedding and the image embeddings, and retrieve the top k
matches. However, *approximate* similarity matching, using frameworks like
[ScaNN](https://github.com/google-research/google-research/tree/master/scann),
[Annoy](https://github.com/spotify/annoy), or [Faiss](https://github.com/facebookresearch/faiss)
is preferred in real-time use cases to scale with a large number of images.

Set the `query` variable to the type of images you want to search for.
Try things like: 'a plate of healthy food',
'a woman wearing a hat is walking down a sidewalk',
'a bird sits near to the water', or 'wild animals are standing in a field'.

In [None]:
q1 = "a family standing next to the ocean on a sandy beach with a surf board"
q2 = "handsome men are walking on the street"
q3 = "people are working hard in the office"
q4 = "there are mountains and lakes under blue sky"
q5 = "the houses are very colorful and the streets are neat"
q6 = "there are shoes"
q7 = "this is a listing on a commerce website"
q8 = "a lot of children are in the classroom and the teacher is really strict"
image_embeddings = generate_image_embeddings(image_paths)
g = 5
for q in [q1, q2, q3, q4, q5, q6, q7, q8]:
    matches = find_matches(image_paths, [q], ieb=image_embeddings, normalize=True, k=g*g)[0]
    print(q)
    plt.figure(figsize=(20, 20))
    for i in range(25):
        ax = plt.subplot(g, g, i + 1)
        plt.imshow(mpimg.imread(matches[i]))
        plt.axis("on")
    print("======================================================")

# Train the model using released **features**

In [None]:
'''
back_vision = 11
back_text = 11
if base_vision_model is not None and base_text_model is not None:
  for i in range(-back_vision, 0):
    base_vision_model.layers[i].trainable = True
  for i in range(-back_text, 0):
    base_text_model.layers[i].trainable = True
'''

num_epochs2 = 10

if base_vision_model is not None:
  base_vision_model.trainable = True
if base_text_model is not None:
  base_text_model.trainable = True

text_encoder.summary()
vision_encoder.summary()

history = dual_encoder.fit(
    train_dataset,
    epochs=num_epochs2,
    validation_data=valid_dataset,
    callbacks=[reduce_lr, early_stopping],
)

#vision_encoder.save_weights(f"{bgPath}/vision_encoder_vit/vit_vision_weights.h5")
#text_encoder.save_weights(f"{bgPath}/vision_encoder_vit/vit_text_weights.h5")

print("Scoring training data...")
train_accuracy = compute_top_k_accuracy(train_image_paths)
print(f"Train accuracy: {round(train_accuracy * 100, 3)}%")

print("Scoring evaluation data...")
eval_accuracy = compute_top_k_accuracy(image_paths[train_size:], k=100)
print(f"Eval accuracy: {round(eval_accuracy * 100, 3)}%")

## Final remarks

You can obtain better results by increasing the size of the training sample,
train for more  epochs, explore other base encoders for images and text,
set the base encoders to be trainable, and tune the hyperparameters,
especially the `temperature` for the softmax in the loss computation.

Example available on HuggingFace

| Trained Model | Demo |
| :--: | :--: |
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-nl%20image%20search-black.svg)](https://huggingface.co/keras-io/dual-encoder-image-search) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-nl%20image%20search-black.svg)](https://huggingface.co/spaces/keras-io/dual-encoder-image-search) |
