## Retrieval model training

In this notebook, we'll train a custom retrieval model for our RAG framework using the pretrained ColBERT model. 

The ColBERT retrieval model generalizes well for most usecases, however there might be a need to train our own custom retrieval model when working with proprietary data, and this notebook shows how to do so.

The data used here is gotten from Wikipedia, however the same process can be replicated for proprietary datasets.

In [1]:
# importing relevant packages
from rich import print
import random
import requests
from ragatouille import RAGTrainer
from ragatouille.data import CorpusProcessor, llama_index_sentence_splitter

In [2]:
# creating the trainer
trainer = RAGTrainer(
    model_name="EPLColBERT",
    pretrained_model_name="colbert-ir/colbertv2.0",
    language_code="en",
)

**Note**: We can use proprietary datasets, however, due to safety reasons i will be using a publicly available dataset from Wikipedia.

To use a proprietary dataset, all you need to do is load and preprocess the dataset to a readable format.

In [3]:
# getting data from wikipedia using the API
def get_wikipedia_page(title: str):
    """
    Retrieve the full text content of a Wikipedia page.
    
    :param title: str - Title of the Wikipedia page.
    :return: str - Full text content of the page as raw string.
    """
    # Wikipedia API endpoint
    URL = "https://en.wikipedia.org/w/api.php"

    # Parameters for the API request
    params = {
        "action": "query",
        "format": "json",
        "titles": title,
        "prop": "extracts",
        "explaintext": True,
    }

    # Custom User-Agent header to comply with Wikipedia's best practices
    headers = {
        "User-Agent": "RAGatouille_tutorial/0.0.1 (ben@clavie.eu)"
    }

    response = requests.get(URL, params=params, headers=headers)
    data = response.json()

    # Extracting page content
    page = next(iter(data['query']['pages'].values()))
    return page['extract'] if 'extract' in page else None

In [4]:
# getting the data
epl_corpus = [get_wikipedia_page('Manchester United F.C.'), get_wikipedia_page('Manchester City F.C.'), get_wikipedia_page('Arsenal F.C.'), get_wikipedia_page('Chelsea F.C.'), get_wikipedia_page('Tottenham Hotspur F.C.'), get_wikipedia_page('Liverpool F.C.'), get_wikipedia_page('Premier League')]

In [5]:
# when building RAGs, its advisable to split the data into smaller chunks
corpus_processor = CorpusProcessor(
    document_splitter_fn=llama_index_sentence_splitter)
documents = corpus_processor.process_corpus(epl_corpus, chunk_size=256)

In [6]:
# generating random query-document pairs (really doesn't matter if its correct)
queries = [
    "When was Manchester United formed?",
    "How many english league titles have Liverpool won?",
    "Which club is the oldest in England?",
    "Which club has the record for the longest unbeaten run in the premier league?",
    "Who is the most successful manager in the premier league?"
] * 3

pairs = []

for query in queries:
    fake_docs = random.sample(documents, 10)
    for doc in fake_docs:
        pairs.append((query, doc))

preparing our data/pairs by performing hard negative mining, which searches for the entire documents for passages that are semantically similar to the query, but aren't actually relevant

In [7]:
trainer.prepare_training_data(
    raw_data=pairs,
    data_out_path="../data/",
    all_documents=epl_corpus,
    num_new_negatives=10,
    mine_hard_negatives=True,
)

Loading Hard Negative SimpleMiner dense embedding model BAAI/bge-small-en-v1.5...




Building hard negative index for 136 documents...
All documents embedded, now adding to index...
save_index set to False, skipping saving hard negative index
Hard negative index generated


'../data/'

In [8]:
# training the retrie
trainer.train(batch_size=32,
              nbits=4, # How many bits will the trained model use when compressing indexes
              maxsteps=500000, # Maximum steps hard stop
              use_ib_negatives=True, # Use in-batch negative to calculate loss
              dim=128, # How many dimensions per embedding. 128 is the default and works well.
              learning_rate=5e-6, # Learning rate, small values ([3e-6,3e-5] work best if the base model is BERT-like, 5e-6 is often the sweet spot)
              doc_maxlen=128, # Maximum document length. Because of how ColBERT works, smaller chunks (128-256) work very well.
              use_relu=False, # Disable ReLU -- doesn't improve performance
              warmup_steps="auto", # Defaults to 10%
             )

#> Starting...
nranks = 1 	 num_gpus = 1 	 device=0
{
    "query_token_id": "[unused0]",
    "doc_token_id": "[unused1]",
    "query_token": "[Q]",
    "doc_token": "[D]",
    "ncells": null,
    "centroid_score_threshold": null,
    "ndocs": null,
    "load_index_with_mmap": false,
    "index_path": null,
    "index_bsize": 64,
    "nbits": 4,
    "kmeans_niters": 20,
    "resume": false,
    "similarity": "cosine",
    "bsize": 32,
    "accumsteps": 1,
    "lr": 5e-6,
    "maxsteps": 500000,
    "save_every": 0,
    "warmup": 0,
    "warmup_bert": null,
    "relu": false,
    "nway": 2,
    "use_ib_negatives": true,
    "reranker": false,
    "distillation_alpha": 1.0,
    "ignore_scores": false,
    "model_name": "EPLColBERT",
    "query_maxlen": 32,
    "attend_to_mask_tokens": false,
    "interaction": "colbert",
    "dim": 128,
    "doc_maxlen": 128,
    "mask_punctuation": true,
    "checkpoint": "colbert-ir\/colbertv2.0",
    "triples": "..\/data\/triples.train.colbert.jsonl",




[May 13, 01:27:38] #> Loading the queries from ../data/queries.train.colbert.tsv ...
[May 13, 01:27:38] #> Got 5 queries. All QIDs are unique.

[May 13, 01:27:38] #> Loading collection...
0M 




#> LR will use 0 warmup steps and linear decay over 500000 steps.

#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: . Who is the most successful manager in the premier league?, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([ 101,    1, 2040, 2003, 1996, 2087, 3144, 3208, 1999, 1996, 4239, 2223,
        1029,  102,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,
         103,  103,  103,  103,  103,  103,  103,  103], device='cuda:0')
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

				 6.6913957595825195 11.028946876525879
#>>>    10.19 16.7 		|		 -6.51




[May 13, 01:28:06] 0 17.7203426361084
				 3.913029432296753 7.930639266967773
#>>>    13.98 17.03 		|		 -3.0500000000000007
[May 13, 01:28:33] 1 17.714465962409975
				 6.548213005065918 11.229314804077148
#>>>    9.65 16.06 		|		 -6.409999999999998
[May 13, 01:28:59] 2 17.714529023303033
				 7.281060218811035 12.035734176635742
#>>>    10.4 17.65 		|		 -7.249999999999998
[May 13, 01:29:27] 3 17.71613128962885
[May 13, 01:29:27] #> Done with all triples!
#> Saving a checkpoint to .ragatouille/colbert/none/2024-05/13/01.27.16/checkpoints/colbert ..
#> Joined...
