In [1]:
!pip3 install thirdai --upgrade

In [None]:
from thirdai import neural_db as ndb, licensing
licensing.activate("D0F869-B61466-6A28F0-14B8C6-0AC6C6-V3")
import pandas as pd
import fitz 
from langchain.text_splitter import CharacterTextSplitter

### Process PDF files into CSV

In [2]:
def extract_text_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    text = ""
    for page in doc:
        text += page.get_text()
    doc.close()
    return text

def save_chunks_to_csv(chunks, csv_path):
    df = pd.DataFrame(chunks, columns=['Text'])
    df.to_csv(csv_path, index=False)

In [3]:
# specify the pdf paths here
pdf_paths = []

csv_files = []
for pdf_path in pdf_paths:
    csv_out_path = pdf_path.split(".")[0] + ".csv"
    csv_files.append(csv_out_path)
    chunk_size = 1000
    chunk_overlap = 100
    text = extract_text_from_pdf(pdf_path)
    splitter = CharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=chunk_overlap,separator='\n')

    chunks = list(map(lambda x: x.page_content, splitter.create_documents([text])))
    print(len(chunks))
    save_chunks_to_csv(chunks, csv_out_path)
print(csv_files)

[]


### Train NeuralDB over generated files

In [4]:
# preprocessed csv files
csv_files = ['data/pfizer-20221231.csv',
               'data/tsla-20231231.csv',
               'data/msft-10-Q.csv',
               'data/walmart-10k.csv',
               'data/samsung-2022-10k.csv',
               'data/apple-10k.csv',
               'data/nvda-10k.csv',
               'data/meta-10k.csv']
csv_docs = [ndb.CSV(path=csv_file, strong_columns=['Text'], weak_columns=[], reference_columns=['Text']) for csv_file in csv_files]

In [None]:
# to load a pretrained model, uncomment this
# db = ndb.NeuralDB.from_checkpoint("lti_finetuned.ndb")

# training model from scratch
db = ndb.NeuralDB()
db.insert(csv_docs)

### Finetune the model over questions, paragraph pairs

In [None]:
question_df = pd.read_csv("questions_large.csv")
import tqdm
for csv_file in csv_files:
    df = pd.read_csv(csv_file)
    print(csv_file)
    para = df['Text'].to_list()
    temp_df = question_df[question_df['source']==csv_file]
    for _, row in tqdm.tqdm(temp_df.iterrows(), total=len(temp_df)):
        question = row['question']
        db.associate(question, para[int(row['para_id'])])

In [None]:
db.save("lti_finetuned.ndb")

### Inference 

In [7]:
import os
os.environ['OPENAI_API_KEY'] = "" # set OPENAI API Key here

from langchain.chat_models import ChatOpenAI
from paperqa.prompts import qa_prompt
from paperqa.chains import make_chain
from langchain.prompts import PromptTemplate
llm = ChatOpenAI(
    model_name='gpt-3.5-turbo', 
    temperature=0.1,
)

In [8]:
def get_references(query, radius=None):
    search_results = db.search(query,top_k=3)
    references = []
    for result in search_results:
        if (radius):
            references.append(result.context(radius=radius))
        else:
            references.append(result.text)
    return references

def get_answer(query, references):
    #uses default qa_prompt
    qa_chain = make_chain(prompt=qa_prompt, llm=llm)
    return qa_chain.run(question=query, context='\n\n'.join(references[:5]), answer_length="abt 50 words")

def get_answer_manual_prompt(query, references, prompt):
    qa_chain = make_chain(prompt=prompt, llm=llm)
    return qa_chain.run(question=query, context='\n\n'.join(references[:5]))
    # can pass in manual prompt here
    # these input variables would  need to be passed while calling qa_chain.run


In [None]:
query = "What is the revenue of apple in year 2022"

references = get_references(query, radius=1)
# print(references)
answer = get_answer(query, references)

print(answer)

In [None]:
query = "What is the revenue of apple in year 2022"

references = get_references(query)
# print(references)

# design your own prompt here
# make sure to have two variables question and context as defined
manual_prompt = (
        PromptTemplate.from_template("Answer the question based on the following context\n")
        + "Question: {question}"
        + "Context: {context}"
    ) 
print(get_answer_manual_prompt(query, references, manual_prompt))