In [10]:
from datasets import load_dataset
import pandas as pd

# First time run this code, and it will download a json file 
dataset = load_dataset("taddeusb90/finbro-v0.1.0", cache_dir="../data")

In [11]:
dataset = pd.DataFrame(dataset)

dataset = pd.json_normalize(dataset['train'])
dataset = dataset.head(10000)
dataset = dataset.dropna(subset=['input', 'instruction','output'])
print(dataset)

                                                  input  \
0     Residual Standard Deviation: Definition, Formu...   
1           How to Use a Home Equity Loan for a Remodel   
2              Microsoft (MSFT) Unveils Space Ambitions   
3                              Crime of 1873 Definition   
4     What Is a Capital Asset? How It Works, With Ex...   
...                                                 ...   
9995  Student Loan Payments Are Restarting Sunday. T...   
9996     Prenuptial Agreement: What it is, How it Works   
9997  Nvidia Introduces Slower Gaming Chip in China ...   
9998  5 Takeaways From the SEC's Complaint Against Musk   
9999  Resilient Labor Market Bounced Back in April A...   

                                            instruction  \
0      In the given passage, how are residual values...   
1      What are some advantages of using a home equi...   
2      Which companies has AWS partnered with to bro...   
3      When was the accessed date for the given text...

In [12]:
import minsearch

index = minsearch.Index(
        text_fields=["input",'instruction','output'],  
        keyword_fields=[], 
)

documents = dataset.to_dict(orient="records")
# print(documents)
index.fit(documents)

<minsearch.Index at 0x273942b41c0>

In [13]:
from transformers import pipeline

llm = pipeline("question-answering", model="deepset/roberta-base-squad2")  
# llm = pipeline("text-generation", model="gpt2")  


def search(query):
    boost = {}
    
    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

prompt_template = """
You're a financial analyst. Answer the QUESTION based on the CONTEXT from our finance database, and the output is for your reference.
Use only the facts from the CONTEXT when answering the QUESTION.
QUESTION: {question}
CONTEXT:{context}
""".strip()

entry_template = """
input: {input}
instruction: {instruction}
output: {output}
""".strip()

def build_prompt(query, search_results):
    context = ""
    
    for doc in search_results:
        context += entry_template.format(**doc) + "\n\n"

    prompt = prompt_template.format(question=query, context=context).strip()
    # print(prompt)
    return prompt

def rag(query):
    search_results = search(query)
    if not search_results:
        return "No relevant documents found."
    prompt = build_prompt(query, search_results)
    # print(llm(prompt,max_new_tokens=50,num_return_sequences=1,truncation=True, pad_token_id=50256))
    # response = llm(prompt,max_new_tokens=100,num_return_sequences=1,truncation=True, pad_token_id=50256)[0]['generated_text']
    
    response = llm(question=query, context=prompt)
    return response['answer']


question = 'What is APR?'
answer = rag(question)
print("Answer:",answer)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFRobertaForQuestionAnswering: ['roberta.embeddings.position_ids']
- This IS expected if you are initializing TFRobertaForQuestionAnswering from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFRobertaForQuestionAnswering from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFRobertaForQuestionAnswering were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaForQuestionAnswering for predictions without further training.


Answer: annual percentage rate
