In [43]:
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer
from datasets import load_dataset
import numpy as np
import chromadb
from tqdm import tqdm

In [2]:
dataset = load_dataset("microsoft/ms_marco", "v1.1")

In [3]:
dataset["train"][0]

{'answers': ['Results-Based Accountability is a disciplined way of thinking and taking action that communities can use to improve the lives of children, youth, families, adults and the community as a whole.'],
 'passages': {'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
  'passage_text': ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.",
   "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the C

In [4]:
def prepare_data(examples):
    queries = []
    positives = []
    negatives = []
    
    for query, passages_dict in zip(examples['query'], examples['passages']):
        positive = None
        negative_list = []
        
        for passage, label in zip(passages_dict['passage_text'], passages_dict['is_selected']):
            if label == 1 and positive is None:
                positive = passage
            elif label == 0:
                negative_list.append(passage)
        
        if positive:
            queries.append(query)
            positives.append(positive)
            negatives.append(negative_list[:5])
    
    return {
        'queries': queries,
        'positives': positives,
        'negatives': negatives
    }

In [5]:
dataset_processed = dataset.map(prepare_data, batched=True, remove_columns=dataset['train'].column_names)

In [6]:
dataset_processed["train"][0]

{'queries': 'what is rba',
 'positives': 'Results-Based Accountability® (also known as RBA) is a disciplined way of thinking and taking action that communities can use to improve the lives of children, youth, families, adults and the community as a whole. RBA is also used by organizations to improve the performance of their programs. Creating Community Impact with RBA. Community impact focuses on conditions of well-being for children, families and the community as a whole that a group of leaders is working collectively to improve. For example: “Residents with good jobs,” “Children ready for school,” or “A safe and clean neighborhood”.',
 'negatives': ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have 

In [7]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [8]:
MAX_TOKENS_LEN = 224

In [9]:
def tokenize_function(examples):
    query_input_ids = []
    query_attention_masks = []
    positives_input_ids = []
    positives_attention_masks = []
    negatives_input_ids = []
    negatives_attention_masks = []

    for i in range(len(examples["queries"])):
        query_tokens = tokenizer(
            examples["queries"][i],
            padding="max_length",
            truncation=True,
            max_length=MAX_TOKENS_LEN,
            return_tensors="tf"
        )
        query_input_ids.append(query_tokens["input_ids"][0])
        query_attention_masks.append(query_tokens["attention_mask"][0])

        positive_tokens = tokenizer(
            examples["positives"][i],
            padding="max_length",
            truncation=True,
            max_length=MAX_TOKENS_LEN,
            return_tensors="tf"
        )
        positives_input_ids.append(positive_tokens["input_ids"][0])
        positives_attention_masks.append(positive_tokens["attention_mask"][0])

        neg_input_ids = []
        neg_attention_masks = []

        for neg_text in examples["negatives"][i]:
            negative_tokens = tokenizer(
                neg_text,
                padding="max_length",
                truncation=True,
                max_length=MAX_TOKENS_LEN,
                return_tensors="tf"
            )
            neg_input_ids.append(negative_tokens["input_ids"][0])
            neg_attention_masks.append(negative_tokens["attention_mask"][0])
        
        negatives_input_ids.append(neg_input_ids)
        negatives_attention_masks.append(neg_attention_masks)

    return {
        "query_input_ids": tf.stack(query_input_ids),
        "query_attention_mask": tf.stack(query_attention_masks),
        "positives_input_ids": tf.stack(positives_input_ids),
        "positives_attention_mask": tf.stack(positives_attention_masks),
        "negatives_input_ids": negatives_input_ids,
        "negatives_attention_mask": negatives_attention_masks
    }

In [10]:
tokenized_dataset = dataset_processed.map(tokenize_function, batched=True, remove_columns=dataset_processed['train'].column_names)

In [11]:
BATCH_SIZE = 32
EPOCHS = 1
NUM_NEGATIVES = 5

In [13]:
def dataset_generator():
    for item in tokenized_dataset["train"]:
        if len(item["negatives_input_ids"]) < NUM_NEGATIVES:
            continue
            
        query_ids = np.array(item["query_input_ids"], dtype=np.int32)
        query_mask = np.array(item["query_attention_mask"], dtype=np.int32)

        pos_ids = np.array(item["positives_input_ids"], dtype=np.int32)
        pos_mask = np.array(item["positives_attention_mask"], dtype=np.int32)
        
        neg_ids = np.array(item["negatives_input_ids"][:NUM_NEGATIVES], dtype=np.int32)
        neg_mask = np.array(item["negatives_attention_mask"][:NUM_NEGATIVES], dtype=np.int32)
        
        yield (query_ids, query_mask, pos_ids, pos_mask, neg_ids, neg_mask)

output_signature = (
    tf.TensorSpec(shape=(224,), dtype=tf.int32),
    tf.TensorSpec(shape=(224,), dtype=tf.int32),
    tf.TensorSpec(shape=(224,), dtype=tf.int32),
    tf.TensorSpec(shape=(224,), dtype=tf.int32),
    tf.TensorSpec(shape=(5, 224), dtype=tf.int32),
    tf.TensorSpec(shape=(5, 224), dtype=tf.int32)
)

tf_train_dataset = tf.data.Dataset.from_generator(
    dataset_generator,
    output_signature=output_signature
).shuffle(1000).batch(BATCH_SIZE).map(
    lambda q_ids, q_mask, p_ids, p_mask, n_ids, n_mask: (
        {"input_ids": q_ids, "attention_mask": q_mask},
        {"input_ids": p_ids, "attention_mask": p_mask},
        {"input_ids": n_ids, "attention_mask": n_mask}
    )
).cache().prefetch(tf.data.AUTOTUNE)

In [14]:
def dataset_generator():
    for item in tokenized_dataset["validation"]:
        if len(item["negatives_input_ids"]) < NUM_NEGATIVES:
            continue
            
        query_ids = np.array(item["query_input_ids"], dtype=np.int32)
        query_mask = np.array(item["query_attention_mask"], dtype=np.int32)

        pos_ids = np.array(item["positives_input_ids"], dtype=np.int32)
        pos_mask = np.array(item["positives_attention_mask"], dtype=np.int32)
        
        neg_ids = np.array(item["negatives_input_ids"][:NUM_NEGATIVES], dtype=np.int32)
        neg_mask = np.array(item["negatives_attention_mask"][:NUM_NEGATIVES], dtype=np.int32)
        
        yield (query_ids, query_mask, pos_ids, pos_mask, neg_ids, neg_mask)

output_signature = (
    tf.TensorSpec(shape=(224,), dtype=tf.int32),
    tf.TensorSpec(shape=(224,), dtype=tf.int32),
    tf.TensorSpec(shape=(224,), dtype=tf.int32),
    tf.TensorSpec(shape=(224,), dtype=tf.int32),
    tf.TensorSpec(shape=(5, 224), dtype=tf.int32),
    tf.TensorSpec(shape=(5, 224), dtype=tf.int32)
)

tf_valid_dataset = tf.data.Dataset.from_generator(
    dataset_generator,
    output_signature=output_signature
).batch(BATCH_SIZE).map(
    lambda q_ids, q_mask, p_ids, p_mask, n_ids, n_mask: (
        {"input_ids": q_ids, "attention_mask": q_mask},
        {"input_ids": p_ids, "attention_mask": p_mask},
        {"input_ids": n_ids, "attention_mask": n_mask}
    )
).cache().prefetch(tf.data.AUTOTUNE)

In [15]:
train_size = len(tokenized_dataset["train"])
valid_size = len(tokenized_dataset["validation"])

steps_per_epoch = train_size // BATCH_SIZE
validation_steps = valid_size // BATCH_SIZE

print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")

Steps per epoch: 2490
Validation steps: 303


In [16]:
@tf.keras.utils.register_keras_serializable(package="MyModels")
class BertWrapper(tf.keras.layers.Layer):
    def __init__(self, model_name="bert-base-uncased", **kwargs):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.model = TFBertModel.from_pretrained(model_name)
    
    def call(self, inputs, training=False):
        return self.model(inputs, training=training)

    def get_config(self):
        config = super().get_config()
        config.update({
            "model_name": self.model_name
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [17]:
@tf.keras.saving.register_keras_serializable(package="MyModels")
class BiEncoder(tf.keras.Model):
    def __init__(self, bert_model, projection_size=256, embedding_size=128, **kwargs):
        super().__init__(**kwargs)
        self.bert = bert_model
        self.bert.trainable = False
        self.projection_size = projection_size
        self.embedding_size = embedding_size
        
        self.projection = tf.keras.layers.Dense(projection_size, activation='gelu', kernel_regularizer=tf.keras.regularizers.l2(1e-5))
        self.dropout = tf.keras.layers.Dropout(0.1)
        self.output_layer = tf.keras.layers.Dense(embedding_size)

        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.mrr_tracker = tf.keras.metrics.Mean(name="mrr")
    
    def call(self, inputs, training=False):
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        is_3d = len(input_ids.shape) == 3
        original_shape = tf.shape(input_ids)
        
        if is_3d:
            input_ids = tf.reshape(input_ids, [-1, original_shape[-1]])
            attention_mask = tf.reshape(attention_mask, [-1, original_shape[-1]])
        
        bert_outputs = self.bert({"input_ids": input_ids, "attention_mask": attention_mask}, training=False)

        x = self.projection(bert_outputs.pooler_output, training=training)
        x = self.dropout(x, training=training)

        embeddings = self.output_layer(x, training=training)
        embeddings = tf.nn.l2_normalize(embeddings, axis=1)
    
        if is_3d:
            embeddings = tf.reshape(embeddings, [original_shape[0], original_shape[1], self.embedding_size])
                                    
        return embeddings
  
    # Multiple Negatives Ranking Loss
    def compute_loss(self, query_emb, pos_emb, neg_emb):
        batch_size = tf.shape(query_emb)[0]

        pos_sim = tf.reduce_sum(query_emb * pos_emb, axis=1)
        neg_sim = tf.reduce_sum(query_emb[:, None, :] * neg_emb, axis=-1)

        similarities = tf.concat([pos_sim[:, None], neg_sim], axis=1)
        labels = tf.zeros(batch_size, dtype=tf.int32)
        loss = tf.keras.losses.sparse_categorical_crossentropy(labels, similarities, from_logits=True)
        
        return loss, similarities
    
    # MRR для оцінки ранжування
    def compute_mrr(self, similarities):
        batch_size = tf.shape(similarities)[0]
        labels = tf.zeros(batch_size, dtype=tf.int32)
        
        ranks = tf.argsort(similarities, axis=-1, direction='DESCENDING')
        positive_ranks = tf.where(tf.equal(ranks, labels[:, None]))[:, 1]
        reciprocal_ranks = 1.0 / (tf.cast(positive_ranks, tf.float32) + 1.0)
        
        return tf.reduce_mean(reciprocal_ranks)
    
    def train_step(self, data):
        query_inputs, positive_inputs, negative_inputs = data
        
        with tf.GradientTape() as tape:
            query_embeddings = self(query_inputs, training=True)
            positive_embeddings = self(positive_inputs, training=True)
            negative_embeddings = self(negative_inputs, training=True)
            loss, similarities = self.compute_loss(query_embeddings, positive_embeddings, negative_embeddings)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        mrr = self.compute_mrr(similarities)
        self.loss_tracker.update_state(loss)
        self.mrr_tracker.update_state(mrr)
        
        return {"loss": self.loss_tracker.result(), "mrr": self.mrr_tracker.result()}
    
    def test_step(self, data):
        query_inputs, positive_inputs, negative_inputs = data
        
        query_embeddings = self(query_inputs, training=False)
        positive_embeddings = self(positive_inputs, training=False)
        negative_embeddings = self(negative_inputs, training=False)
        loss, similarities = self.compute_loss(query_embeddings, positive_embeddings, negative_embeddings)
        
        mrr = self.compute_mrr(similarities)
        self.loss_tracker.update_state(loss)
        self.mrr_tracker.update_state(mrr)
        
        return {"loss": self.loss_tracker.result(), "mrr": self.mrr_tracker.result()}
    
    def get_text_embedding(self, inputs):
        return self(inputs, training=False)
        
    # Цей метод дозволяє автоматично виводити метрики при навчанні через model.fit()
    @property
    def metrics(self):
        return [self.loss_tracker, self.mrr_tracker]
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "bert_model": tf.keras.saving.serialize_keras_object(self.bert),
            "projection_size": self.projection_size,
            "embedding_size": self.embedding_size
        })
        return config

    @classmethod
    def from_config(cls, config):
        bert_model = tf.keras.saving.deserialize_keras_object(config["bert_model"])
        return cls(
            bert_model=bert_model,
            projection_size=config["projection_size"],
            embedding_size=config["embedding_size"]
        )

In [18]:
@tf.keras.saving.register_keras_serializable(package="MySchedules")
class CustomLearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr=1e-3, decay_rate=0.9, decay_steps=2000, min_lr=1e-5):
        super().__init__()
        self.initial_lr = initial_lr
        self.decay_rate = decay_rate
        self.decay_steps = decay_steps
        self.min_lr = min_lr
        
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        decay_factor = self.decay_rate ** (step / self.decay_steps)
        lr = self.initial_lr * decay_factor
        lr = tf.maximum(lr, self.min_lr)
                
        return lr
    
    def get_config(self):
        return {
            "initial_lr": self.initial_lr,
            "decay_rate": self.decay_rate,
            "decay_steps": self.decay_steps,
            "min_lr": self.min_lr
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)

In [19]:
bert_wrapper = BertWrapper("bert-base-uncased")
bi_encoder = BiEncoder(bert_model=bert_wrapper)
lr_schedule = CustomLearningRateSchedule()

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

In [22]:
bi_encoder.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule), weighted_metrics=["mrr"])

