In [1]:
import numpy as np
from sklearn.model_selection import train_test_split

from ragbooster import BingRetriever, HuggingfaceQAGenerator, Generator, RetrievalAugmentedModel, RAGBooster, score
from ragbooster.demo import load_imputation_dataset

In [2]:
np.random.seed(42)

questions = load_imputation_dataset('demo_data/restaurant.csv', 
                                    impute='city', 
                                    based_on=['name', 'address', 'phone'])

validation_questions, test_questions = train_test_split(questions, test_size=0.5)
validation_questions[0]

Question(text='name: border grill; address: 4th st.; phone: 310/451-1655', correct_answers=['los angeles'], metadata={})

In [3]:
class QAGenerator(HuggingfaceQAGenerator):
    
    def __init__(self, model_name, cache_path):
        super().__init__(model_name, cache_path)
    
    def _create_prompt(self, question, params):
        return { 'question': "What is the name of the city in which this restaurant is located?",
                 'context': question.text }
    
    def _extract_answer(self, response):
        return response['answer'].lower()

In [4]:
minilm = QAGenerator('deepset/minilm-uncased-squad2', 'demo_data/qa-cache.pkl')

score(test_questions, minilm)

  0%|          | 0/432 [00:00<?, ?it/s]

0.05555555555555555

In [5]:
few_shot = ''
for validation_question in validation_questions[:5]:
    few_shot += f"{validation_question.text}; city? {validation_question.correct_answers[0]}\n"
    
print(few_shot)    

name: border grill; address: 4th st.; phone: 310/451-1655; city? los angeles
name: le soleil; address: 133 clement st.; phone: 415/668-4848; city? san francisco
name: cypress club; address: 500 jackson st.; phone: 415/296-8555; city? san francisco
name: west; address: 63rd street steakhouse 44 w. 63rd st.; phone: 212/246-6363; city? new york
name: schatzi on main; address: 3110 main st.; phone: 310/399-4800; city? los angeles



In [6]:
import re 

class FewShotGenerator(Generator):
    
    FEW_SHOT_PROMPT = "name: border grill; address: 4th st.; phone: 310/451-1655; city? los angeles\n\n"+\
        "name: le soleil; address: 133 clement st.; phone: 415/668-4848; city? san francisco\n\n"+\
        "name: cypress club; address: 500 jackson st.; phone: 415/296-8555; city? san francisco\n\n"+\
        "name: west; address: 63rd street steakhouse 44 w. 63rd st.; phone: 212/246-6363; city? new york\n\n"+\
        "name: schatzi on main; address: 3110 main st.; phone: 310/399-4800; city? los angeles\n\n"

    def __init__(self, llm):
        super().__init__(llm=llm, max_tokens=10)    
    
    def _create_prompt(self, question, params):        
        return f"{self.FEW_SHOT_PROMPT}{question.text}; city?"   
    
    def _extract_answer(self, response):
        answer = response.get_response()          

        answer = re.sub(r'[0-9]+', '', answer)
        answer = answer.strip()   

        for sep in ['\n', ',', '.']:
            if sep in answer:
                answer = answer.split(sep)[0]

        return answer.strip()  

In [7]:
from manifest import Manifest 

gpt35_client = Manifest(client_name="openai", engine="text-davinci-003",
                        cache_name="sqlite", cache_connection="demo_data/gpt35-cache.sqlite")

gpt35 = FewShotGenerator(llm=gpt35_client)

In [8]:
# TODO: Add note that GPT3.5 has most certainly seen the data at training time...

In [9]:
score(test_questions, gpt35)

  0%|          | 0/432 [00:00<?, ?it/s]

0.8981481481481481

In [10]:
class MyBingWebsearch(BingRetriever):
    
    def __init__(self, cache_path):
        super().__init__(cache_path)
    
    def create_query(self, question):
        return question.text
    
bing_websearch = MyBingWebsearch('demo_data/bing-cache.pkl')  

In [11]:
example_question = validation_questions[11]
example_question

Question(text="name: scala's bistro; address: 432 powell st.; phone: 415/395-8555", correct_answers=['san francisco'], metadata={})

In [12]:
retrieved = bing_websearch.retrieve(example_question)
for snippet, url in retrieved[:3]:
    print(url, '-', snippet, '\n')

https://tableagent.com/san-francisco/scalas-bistro/ - Reservations Scala's Bistro Reservations Date Time Party Size Business Info + − Leaflet | © OpenStreetMap Address: 432 Powell Street, San Francisco CA 94102 Cross Street: Post Street Location: San Francisco | Union Square Cuisine: French | Italian | Pasta | Cost: | Moderate Category: Fine Dining Star Rating: Reservations: Unknown 

https://www.yellowpages.com/san-francisco-ca/mip/scalas-bistro-4887204 - ﻿ $$$ Italian Restaurants, Bars, Continental Restaurants (2) (2076) 7.1 OPEN NOW Today: 8:00 am - 11:00 pm 21 YEARS IN BUSINESS Amenities: (415) 395-8555 Map & Directions 432 Powell StSan Francisco, CA 94102 Write a Review Is this your business? Customize this page. Claim This Business Hours Regular Hours Scala's Bistro 432 Powell St, San Francisco 

https://www.chamberofcommerce.com/united-states/california/san-francisco/italian-restaurant/2006879304-scala-s-bistro - Scala's Bistro at 432 Powell St, San Francisco, CA 94102. Get Scal

In [13]:
class QAGeneratorWithContext(HuggingfaceQAGenerator):
    
    def __init__(self, model_name, cache_path):
        super().__init__(model_name, cache_path)
    
    def _create_prompt(self, question, params):
        retrieved_context = params['retrieved_context']
        return { 'question': "What is the name of the city in which this restaurant is located?",
                 'context': f'{retrieved_context};{question.text}' }
    
    def _extract_answer(self, response):
        return response['answer'].lower()
    
minilm_ctx = QAGeneratorWithContext('deepset/minilm-uncased-squad2', 'demo_data/qa_ctx-cache.pkl')

In [14]:
rag10 = RetrievalAugmentedModel(bing_websearch, minilm_ctx, k=10)

accuracy_rag_10 = score(test_questions, rag10)

f'The accuracy with retrieval augmentation and k=10 on the test set is {accuracy_rag_10}'

  0%|          | 0/432 [00:00<?, ?it/s]

'The accuracy with retrieval augmentation and k=10 on the test set is 0.8009259259259259'

In [15]:
refined_rag_model = RAGBooster(rag10, validation_questions[5:])

Computing validation corpus...


  0%|          | 0/427 [00:00<?, ?it/s]

Learning importance weights for data sources...
Tuning threshold for corpus pruning...
Achieved accuracy of 0.874 with a pruning threshold of 0.57806 on the validation set.


In [16]:
accuracy_refined = score(test_questions, refined_rag_model)
improvement = accuracy_refined - accuracy_rag_10

f'RAGBooster improved the accuracy with retrieval augmentation by {improvement:.3f}'+\
f' to {accuracy_refined}!'

  0%|          | 0/432 [00:00<?, ?it/s]

'RAGBooster improved the accuracy with retrieval augmentation by 0.044 to 0.8449074074074074!'