# Multimodal Model
---

S.Yu. Papulin (papulin.study@yandex.ru)

### Contents

- [Loading Dataset](#Loading-Dataset)
- [Image Encoder](#Image-Encoder)
- [Text Encoder](#Text-Encoder)
- [Multimodal Model](#Multimodal-Model)
- [Pretrained `CLIP` model](#Pretrained-CLIP-model)
- [Sources](#Sources)

In [None]:
import tensorflow as tf
from tensorflow.keras import (
    layers, 
    models, 
    Model, 
    utils, 
    losses, 
    optimizers, 
    metrics
)

In [None]:
from tensorflow.keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

In [None]:
RANDOM_STATE = 100

## Loading Dataset

In [None]:
# Load dataset and show shape of data
(X_trainval, y_trainval), (X_test, y_test) = cifar10.load_data()
X_trainval.shape, y_trainval.shape, X_test.shape, y_test.shape

In [None]:
# Image value range
X_trainval.max(), X_trainval.min()

In [None]:
# Unique targets and their counts
np.unique(y_trainval, return_counts=True)

In [None]:
# First n targets
y_trainval[:5]

In [None]:
# Class labels
labels = np.array([
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
])

In [None]:
num_classes = len(labels)
num_classes

In [None]:
# Show 10 random images of each class
NUM_DISPLAY_IMAGES = 10
for target in range(num_classes):
    indices = np.asarray(y_trainval==target).nonzero()[0]
    np.random.seed(RANDOM_STATE)
    indices_rnd = np.random.choice(indices, NUM_DISPLAY_IMAGES, replace=False)
    print(f'Class label: {labels[target]}')
    plt.figure(figsize=[10, 10])
    for i in range(NUM_DISPLAY_IMAGES):
        plt.subplot(1, NUM_DISPLAY_IMAGES, i+1)
        plt.title(indices_rnd[i])
        plt.imshow(X_trainval[indices_rnd[i]])
        plt.axis("off")
    plt.show()

## Image Encoder

Image encoder represents our image as a vector in some multidimensional space. This vector should contain semantic information about the image. In our case, we have the image classifier. The last hidden layer is a classification layer with 10 neurons for each class. We can get rid of the last layer. Another option is to keep the last convolutional layer and average the output of all filters.

### Preparing dataset

In [None]:
# Compose train and validation subsets
X_train, X_val, y_train, y_val = train_test_split(
    X_trainval, 
    y_trainval, 
    test_size=0.1, 
    random_state=RANDOM_STATE
)

In [None]:
X_train.shape, X_val.shape

In [None]:
def convert_to_tf_dataset(X, y, batch_size=64, use_one_hot=False):
    X = X.astype('float32') / 255.0
    if use_one_hot:
        y = utils.to_categorical(y)
    else:
        y = y.flatten()
    return (
        tf.data.Dataset.from_tensor_slices((X, y))
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )


def print_first_batch(ds):
    for X_batch, y_batch in ds.take(1):
        print(X_batch)
        print(y_batch)

In [None]:
imgrecog_train_ds = convert_to_tf_dataset(X_train, y_train)
imgrecog_val_ds = convert_to_tf_dataset(X_val, y_val)
imgrecog_test_ds = convert_to_tf_dataset(X_test, y_test)

In [None]:
# print_first_batch(train_ds)

### Image recognition

In [None]:
def load_pretrained_tiny_conv_model(model_name='tiny_conv_net_128@10.keras'):
    """
    Note: From the C5_NN_ImageRecognition notebook
    
    If you use TinyConvModel, you need to have access to 
    the implementation of this class.
    """
    import os
    BASE_MODEL_PATH = '~/.keras/models'
    model_filename = model_name
    model_path = os.path.expanduser(os.path.join(BASE_MODEL_PATH, model_filename))
    return models.load_model(model_path)

In [None]:
def build_tiny_conv_model():
    model = models.Sequential(name="ConvNet")
    model.add(layers.Input(shape=(32, 32, 3)))
    model.add(layers.Conv2D(16, (3, 3), activation="relu", padding="same", name="layer_1"))
    model.add(layers.MaxPooling2D((2, 2), name="transform_1"))
    model.add(layers.Conv2D(32, (3, 3), activation="relu", padding="same", name="layer_2"))
    model.add(layers.Dropout(0.1, name="dropout"))
    model.add(layers.Flatten(name="transform_2"))
    model.add(layers.Dense(128, activation="relu", name="layer_3"))
    model.add(layers.Dense(10, name="layer_4"))
    return model

In [None]:
imgrecog_model = build_tiny_conv_model()

In [None]:
imgrecog_model = build_tiny_conv_model()
imgrecog_model.compile(
    optimizer=optimizers.Adam(learning_rate=1e-3), 
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[metrics.SparseCategoricalAccuracy(),]
)
imgrecog_model.summary()

In [None]:
NUM_EPOCHS = 10

train_history = imgrecog_model.fit(
    imgrecog_train_ds,
    # validation_split=0.1,
    validation_data=imgrecog_val_ds,
    epochs=NUM_EPOCHS,
    verbose=1
)

In [None]:
_, test_error = imgrecog_model.evaluate(imgrecog_test_ds)
test_error

### Building image encoder

In [None]:
# layers of classifier
imgrecog_model.layers

In [None]:
def build_image_encoder_on_tiny_conv():
    """Build a model containing all layers except the last one."""
    return Model(
        inputs=imgrecog_model.inputs,
        outputs=imgrecog_model.layers[-2].output,
        trainable=False
    )


def build_image_encoder_on_resnet():
    """More advanced image encoder based on the ResNet50 model."""
    return tf.keras.applications.ResNet50(
        # exclude last layer
        include_top=False, 
        # average across all filters
        pooling="avg"
    )

In [None]:
image_encoder = build_image_encoder_on_tiny_conv()

In [None]:
image_encoder.summary()

In [None]:
# We can freeze model by making weights non-trainable.
# In this example we leave the last layer trainable
image_encoder.trainable = True 

for layer in image_encoder.layers[:-1]:
    layer.trainable = False

image_encoder.summary()

In [None]:
# Provide some image as input to check encoder output
image_embeddings = image_encoder(X_test[:10])
image_embeddings.shape

## Text Encoder

Similarly, text encoder represents any text as a vector in some multidimensional space. For our text encoder we take pretained embedding model, that we discussed earlier (`glove`). Based on this model, we build a sequential model with an average over all vectors of words in the provided text.

### Loading embedding model

In [None]:
def load_vectors(path_to_file):
    """Load words and their weights from file."""
    words = list()
    embeddings = list()
    with open(path_to_file) as f:
        for line in f:
            word, coefs = line.split(maxsplit=1)
            coefs = np.fromstring(coefs, 'f', sep=' ')
            words.append(word)
            embeddings.append(coefs)
    return np.array(words), np.array(embeddings)


In [None]:
EMBEDDING_DIM = 100
FILEPATH = f'/media/sf_practice/data/debug_glove/glove.6B/glove.6B.{EMBEDDING_DIM}d.txt'

# Load words and their embeddings
words, embeddings = load_vectors(FILEPATH)
words[:5]

In [None]:
# Embeddings
embeddings.shape

### Building vectorizer layer

In [None]:
MAX_TEXT_LENGTH = 20
NUM_FEATURES = len(words) + 2


def build_vectorizer_layer():
    # setup vectorizer layer
    vectorizer_layer = layers.TextVectorization(
        max_tokens=NUM_FEATURES, 
        output_sequence_length=MAX_TEXT_LENGTH,
        output_mode="int"
    )
    # build vocabulary
    vectorizer_layer.adapt(words)
    return vectorizer_layer


def build_vectorizer_layer_alt():
    # setup vectorizer layer
    vectorizer_layer = layers.TextVectorization(
        max_tokens=NUM_FEATURES, 
        output_sequence_length=MAX_TEXT_LENGTH,
        output_mode="int"
    )
    # set vocabulary
    vectorizer_layer.set_vocabulary(words)
    return vectorizer_layer

In [None]:
vectorizer_layer = build_vectorizer_layer_alt()
print(f'Number of tokens:\t{len(vectorizer_layer.get_vocabulary())}')
print(f'First few tokens:\t{vectorizer_layer.get_vocabulary()[:5]}')

In [None]:
# Convert text to token ids
token_ids = vectorizer_layer('A photo of a cat').numpy()
token_ids

### Building embedding layer

In [None]:
# There are 2 extra tokens: padding and [UNK]
E = np.zeros((NUM_FEATURES, EMBEDDING_DIM))
# E[1] = np.random.normal(0, 0.1, EMBEDDING_DIM) # [UNK]
E[2:] = embeddings

In [None]:
def build_embedding_layer():
    # setup embedding layer
    embedding_layer = layers.Embedding(
        input_dim=NUM_FEATURES,
        output_dim=EMBEDDING_DIM,
        trainable=False  # disable training
    )
    # initialize weights
    embedding_layer.build((1, ))
    # set weights
    embedding_layer.set_weights([E])
    return embedding_layer

In [None]:
embedding_layer = build_embedding_layer()

In [None]:
# Pass token ids as example
embedding_layer(token_ids).shape

### Building text encoder

In [None]:
def build_text_encoder_on_glove():
    model = models.Sequential()
    model.add(vectorizer_layer)
    model.add(embedding_layer)
    model.add(layers.GlobalAveragePooling1D())
    # model.add(layers.Dense(64, activation="relu"))
    return model

In [None]:
text_encoder = build_text_encoder_on_glove()

In [None]:
labels

In [None]:
text_embeddings = text_encoder(labels)
text_embeddings.shape

## Multimodal Model

### Preparing dataset

In [None]:
def make_dataset(X, y, batch_size=64):
    y = np.array([f'A photo of a {labels[target[0]]}' for target in y])
    return convert_to_tf_dataset(X, y, batch_size)

In [None]:
train_ds = make_dataset(X_train, y_train, batch_size=128)
val_ds = make_dataset(X_val, y_val, batch_size=128)
test_ds = make_dataset(X_test, y_test, batch_size=128)

In [None]:
# print_first_batch(train_ds)

### CLIP-like model

CLIP (Contrastive Language-Image Pretraining)

In [None]:
class ProjectionLayer(layers.Layer):
    """
    Layer to project representation of some modality
    """
    
    def __init__(self, projection_dim=64, dropout_rate=0.1):
        super().__init__()
        self.dense = layers.Dense(projection_dim)
        self.layer_norm = layers.LayerNormalization()
        
    def call(self, inputs):
        x = self.dense(inputs)
        x = self.layer_norm(x)
        return x


In [None]:
class VanillaCLIPModel(Model):
    """
    Simple CLIP-like model to visual-text representation
    """
    
    def __init__(self, image_encoder, text_encoder, initial_temperature=0.07, projection_dim=64):
        super().__init__()
        # encoders for modality representation
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        # project to shared multimodal space
        self.image_projection = ProjectionLayer(projection_dim)
        self.text_projection = ProjectionLayer(projection_dim)
        # control logits range
        self.logit_scale = self.add_weight(
            initializer=tf.constant_initializer(np.log(1.0 / initial_temperature)),
            trainable=True,
            dtype=tf.float32,
            shape=(),
            name="logit_scale"
        )
        # metrics
        self.loss_tracker = metrics.Mean(name='loss')
        self.accuracy_tracker = metrics.Accuracy(name='accuracy')

    def call(self, inputs, training=False):
        # unpack inputs
        I, T = inputs
        # encode inputs to get features
        I_features = self.image_encoder(I, training=training)
        T_features = self.text_encoder(T, training=training)
        # project to shared multimodal space (multimodal embeddings)
        I_projections = self.image_projection(I_features, training=training)
        T_projections = self.text_projection(T_features, training=training)
        # normalize projections
        I_projections = tf.math.l2_normalize(I_projections, axis=1)
        T_projections = tf.math.l2_normalize(T_projections, axis=1)
        return I_projections, T_projections

    def compute_contrastive_loss(self, I_projections, T_projections):
        # compute similarity matrix: image-to-text and text-to-image
        I_logits = tf.matmul(I_projections, T_projections, transpose_b=True)
        I_logits *= tf.exp(self.logit_scale)
        T_logits = tf.transpose(I_logits)
        # create targets (Y) in one-hot form
        batch_size = tf.shape(I_projections)[0]
        Y = tf.eye(batch_size)
        # compute cross entropy losses
        I_loss = tf.keras.losses.categorical_crossentropy(
            y_true=Y, 
            y_pred=I_logits, 
            from_logits=True
        )
        T_loss = tf.keras.losses.categorical_crossentropy(
            y_true=Y, 
            y_pred=T_logits, 
            from_logits=True
        )
        total_loss = (I_loss + T_loss) / 2
        return tf.reduce_mean(total_loss), I_logits

    @tf.function
    def train_step(self, data):
        # unpack inputs
        I, T = data
        # record operation graph
        with tf.GradientTape() as tape:
            # tape.watch(self.logit_scale)
            I_projections, T_projections = self.call(
                inputs=(I, T), 
                training=True
            )
            loss, logits = self.compute_contrastive_loss(
                I_projections=I_projections, 
                T_projections=T_projections
            )
        # compute gradients and update weights
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # update metrics
        self.loss_tracker.update_state(loss)
        y_pred = tf.argmax(logits, axis=1)
        y_true = tf.range(tf.shape(logits)[0])
        self.accuracy_tracker.update_state(y_true, y_pred)
        return {
            'loss': self.loss_tracker.result(),
            'accuracy': self.accuracy_tracker.result(),
            "temperature": self.temperature,
            "logit_scale": self.logit_scale
        }

    def test_step(self, data):
        # unpack inputs
        I, T = data
        # compute projections
        I_projections, T_projections = self.call(
            inputs=(I, T), 
            training=False
        )
        # compute loss
        loss, logits = self.compute_contrastive_loss(
            I_projections=I_projections, 
            T_projections=T_projections
        )
        # Update metrics
        self.loss_tracker.update_state(loss)
        y_pred = tf.argmax(logits, axis=1)
        y_true = tf.range(tf.shape(logits)[0])
        self.accuracy_tracker.update_state(y_true, y_pred)
        return {
            "loss": self.loss_tracker.result(),
            "accuracy": self.accuracy_tracker.result()
        }

    @property
    def temperature(self):
        return 1.0 / tf.exp(self.logit_scale)

In [None]:
# Build the model
clip_model = VanillaCLIPModel(
    image_encoder, 
    text_encoder, 
    initial_temperature=0.07
)
clip_model.summary()

In [None]:
clip_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4)
    # optimizer=tf.keras.optimizers.SGD(learning_rate=0.1)
)

In [None]:
NUM_EPOCHS = 5

clip_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=NUM_EPOCHS,
    verbose=1
)

### Inference

In [None]:
from abc import ABC, abstractmethod

In [None]:
class CLIPInference(ABC):

    @abstractmethod
    def project_image(self, image):
        pass

    @abstractmethod
    def project_text(self, text):
        pass

    def compute_similarity(self, image_projections, text_projections):
        return tf.matmul(image_projections, text_projections, transpose_b=True) * tf.exp(self.logit_scale)
    
    def rank_texts(self, image, texts):
        I_projections = self.project_image(image)
        T_projections = self.project_text(texts)
        I_logits = self.compute_similarity(I_projections, T_projections)
        return self.format_rank_output(texts, I_logits)
    
    def rank_images(self, text, images):
        T_projections = self.project_text(text)
        I_projections = tf.concat([self.project_image(image) for image in images], axis=0)
        T_logits = self.compute_similarity(T_projections, I_projections)
        # T_logits = tf.transpose(self.compute_similarity(I_projections, T_projections))
        return self.format_rank_output(images, T_logits)

    @staticmethod
    def format_rank_output(inputs, logits):
        indices = tf.argsort(logits, axis=1, direction='DESCENDING')
        values = tf.gather(inputs, indices)
        probabilities = tf.nn.softmax(logits, axis=1)[0]
        probabilities_sorted = tf.gather(probabilities, indices)
        return indices.numpy().flatten(), values.numpy()[0], probabilities_sorted.numpy().flatten()


In [None]:
class VanillaCLIPInference(CLIPInference):
    
    def __init__(self, model):
        self.image_encoder = model.image_encoder
        self.text_encoder = model.text_encoder
        self.image_projection = model.image_projection
        self.text_projection = model.text_projection
        self.logit_scale = model.logit_scale
        
    def project_image(self, image):
        if len(image.shape) == 3:
            image = tf.expand_dims(image, 0)
        I_features = self.image_encoder(image, training=False)
        I_projections = self.image_projection(I_features, training=False)
        return tf.math.l2_normalize(I_projections, axis=1)
    
    def project_text(self, text):
        if isinstance(text, str):
            text = tf.constant([text])
        T_features = self.text_encoder(text, training=False)
        T_projections = self.text_projection(T_features, training=False)
        return tf.math.l2_normalize(T_projections, axis=1)


In [None]:
inference = VanillaCLIPInference(clip_model)

**Projecting to multimodal space**

In [None]:
test_image = X_test[10]
test_text = "image of airplane"

plt.figure(figsize=(2,2))
plt.imshow(test_image)
plt.axis("off")
plt.show()

In [None]:
I_projections = inference.project_image(test_image)
I_projections

In [None]:
T_projections = inference.project_text(test_text)
T_projections

In [None]:
inference.compute_similarity(I_projections, T_projections)

**Ranking texts**

In [None]:
test_image = X_test[15]

plt.figure(figsize=(2,2))
plt.imshow(test_image)
plt.axis("off")
plt.show()

indices, values, probs = inference.rank_texts(test_image, labels)
indices, values, probs

**Ranking images**

In [None]:
test_text = "Image of an airplane"

In [None]:
indices, values, probs = inference.rank_images(test_text, X_test[:100])
indices, values.shape, probs

In [None]:
plt.figure(figsize=[14, 4])
for index, image in enumerate(values[:10]):
    plt.subplot(1, NUM_DISPLAY_IMAGES, index+1)
    plt.imshow(image)
    plt.axis("off")
plt.show()

## Pretrained `CLIP` model

In [None]:
from transformers import AutoProcessor, TFCLIPModel

### Loading model

In [None]:
CHECKPOINT = "openai/clip-vit-base-patch32"

In [None]:
processor = AutoProcessor.from_pretrained(CHECKPOINT)
model = TFCLIPModel.from_pretrained(CHECKPOINT)

In [None]:
# model.config

### Processor

In [None]:
IMAGE_INDEX = 1

In [None]:
# Single sample
image = X_test[IMAGE_INDEX]
target = y_test[IMAGE_INDEX]
label = labels[target]
text = 'a photo of a cat'

image.shape, label

In [None]:
# Show image
plt.figure(figsize=(2,2))
plt.imshow(image)
plt.axis("off")
plt.show()

**Image processor**

In [None]:
image_input = processor.image_processor(image)
image_input.keys()

In [None]:
image_input.pixel_values[0].shape

**Text tokenizer**

In [None]:
text_input = processor.tokenizer(text)
text_input

**Processor**

In [None]:
inputs = processor(
    text=text, 
    images=image, 
    return_tensors="tf", 
    padding=True
)
inputs.keys()

In [None]:
I_pixels = inputs.pixel_values
T_ids = inputs.input_ids
T_mask = inputs.attention_mask

In [None]:
I_pixels.shape, T_ids, T_mask

### Inference

**Model outputs**

In [None]:
# Input single image and list of texts (class labels for simplicity)
images = image
texts = labels.tolist()

images.shape, texts

In [None]:
# Process model inputs
inputs = processor(
    text=texts, 
    images=images, 
    return_tensors="tf", 
    padding=True
)
inputs.keys()

In [None]:
# Inference
outputs = model(**inputs, training=False)
outputs.keys()

In [None]:
# Normalized projections
I_projections = outputs.image_embeds
T_projections = outputs.text_embeds

I_projections.shape, T_projections.shape

In [None]:
# Similarity matrix
I_logits = outputs.logits_per_image
T_logits = outputs.logits_per_text

I_logits.shape, T_logits.shape

In [None]:
# np.linalg.norm(I_projections.numpy(), axis=1)

In [None]:
# Unnormalized projections
I_projections_unnorm = model.get_image_features(
    pixel_values=inputs['pixel_values']
)
T_projections_unnorm = model.get_text_features(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask']
)

I_projections_unnorm.shape, T_projections_unnorm.shape

In [None]:
# Rank texts by image
CLIPInference.format_rank_output(texts, I_logits)

**Inference class implementation**

In [None]:
class TransformerCLIPInference(CLIPInference):
    
    def __init__(self, model, processor):
        self.model = model
        self.processor = processor
        self.logit_scale = self._get_logit_scale(model)

    def _get_logit_scale(self, model):
        for var in model.trainable_variables:
            if 'logit_scale' in var.name.lower():
                return var
        return tf.Variable(0.07)
        
    def project_image(self, image):
        if len(image.shape) == 3:
            image = tf.expand_dims(image, 0)
        inputs = processor(
            images=image, 
            return_tensors="tf", 
            padding=True
        )
        I_projections_unorm = model.get_image_features(**inputs)
        return tf.math.l2_normalize(I_projections_unorm, axis=1)
    
    def project_text(self, text):
        if isinstance(text, np.ndarray):
            text = text.tolist()
        inputs = processor(
            text=text, 
            return_tensors="tf", 
            padding=True
        )
        I_projections_unorm = model.get_text_features(**inputs)
        return tf.math.l2_normalize(I_projections_unorm, axis=1)


In [None]:
inference = TransformerCLIPInference(model, processor)

**Ranking texts**

In [None]:
test_image = X_test[15]

plt.figure(figsize=(2,2))
plt.imshow(test_image)
plt.axis("off")
plt.show()

indices, values, probs = inference.rank_texts(test_image, labels)
indices, values, probs

**Ranking images**

In [None]:
test_text = "Image of a car"

In [None]:
indices, values, probs = inference.rank_images(test_text, X_test[:100])
indices, values.shape, probs

In [None]:
plt.figure(figsize=[14, 4])
for index, image in enumerate(values[:10]):
    plt.subplot(1, NUM_DISPLAY_IMAGES, index+1)
    plt.imshow(image)
    plt.axis("off")
plt.show()

## Sources

- [Radford, A., et al. Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
- [Transformers documentation: CLIP](https://huggingface.co/docs/transformers/model_doc/clip)