In [23]:
history = bi_encoder.fit(tf_train_dataset, validation_data=tf_valid_dataset, epochs=EPOCHS, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, verbose=1)



In [59]:
client = chromadb.Client()

In [60]:
collection = client.create_collection(name="passages")

In [61]:
def process_batch(passages_batch, query_id, start_idx):
    tokenized_batch = tokenizer(
        passages_batch,
        padding="max_length",
        truncation=True,
        max_length=MAX_TOKENS_LEN,
        return_tensors="tf"
    )
    embeddings_batch = bi_encoder({
        "input_ids": tokenized_batch["input_ids"],
        "attention_mask": tokenized_batch["attention_mask"]
    }, training=False)

    results = []
    for idx, passage_text in enumerate(passages_batch):
        results.append({
            "id": f"{query_id}_{start_idx + idx}",
            "text": passage_text,
            "embedding": np.squeeze(embeddings_batch[idx].numpy()).tolist()
        })
    return results

def process_and_store(dataset, batch_size=32):
    with tqdm(total=len(dataset["test"]), desc="Обробка запитів") as pbar:
        for item in dataset["test"]:
            passages = item["passages"]["passage_text"]
            query_id = item["query_id"]
            
            for start_idx in range(0, len(passages), batch_size):
                passages_batch = passages[start_idx:start_idx + batch_size]
                
                results_batch = process_batch(passages_batch, query_id, start_idx)
                
                for result in results_batch:
                    collection.add(
                        ids=[result["id"]],
                        documents=[result["text"]],
                        embeddings=[result["embedding"]]
                    )
            
            pbar.update(1)

