# RAGBooster

We detail how to improve the performance of LLMs for question answering with retrieval augmentation and learned data pruning. We showcase how these techniques allow a small open source LLM with 3 billion parameters to perform on par with a large commercial LLM which has 175 billion parameters.

Our library "RAGBooster" for learned data pruning is available as open source.

### Setup

In order to run this demo, we need access to GPT3.5, the Bing web API and a deployed version of the RedPajama-INCITE-Instruct-3B-v1 model from Together. Note that we implement caching and can serve the vast majority of requests from our cache for this particular demo setup.

 1. **Access to GPT3.5 from OpenAI**: This demo notebook leverages GPT3.5 via the OpenAI API. It requires you to make your [OpenAI API key](https://platform.openai.com/account/api-keys) available as an environment variable via the following command:<br/><br/>`export OPENAI_API_KEY=your_secret_openai_key`<br/><br/>

 1. **Access to RedPajama-INCITE-Instruct-3B-v1**: This demo also uses the 3B param language model [RedPajama-INCITE-Instruct-3B-v1](https://huggingface.co/togethercomputer/RedPajama-INCITE-Instruct-3B-v1). This model should be made available via REST API through the [manifest project](https://github.com/HazyResearch/manifest#local-huggingface-models) as follows:<br/> 
<br/> `python -m manifest.api.app \`<br/>
`   --model_type huggingface \`<br/>
`   --model_name_or_path togethercomputer/RedPajama-INCITE-Instruct-3B-v1 \`<br/>
`   --model_generation_type text-generation`<br/>
`   --device 0`<br/><br/>
 
 1. **Access to the Bing Websearch API**: Furthermore, we will query the web via [Microsoft Bing's websearch API](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api). You need to make your Bing API key available as an environment variable via the following command:<br/><br/>`export BING_SUBSCRIPTION_KEY=your_secret_bing_key`<br/><br/>
 


## Question Answering with Large Language Models

The scenario for this demo is question answering with Large Language Models (LLMs). We use a dataset of questions about the place of birth of various people from the Wikifact dataset in Stanford's [HELM benchmark](https://crfm.stanford.edu/helm/latest/). We work with a sample of 500 questions from the data as final test set.

In [1]:
from ragbooster import Generator, BingRetriever, RetrievalAugmentedModel, RAGBooster, score
from ragbooster.demo import load_wikifact_questions

questions = load_wikifact_questions('demo_data/wikifact_place_of_birth_helm.json')

validation_questions = questions[:500]
test_questions = questions[500:1000]

An example question is about the birth place of the Slovak ice hockey player Martin Kulha:

In [2]:
example_question = questions[5]
example_question

Question(text='Martin Kulha was born in', correct_answers=['Poprad'], metadata={})

### Question Answering with GPT3.5

Let's see how well OpenAI's `'text-davinci-003'` model from the [GPT3.5 family](https://platform.openai.com/docs/models/gpt-3-5) is doing on these questions. 

In [3]:
from manifest import Manifest 

gpt35_client = Manifest(client_name="openai", engine="text-davinci-003",
                        cache_name="sqlite", cache_connection="demo_data/gpt35-cache.sqlite")

We can leverage GPT3.5 by extending the `Generator` class. We write a couple of lines of Python to define how create our prompt from the question and some few-shot examples, and how to parse the answer returned by GPT3.5

In [4]:
import re 

class PlaceOfBirthGenerator(Generator):
    
    FEW_SHOT_PROMPT = "Brown was born in England\n\n"+\
        "Jerry Beck was born in New York City\n\n"+\
        "Werner Lorenz was born in Ludwigshafen\n\n"+\
        "Moritz Retzsch was born in Dresden\n\n"+\
        "Roni Rosadi was born in Bandar Lampung\n\n"      
    
    def __init__(self, llm):
        super().__init__(llm=llm, max_tokens=10)    
    
    def _create_prompt(self, question, params):        
        return f"{self.FEW_SHOT_PROMPT}\n\n{question.text}"            

    def _extract_answer(self, response):
        answer = response.get_response()          
        answer = re.sub(r'[0-9]+', '', answer)
        answer = answer.strip()   

        for sep in ['\n', ',', '.']:
            if sep in answer:
                answer = answer.split(sep)[0]

        return answer.strip()  
    
gpt35 = PlaceOfBirthGenerator(llm=gpt35_client)    

Unfortunately, GPT3.5 gives us the wrong answer to the example question!

In [5]:
print(f'GPT3.5 answers "{example_question.text}" with "{gpt35.generate(example_question)}"')
print(f'The Correct answer is: {example_question.correct_answers[0]}')

GPT3.5 answers "Martin Kulha was born in" with "Prague"
The Correct answer is: Poprad


We can also evaluate GPT3.5 on all our 500 test questions and find that it only answers 14% of the questions correctly.

In [6]:
accuracy = score(test_questions, gpt35)

f'The accuracy of GPT3.5 on the test set is {accuracy}'

  0%|          | 0/500 [00:00<?, ?it/s]

'The accuracy of GPT3.5 on the test set is 0.14'

## Question Answering with RedPajama-INCITE-Instruct-3B-v1

Let's see how the smaller model RedPajama-INCITE-Instruct-3B-v1 is doing on this task. We connect to our local instance of this model as follows (Please adjust the code if you use a different port).

In [7]:
redpajama_port=5291
redpajama_client = Manifest(client_name = "huggingface", client_connection = f"http://127.0.0.1:{redpajama_port}",
                            cache_name='sqlite', cache_connection="demo_data/rp3b-cache.sqlite")

redpajama = PlaceOfBirthGenerator(llm=redpajama_client)

_Due a [caching bug](https://github.com/HazyResearch/manifest/issues/103) in manifest, we need to apply the following hack to get performant caching. This code will become unnecessary as soon as the bug is fixed._

In [8]:
from manifest.clients.huggingface import HuggingFaceClient
import types

client = redpajama_client.client_pool.get_current_client()
redpajama_model_params = client.get_model_params()

def cached_params(self):
    return redpajama_model_params

client.get_model_params = types.MethodType(cached_params, client)

RedPajama-INCITE-Instruct-3B-v1 also gives us the wrong answer to the example question!

In [9]:
print(f'RedPajama3B answers "{example_question.text}" with "{redpajama.generate(example_question)}"')
print(f'The correct answer is: {example_question.correct_answers[0]}')

RedPajama3B answers "Martin Kulha was born in" with "Prague"
The correct answer is: Poprad


Our task is also difficult for RedPajama-INCITE-Instruct-3B-v1: it answers only 10% of the test questions correctly.

In [10]:
accuracy = score(test_questions, redpajama)

f'The accuracy of RedPajama3B on the test set is {accuracy}'

  0%|          | 0/500 [00:00<?, ?it/s]

'The accuracy of RedPajama3B on the test set is 0.1'

## Retrieval Augmentation with Bing Websearch

We can improve the performance of our LLMs by providing them with some external data to answer the questions, for example from the web. This is called retrieval augmentation, and we use Microsoft's Bing websearch API for that by extending the `BingRetriever` class and defining how to create a query from the question text. In our case, we can just use the question text as the query.

In [11]:
class MyBingWebsearch(BingRetriever):
    
    def __init__(self, cache_path):
        super().__init__(cache_path)
    
    def create_query(self, question):
        return question.text
    
bing_websearch = MyBingWebsearch('demo_data/bing-cache.pkl')    

Here is the information that we find via Bing for our example question about Martin Kulha. Note that the top results already contain the correct answer 'Poprad' in the text!

In [12]:
retrieved = bing_websearch.retrieve(example_question)
for snippet, url in retrieved[:3]:
    print(url, '-', snippet, '\n')


https://www.wikilogy.com/biography/martin-kulha/ - Martin Kulha is an Ice Hockey Player. He was born in Poprad on August 07, 1976. Want to more about Him? In this article, we covered Martin Kulha's net worth, wiki, bio, career, height, weight, pics, family, affairs, car, salary, age, facts, and other details in 2023. Continue reading to discover who is Martin Kulha. 

https://www.celebsagewiki.com/martin-kulha - Martin Kulha was born on 7 August, 1976 in Poprad, Slovakia. Discover Martin Kulha's Biography, Age, Height, Physical Stats, Dating/Affairs, Family and career updates. Learn How rich is She in this year and how She spends money? Also learn how She earned most of networth at the age of 44 years old? 

https://icehockey.fandom.com/wiki/Martin_Kulha - Martin Kulha (born August 7, 1976) is a Slovak professional ice hockey player who formerly played with Sangliers Arvernes de Clermont in the FFHG Division 1. He is now a member of the Lyon Club in the French Division 3. Kulha had pre

Next, we write a new `Generator` which uses a different prompt tailored for retrievals and the retrieved text from Bing to generate answers.

In [13]:
class PlaceOfBirthGeneratorWithContext(Generator):
    
    RETRIEVAL_PROMPT = "\nJerry Beck (born February 9, 1955, in New York City) is an American animation historian," +\
        " author, blogger, and video producer.Beck wrote or edited several books on classic" +\
        " American animation and classic characters.\nJerry Beck was born in New York\n\n" +\
        "Ettore Maria Fizzarotti (1916–1985) was an Italian film director and screenwriter." +\
        " Born in Naples, the son of the director Armando, he debuted as assistant director" +\
        " in the films of his father.\nEttore Maria Fizzarotti was born in Naples\n"       
    
    def __init__(self, llm):
        super().__init__(llm=llm, max_tokens=10)    
    
    def _create_prompt(self, question, params):        
        retrieved_text = params['retrieved_context']
        return f"{self.RETRIEVAL_PROMPT}\n\n{retrieved_text}\n\n{question.text}"               

    def _extract_answer(self, response):
        answer = response.get_response()          

        answer = re.sub(r'[0-9]+', '', answer)
        answer = answer.strip()   

        for sep in ['\n', ',', '.']:
            if sep in answer:
                answer = answer.split(sep)[0]

        return answer.strip()  
    
gpt35_ctx = PlaceOfBirthGeneratorWithContext(llm=gpt35_client)   
redpajama_ctx = PlaceOfBirthGeneratorWithContext(llm=redpajama_client)



If we provide GPT3.5 with the retrieved extra information from Bing, it generates the correct answer in the majority of cases:


In [14]:
for snippet, url in retrieved[:10]:
    answer = gpt35_ctx.generate(example_question, {'retrieved_context': snippet})
    print(f'GPT3.5 gives the answer "{answer}" based on {url}')    

GPT3.5 gives the answer "Poprad on August" based on https://www.wikilogy.com/biography/martin-kulha/
GPT3.5 gives the answer "Poprad" based on https://www.celebsagewiki.com/martin-kulha
GPT3.5 gives the answer "Poprad" based on https://icehockey.fandom.com/wiki/Martin_Kulha
GPT3.5 gives the answer "Poprad" based on https://biogossipy.com/martin-kulha/
GPT3.5 gives the answer "Poprad" based on https://popularbio.com/martin-kulha/
GPT3.5 gives the answer "Poprad" based on https://www.hockeydb.com/ihdb/stats/pdisplay.php?pid=57405
GPT3.5 gives the answer "Pohoří" based on https://www.myheritage.com/names/martin_kulha
GPT3.5 gives the answer "Slovakia" based on http://www.vipfaq.com/Martin%20Kulha.html
GPT3.5 gives the answer "Poprad" based on https://networthmask.com/martin-kulha/
GPT3.5 gives the answer "August th" based on https://en.wikipedia.org/wiki/Martin_Kulha


In order to leverage this finding, we implement a `RetrievalAugmentedModel`, which generates the final answer via a majority vote over the top-10 generated answers from GPT3.5 based on the data from Bing.

This model gives us the correct answer:


In [15]:
rag = RetrievalAugmentedModel(bing_websearch, gpt35_ctx, k=10)

print(f'GPT3.5 with retrieval augmentation gives the correct answer "{rag.generate(example_question)}"')   

GPT3.5 with retrieval augmentation gives the correct answer "Poprad"


Retrieval augmentation is a powerful technique, even a single retrieved webpage (`k=1`) improves the accuracy of our 
LLMs by a factor of 3 to 4:

In [16]:
gpt35_rag1 = RetrievalAugmentedModel(bing_websearch, gpt35_ctx, k=1)
redpajama_rag1 = RetrievalAugmentedModel(bing_websearch, redpajama_ctx, k=1)

accuracy_gpt35_rag1 = score(test_questions, gpt35_rag1)
accuracy_redpajama_rag1 = score(test_questions, redpajama_rag1)

print(f'The accuracy of GPT3.5 with retrieval augmentation and k=1 on the test set is {accuracy_gpt35_rag1}\n'+\
f'The accuracy of RedPajama3B with retrieval augmentation and k=1 on the test set is {accuracy_redpajama_rag1}')

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

The accuracy of GPT3.5 with retrieval augmentation and k=1 on the test set is 0.336
The accuracy of RedPajama3B with retrieval augmentation and k=1 on the test set is 0.41


Using `k=10` further improves the performance and makes the small 6B model on par with 175B parameter model from OpenAI. Both models now answer about half of the test questions correctly!

In [17]:
gpt35_rag10 = RetrievalAugmentedModel(bing_websearch, gpt35_ctx, k=10)
redpajama_rag10 = RetrievalAugmentedModel(bing_websearch, redpajama_ctx, k=10)

accuracy_gpt35_rag10 = score(test_questions, gpt35_rag10)
accuracy_redpajama_rag10 = score(test_questions, redpajama_rag10)

print(f'The accuracy of GPT3.5 with retrieval augmentation and k=10 on the test set is {accuracy_gpt35_rag10}\n'+\
f'The accuracy of RedPajama3B with retrieval augmentation and k=10 on the test set is {accuracy_redpajama_rag10}')

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

The accuracy of GPT3.5 with retrieval augmentation and k=10 on the test set is 0.498
The accuracy of RedPajama3B with retrieval augmentation and k=10 on the test set is 0.496


## Improving the performance further with RAGBooster

We can further improve the performance of our retrieval-augmented models by learning the data importance of the retrieval sources (web domains in our case) and pruning the retrieval corpus accordingly. Checkout our recent paper on **Improving Retrieval-Augmented Large Language Models with Data-Centric Refinement** (TODO need arxiv version) for details on the algorithm behind this.

We can "boost" the performance of our models via the `RAGBooster` class and an additional set of validation questions as follows:

In [38]:
gpt35_rag_boosted = RAGBooster(gpt35_rag10, validation_questions, 
                               learning_rate=10, num_epochs=100, n_jobs=-1)

redpajama_rag_boosted = RAGBooster(redpajama_rag10, validation_questions, 
                                   learning_rate=10, num_epochs=100, n_jobs=-1)

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

We find that **RAGBooster improves the accuracy of both our LLMs by approximately 3%** and makes them both answer about 53% percent of the questions correctly:

In [19]:
accuracy_gpt35_rag10_boosted = score(test_questions, gpt35_rag_boosted)

f'RAGBooster boosted the accuracy of GPT3.5 with retrieval augmentation'+\
f' from {accuracy_gpt35_rag10} to {accuracy_gpt35_rag10_boosted}!'

  0%|          | 0/500 [00:00<?, ?it/s]

'RAGBooster boosted the accuracy of GPT3.5 with retrieval augmentation from 0.498 to 0.532!'

In [20]:
accuracy_redpajama_rag10_boosted = score(test_questions, redpajama_rag_boosted)

f'RAGBooster boosted the accuracy of RedPajama3B with retrieval augmentation'+\
f' from {accuracy_redpajama_rag10} to {accuracy_redpajama_rag10_boosted}!'

  0%|          | 0/500 [00:00<?, ?it/s]

'RAGBooster boosted the accuracy of RedPajama3B with retrieval augmentation from 0.496 to 0.528!'

### Important retrieval sources

Internally, RAGBooster learns an importance weight for each data source (web domain in our case) as well as a pruning threshold. We can inspect these importances via the `weights` attribute of RAGBooster:

In [39]:
domains_and_weights = redpajama_rag_boosted.weights
domains_and_weights_sorted = sorted(domains_and_weights.items(), key=lambda x:x[1], reverse=True)

domains_and_weights_sorted[:25]

[('allengelhard.com', 0.7000000000000002),
 ('ancient-origins.net', 0.7000000000000002),
 ('badmintonbites.com', 0.7000000000000002),
 ('britishmuseum.org', 0.7000000000000002),
 ('elanka.com.au', 0.7000000000000002),
 ('jewage.org', 0.7000000000000002),
 ('jukebugs.com', 0.7000000000000002),
 ('masterbond.com', 0.7000000000000002),
 ('mormonwiki.com', 0.7000000000000002),
 ('nwasianweekly.com', 0.7000000000000002),
 ('playmakerstats.com', 0.7000000000000002),
 ('thedailygardener.org', 0.7000000000000002),
 ('thestar.co.uk', 0.7000000000000002),
 ('mathrubhumi.com', 0.6997466764155842),
 ('tribuneindia.com', 0.6997466764155842),
 ('ceeol.com', 0.6997357590579518),
 ('cartoonia.ru', 0.6997044999828725),
 ('ww2db.com', 0.6997010228277173),
 ('apumone.com', 0.6996936785098945),
 ('aussiecelebs.com.au', 0.6996936785098945),
 ('namecensus.com', 0.6996899787712634),
 ('getsol.app', 0.6996804444452616),
 ('ed.ac.uk', 0.6996703510006601),
 ('lindahall.org', 0.6996703510006601),
 ('raynatours.c

We find that RAGBooster identifies some very interesting data sources, for example:

 * [ancient-origins.net](ancient-origins.net), a website dedicated to archaeology and ancient history
 * [mormonwiki.com](mormonwiki.com), an encyclopedia about mormons
 * [britishmuseum.org](https://www.britishmuseum.org), the website of the british museum
 * [badmintonbites.com](https://badmintonbites.com), a website about important badminton players

We can also compute how much RAGBooster prunes the retrieval corpus. It turns out, it uses only about 20% of the domains it saw for the validation corpus.

In [50]:
num_data_sources = len(redpajama_rag_boosted.weights)
threshold = redpajama_rag_boosted.tuning_result.best_threshold

after_pruning = [domain for domain, weight in redpajama_rag_boosted.weights.items() if weight >= threshold]

num_data_sources_after_pruning = len(after_pruning)

f'Pruning retrieval corpus from {num_data_sources} to {num_data_sources_after_pruning} sources, '+\
f'based on learned weight threshold of {threshold:.4f}'

'Pruning retrieval corpus from 2374 to 475 sources, based on learned weight threshold of 0.5476'

### Disclaimer

TODO: explain not always the case that small LLM performs so well, check paper for detailed evaluation; retrieval augmentation costly for large k