In [None]:
# installing tensorflow extra due to incompatibility with conda and tensorflow-text https://github.com/tensorflow/text/issues/644
!pip install transformers[tf] -q --upgrade
!pip install sentence-transformers -q # needed for validating results


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m66.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.8/83.8 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m451.2/451.2 kB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.7/527.7 kB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from transformers import BertTokenizer, TFBertModel, TFAutoModel
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import Model
import numpy as np

In [None]:
class TFSTLayer(tf.keras.layers.Layer):
    def __init__(self, model_name: str) -> None:
        super(TFSTLayer, self).__init__()
        self.tf_model = TFAutoModel.from_pretrained(model_name)

    def call(self, input_ids, attention_mask, token_type_ids, normalize=True):
        # Compute the model output
        output = self.tf_model(input_ids, attention_mask, token_type_ids)

        # Compute the token embeddings
        token_embeddings = output.last_hidden_state  # shape=(B, max_seq_length, n_embd), dtype=float32

        # Mean Pooling
        embedding = self.mean_pooling(token_embeddings, attention_mask)  # shape=(B, n_embd), dtype=float32

        if normalize:
            embedding, _ = tf.linalg.normalize(embedding, 2, axis=1)  # shape=(B, n_embd), dtype=float32

        return embedding

    def mean_pooling(self, token_embeddings, attention_mask):
        attention_mask = tf.expand_dims(attention_mask, axis=-1)  # shape=(B, max_seq_length, 1), dtype=int32
        attention_mask = tf.broadcast_to(attention_mask, tf.shape(token_embeddings))  # shape=(B, max_seq_length, n_embd), dtype=int32
        attention_mask = tf.cast(attention_mask, dtype=tf.float32)  # shape=(B, max_seq_length, n_embd), dtype=float32
        token_embeddings = token_embeddings * attention_mask  # shape=(B, max_seq_length, n_embd), dtype=float32

        # Taking mean over all the tokens (max_seq_length axis)
        mean_embeddings = tf.reduce_sum(token_embeddings, axis=1)  # shape=(B, n_embd), dtype=float32
        # Alternatively, you can replace the `mean_pooling` method with `tf.keras.layers.GlobalAveragePooling1D`:
        # mean_pooling = tf.keras.layers.GlobalAveragePooling1D()
        # mean_embeddings = mean_pooling(token_embeddings)
        return mean_embeddings

In [None]:

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
input_data = [
    ['sentence A1', 'sentence B1'],
    ['sentence A2', 'sentence B2'],
    ['sentence A3', 'sentence B3']
]
tokenized_data = [(tokenizer(s1, padding='max_length', max_length=512, return_tensors='tf'),
                   tokenizer(s2, padding='max_length', max_length=512, return_tensors='tf'))
                  for s1, s2 in input_data]


In [None]:
model_name = 'sentence-transformers/all-MiniLM-L6-v2'
bert_model = TFBertModel.from_pretrained(model_name)


Downloading (…)lve/main/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading tf_model.h5:   0%|          | 0.00/91.0M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFBertModel.

All the layers of TFBertModel were initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


In [None]:
np.shape(tokenized_data)

(3, 2, 3)

In [None]:
tokenized_data

array([[['input_ids', 'token_type_ids', 'attention_mask'],
        ['input_ids', 'token_type_ids', 'attention_mask']],

       [['input_ids', 'token_type_ids', 'attention_mask'],
        ['input_ids', 'token_type_ids', 'attention_mask']],

       [['input_ids', 'token_type_ids', 'attention_mask'],
        ['input_ids', 'token_type_ids', 'attention_mask']]], dtype='<U14')