In [2]:
!pip install -qqq datasets

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
!pip install -qqq transformers==4.45.2 sentence-transformers==3.1.1

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m58.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m245.3/245.3 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
import datasets
import pandas as pd

In [None]:
# Load the pre-trained model
model_id = "BAAI/bge-small-en"
model = SentenceTransformer(model_id)

In [None]:
# Load the dataset
obliqa_dataset = datasets.load_dataset("DrishtiSharma/obliqa")

In [None]:
# Constants
BATCH_SIZE = 10
EPOCHS = 2

# Prepare the training data
train_dataset = obliqa_dataset['train']
corpus = {}
queries = {}
relevant_docs = {}
examples = []

# Process the train split
for row in train_dataset:
    query_id = row['QuestionID']
    query = row['Question']
    passages = row['Passages']

    # Add query to queries dictionary
    queries[query_id] = query

    # Add passages to corpus and relevant_docs
    for passage in passages:
        passage_id = passage['PassageID']
        corpus[passage_id] = passage['Passage']

        if query_id not in relevant_docs:
            relevant_docs[query_id] = []
        relevant_docs[query_id].append(passage_id)

# Create InputExample instances
for query_id, query in queries.items():
    if query_id in relevant_docs and relevant_docs[query_id]:
        passage_id = relevant_docs[query_id][0]
        text = corpus[passage_id]
        examples.append(InputExample(texts=[query, text]))

# Set up training DataLoader and loss
train_loader = DataLoader(examples, batch_size=BATCH_SIZE, shuffle=True)
loss = losses.MultipleNegativesRankingLoss(model)

# Prepare validation data
val_dataset = obliqa_dataset['validation']
val_corpus = {}
val_queries = {}
val_relevant_docs = {}

for row in val_dataset:
    query_id = row['QuestionID']
    query = row['Question']
    passages = row['Passages']

    val_queries[query_id] = query

    for passage in passages:
        passage_id = passage['PassageID']
        val_corpus[passage_id] = passage['Passage']

        if query_id not in val_relevant_docs:
            val_relevant_docs[query_id] = []
        val_relevant_docs[query_id].append(passage_id)

# Define evaluator
evaluator = InformationRetrievalEvaluator(
    queries=val_queries,
    corpus=val_corpus,
    relevant_docs=val_relevant_docs,
    show_progress_bar=True
)

# Bypass `num_items_in_batch` issue
import transformers
from transformers import Trainer

original_compute_loss = Trainer.compute_loss

def patched_compute_loss(self, model, inputs, return_outputs=False, **kwargs):
    if "num_items_in_batch" in kwargs:
        kwargs.pop("num_items_in_batch")
    return original_compute_loss(self, model, inputs, return_outputs=return_outputs)

Trainer.compute_loss = patched_compute_loss

# Train the model
warmup_steps = int(len(train_loader) * EPOCHS * 0.1)

model.fit(
    train_objectives=[(train_loader, loss)],
    evaluator=evaluator,
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path="fine_tuned_obliqa_model",
    evaluation_steps=50,
    show_progress_bar=True
)

print("Training complete.")