In [62]:
process_and_store(dataset)

Обробка запитів: 100%|██████████| 9650/9650 [24:54<00:00,  6.46it/s]


In [79]:
def select_top_5_queries_with_selected(dataset):
    selected_queries = []
    for item in dataset["test"]:
        if 1 in item["passages"]["is_selected"]:
            selected_queries.append(item)
        if len(selected_queries) >= 5:
            break
    return selected_queries

def search(selected_queries, n_results=5):
    for item in selected_queries:
        tokenized_query = tokenizer(
            item["query"],
            padding="max_length",
            truncation=True,
            max_length=MAX_TOKENS_LEN,
            return_tensors="tf"
        )

        query_embedding = bi_encoder({
            "input_ids": tokenized_query["input_ids"],
            "attention_mask": tokenized_query["attention_mask"]
        }, training=False).numpy()[0].tolist()

        search_results = collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results
        )
        print(f'query = "{item["query"]}"')
        [print("   " + str(result)) for result in np.squeeze(search_results["documents"])]
        print()

selected_queries = select_top_5_queries_with_selected(dataset)
search(selected_queries)

query = "does human hair stop squirrels"
   one acre equals 0 0015625 square miles 4840 square yards 43560 square feet or about 4047 square metres 0 405 hectares see below
   Karen Nicol is an embroidery and mixed media textile artist working in gallery, fashion and interiors with a London based design and production studio established for over twenty-five years. 
   Trap rock is a name used in the construction industry for any dark-colored igneous rock that is used to produce crushed stone.
   Pilocarpine is a drug used to treat dry mouth and glaucoma. It is a parasympathomimetic alkaloid obtained from the leaves of tropical South American shrubs from the genus Pilocarpus. Pilocarpine is used to stimulate sweat glands in a sweat test to measure the concentration of chloride and sodium that is excreted in sweat. It is used to diagnose cystic fibrosis.
   A castle (from Latin: castellum) is a type of fortified structure built in Europe and the Middle East during the Middle Ages by nobil