Step,Training Loss,Validation Loss,Cosine Accuracy@1,Cosine Accuracy@3,Cosine Accuracy@5,Cosine Accuracy@10,Cosine Precision@1,Cosine Precision@3,Cosine Precision@5,Cosine Precision@10,Cosine Recall@1,Cosine Recall@3,Cosine Recall@5,Cosine Recall@10,Cosine Ndcg@10,Cosine Mrr@10,Cosine Map@100,Dot Accuracy@1,Dot Accuracy@3,Dot Accuracy@5,Dot Accuracy@10,Dot Precision@1,Dot Precision@3,Dot Precision@5,Dot Precision@10,Dot Recall@1,Dot Recall@3,Dot Recall@5,Dot Recall@10,Dot Ndcg@10,Dot Mrr@10,Dot Map@100
50,No log,No log,0.566714,0.686155,0.719154,0.752869,0.566714,0.238044,0.152224,0.081528,0.502959,0.61134,0.64395,0.679663,0.613534,0.6329,0.578715,0.566714,0.686155,0.719154,0.752869,0.566714,0.238044,0.152224,0.081528,0.502959,0.61134,0.64395,0.679663,0.613534,0.6329,0.578715
100,No log,No log,0.590029,0.704089,0.732425,0.76901,0.590029,0.245337,0.155882,0.084039,0.522089,0.628372,0.65813,0.697669,0.632607,0.653098,0.597716,0.590029,0.704089,0.732425,0.76901,0.590029,0.245337,0.155882,0.084039,0.522089,0.628372,0.65813,0.697669,0.632607,0.653098,0.597716
150,No log,No log,0.598278,0.711263,0.74462,0.781205,0.598278,0.248446,0.158752,0.085402,0.529095,0.636579,0.670762,0.708064,0.64168,0.661827,0.606189,0.598278,0.711263,0.74462,0.781205,0.598278,0.248446,0.158752,0.085402,0.529095,0.636579,0.670762,0.708064,0.64168,0.661827,0.606189
200,No log,No log,0.596485,0.715208,0.75,0.787661,0.596485,0.25,0.160187,0.086406,0.525747,0.640686,0.675526,0.714311,0.643713,0.662877,0.606285,0.596485,0.715208,0.75,0.787661,0.596485,0.25,0.160187,0.086406,0.525747,0.640686,0.675526,0.714311,0.643713,0.662877,0.606285
250,No log,No log,0.598637,0.720588,0.753228,0.789455,0.598637,0.251913,0.160689,0.086549,0.527302,0.645068,0.678043,0.716081,0.646405,0.666348,0.609073,0.598637,0.720588,0.753228,0.789455,0.598637,0.251913,0.160689,0.086549,0.527302,0.645068,0.678043,0.716081,0.646405,0.666348,0.609073
300,No log,No log,0.601865,0.72274,0.75538,0.789096,0.601865,0.252152,0.161406,0.086263,0.530697,0.645475,0.679747,0.714497,0.647017,0.66775,0.610531,0.601865,0.72274,0.75538,0.789096,0.601865,0.252152,0.161406,0.086263,0.530697,0.645475,0.679747,0.714497,0.647017,0.66775,0.610531
350,No log,No log,0.596485,0.724892,0.754663,0.791966,0.596485,0.253587,0.161191,0.086872,0.525359,0.648446,0.679561,0.718574,0.645934,0.664772,0.607648,0.596485,0.724892,0.754663,0.791966,0.596485,0.253587,0.161191,0.086872,0.525359,0.648446,0.679561,0.718574,0.645934,0.664772,0.607648
400,No log,No log,0.602224,0.727044,0.758608,0.794476,0.602224,0.253706,0.162339,0.087052,0.530811,0.648517,0.683148,0.720086,0.649532,0.669552,0.611787,0.602224,0.727044,0.758608,0.794476,0.602224,0.253706,0.162339,0.087052,0.530811,0.648517,0.683148,0.720086,0.649532,0.669552,0.611787
450,No log,No log,0.606528,0.729555,0.765423,0.795552,0.606528,0.254543,0.163558,0.087733,0.534403,0.651483,0.688289,0.723326,0.654045,0.674296,0.616579,0.606528,0.729555,0.765423,0.795552,0.606528,0.254543,0.163558,0.087733,0.534403,0.651483,0.688289,0.723326,0.654045,0.674296,0.616579
500,0.831300,No log,0.609756,0.731707,0.764347,0.797346,0.609756,0.255141,0.163343,0.087805,0.537691,0.653156,0.686394,0.722967,0.655522,0.676487,0.618483,0.609756,0.731707,0.764347,0.797346,0.609756,0.255141,0.163343,0.087805,0.537691,0.653156,0.686394,0.722967,0.655522,0.676487,0.618483


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.65s/it]


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:04<00:00,  4.05s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.73s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.72s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:05<00:00,  5.03s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:04<00:00,  4.38s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.71s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:04<00:00,  4.06s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.80s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.69s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.71s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:04<00:00,  4.04s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.79s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.70s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.71s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.70s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:04<00:00,  4.08s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.77s/it]


Batches:   0%|          | 0/88 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:03<00:00,  3.69s/it]
