## ThirdAI's Playground

In this notebook, we will show 

1. How to easily build a semantic QnA engine for all your documents with ThirdAI's BOLT engine.

2. (Optional) How to use your OpenAI key to get retrieval augmented answers from OpenAI.

3. How to teach your retrieval model with RLHF.

4. (Optional) How to save your models and export to ThirdAI's Playground web-app to do interative QnA and teach your model with RLHF.

In [None]:
!pip3 install -r requirements.txt

In [3]:
# thirdai's license activation

import thirdai
try:
    thirdai.licensing.activate("")
except:
    print("You need a license key to use ThirdAI's library. Please request a trial license at https://www.thirdai.com/try-bolt/")

thirdai.set_seed(7)

In [4]:
from thirdai import bolt
import os
import nltk
nltk.data.path.append("./data/")
from pathlib import Path
import pickle
from doc_utils import documents

### Load your files

You can load a mix of csv, pdf and docx files. If you want to train on a CSV file, set the target_column_name variable below to the ID column of the CSV file. The ID column must contain consecutive integers from 0 to num_ids - 1.

Also, if you're loading from a pre-trained checkpoint, the query_col_name and target_col_name should match the ones used for that model. For the checkpoints that ThirdAI provides, we standardize the query_column_name to "QUERY" and the target_column_name to "DOC_ID".

In [5]:
query_column_name = "QUERY"
target_column_name = "DOC_ID"

combined_pdfs = None
combined_docxs = None
csv_doc = None

# This object does the bookkeeping for managing multiple documents
doclist = documents.DocList()

#### Option 1: CSV files

In [16]:
csv_file = "sample_catalog.csv"

# Visualize the dataframe and get the column names in the csv_file.
# Your target column (id_col) name has to match the target column in the model defined above (we are using target_column_name across the notebook)
# You will have to pick your choice of strong_columns and weak_columns for the train step shown next.
# Strong columns are usually the most important ones like titles of documents, keywords, categories etc
# Weak columns are usually the long descriptions

import pandas as pd
pd.options.display.max_colwidth = 700

df = pd.read_csv(csv_file)
print(df.iloc[0])

DOC_ID                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     0
TITLE                                                                                                                                                                                                                                                                                      

In [17]:
csv_doc = documents.CSV(
    path=csv_file,
    id_col=target_column_name,
    strong_cols=['TITLE', 'BRAND'],
    weak_cols=['DESCRIPTION'],
    display_cols=['TITLE','DESCRIPTION'],
)

doclist.add_document(csv_doc)

#### Option 2: PDF or DOCX files

In [6]:
filenames = ['mutual_nda_teamplate_for_testing.pdf']

pdfs = [name for name in filenames if name.endswith(".pdf")]
docxs = [name for name in filenames if name.endswith(".docx")]

if len(pdfs)>0:
    combined_pdfs = documents.PDF(
        files=pdfs, 
        expected_id_col=target_column_name,
        hash_to_id_offset=doclist.get_source_hash_to_id_offset_map(),
        next_id_offset=doclist.get_n_new_ids(),
    )
    doclist.add_document(combined_pdfs)

if len(docxs)>0:
    combined_docxs = documents.DOCX(
        files=docxs, 
        expected_id_col=target_column_name,
        hash_to_id_offset=doclist.get_source_hash_to_id_offset_map(),
        next_id_offset=doclist.get_n_new_ids(),
    )
    doclist.add_document(combined_docxs)

### Model Definition
#### Option 1: Define a model from scratch

In [41]:
model = bolt.UniversalDeepTransformer(
    data_types = {
        query_column_name: bolt.types.text(tokenizer="char-4"),
        target_column_name: bolt.types.categorical(delimiter=":"),
    },
    target=target_column_name,
    n_target_classes=doclist.get_n_new_ids(),
    integer_target=True,
    options={
        "fhr": 50000,
        "embedding_dimension": 2048,
        "extreme_classification": True,
        "extreme_output_dim": 50000,
        "rlhf":True,
    }
)

lr = 0.005

#### Option 2: Load from a checkpoint

In [None]:
import os

checkpoint = "msmarco.bolt"

if not os.path.exists(checkpoint):
    os.system("wget -O msmarco.bolt https://www.dropbox.com/s/sd1vxsg8v6d2u2r/msmarco_0_reindexes.bolt?dl=0")

model = bolt.UniversalDeepTransformer.load(checkpoint)

model.clear_index()

for doc in [csv_doc, combined_pdfs, combined_docxs]:
    if doc:
        print(doc)
        doc_config = doc.get_config()
        model.introduce_documents(
            doc_config.introduction_dataset,
            strong_column_names=doc_config.strong_cols, 
            weak_column_names=[],
            num_buckets_to_sample=16,
        )

lr = 1e-3

### Train the model


In [None]:
for doc in [csv_doc, combined_pdfs, combined_docxs]:
    if doc:
        doc_config = combined_pdfs.get_config()
        metrics = model.cold_start(
            filename=doc_config.unsupervised_dataset,
            strong_column_names=doc_config.strong_cols,
            weak_column_names=doc_config.weak_cols,
            epochs=3,
            learning_rate=lr,
            metrics=["precision@5"],
        )

In [36]:
# how many search results do you want to retrieve from your files for every query
N_REFERENCES = 2

model.set_decode_params(min(doclist.get_n_new_ids(), N_REFERENCES), min(doclist.get_n_new_ids(), 100))

### Get Answers from OpenAI

In this section, we will show how to use LangChain and query OpenAI's QnA module to generate an answer from the references that you retrieve from the above BOLT model. You'll have to specify your own OpenAI key for this module to work. You can replace this segment with any other model of your choice. You can choose to use an on-prem open source model like MPT or Dolly for answer generation with the same prompt that you use with OpenAI.


In [11]:
from langchain.chat_models import ChatOpenAI
from paperqa.qaprompts import qa_prompt, make_chain

your_openai_key = ""

llm = ChatOpenAI(
    model_name='gpt-3.5-turbo', 
    temperature=0.1, 
    openai_api_key=your_openai_key,
)

qa_chain = make_chain(prompt=qa_prompt, llm=llm)

In [12]:
def get_references(query):
    reference_ids = model.predict({"QUERY":query})
    reference_ids = [itm[0] for itm in reference_ids]
    references = [doclist.get_new_display_items().iloc[p] for p in reference_ids]
    return references

def get_answer(query, references):
    return qa_chain.run(question=query, context_str='\n\n'.join(references[:3]), length="abt 50 words")

In [45]:
query = "what is the effective date of this agreement?"

references = get_references(query)
print(references)

['CONFIDENTIALITY AGREEMENT This Confidentiality Agreement (the “Agreement”) is made by and between ACME. dba ToTheMoon Inc. with offices at 2025 Guadalupe St. Suite 260 Austin TX 78705 and StarWars dba ToTheMars with offices at the forest moon of Endor and entered as of May 3 2023 (“Effective Date”).', 'In consideration of the business discussions disclosure of Confidential Information and any future business relationship between the parties it is hereby agreed as follows: 1. CONFIDENTIAL INFORMATION. For purposes of this Agreement the term “Confidential Information” shall mean any information business plan concept idea know-how process technique program design formula algorithm or work-in-process Request for Proposal (RFP) or Request for Information (RFI) and any responses thereto engineering manufacturing marketing technical financial data or sales information or information regarding suppliers customers employees investors or business operations and other information or materials w

In [46]:
answer = get_answer(query, references)

print(answer)

The effective date of this Confidentiality Agreement is May 3, 2023 (ACME dba ToTheMoon Inc. and StarWars dba ToTheMars, 2023).


Now, let's ask a query that the model gets it wrong. Subsequently, let's teach the model to correct itself using our RLHF methods.

In [52]:
query = "who are the parties involved in this agreement?"

references = get_references(query)
answer = get_answer(query, references)
print(answer)

The context provides insufficient information to determine the parties involved in this agreement.


### How to teach your model (RLHF)

This is one of the marquee features that we provide. Thanks to our efficient training capabilties, we can offer you to teach the retrieval model to correct itself in the event of it not being able to get the correct paragraphs from the document. 

Also, the RLHF teachings done a model will generalize beyond the current documents if we run *model.clear_index()* and introduce new documents. This is because our engine has an elastic output space that adapts to the contents of new documents.

To do RLHF, we provide two functions:

1. Associate: Using this function, you can associate two phrases to give similar results. For examples, assume you're in the contract review domain. And you're interested in asking a question like "who are the parties involved in this contract?". However, most contracts have the phrase "made by and between" to suggest the parties involved in the contracts (like "this agreement is made by and between company A and company B"). In this scenario, you can simply call *model.associate(["parties involved","made by and between"])* and the model would learn the relation. In the subsequent documents, you're more likely to retrieve the passage containing the correct information.

2. Upvote: Let's say you searched for a query "is there a limited liability clause?" and you got 5 search results (along with their passage IDs). If you know that the correct result is actually the 2nd one instead of the first one. Then you can simply call *model.upvote("is there a limited liability clause",passage_id_of_the_best_search_result)*.

We provide two interfaces to do the teaching.

1. You can simply teach the model as shown below in python with model.associate() and model.upvote() calls. Refer to the "RLHF using Python functions".

2. You can save a checkpoint of your trained model and export it to our Playground web-app to do QnA and teaching on an intuitive UI.

### Option 1. RLHF using function calls 

In the above example, the model could not understand that the phrase "date of signing". But if you are an expert in contracts, you know that "date of signing" usually goes with phrases like "duly executed" (for example, "this Agreement has been duly executed by the parties hereto as of the latest date set forth below ..."). So, let's teach the model that these two phrases should retrieve similar passages.

In [53]:
rlhf_samples = [({"QUERY":"parties involved"},{"QUERY":"made by and between"})]

model.associate(rlhf_samples, 7)

Now, let's query the model again

In [54]:
query = "who are the parties involved in this agreement?"

references = get_references(query)
answer = get_answer(query, references)
print(answer)

The parties involved in this agreement are ACME, dba ToTheMoon Inc. with offices at 2025 Guadalupe St. Suite 260 Austin TX 78705 and StarWars dba ToTheMars with offices at the forest moon of Endor (Confidentiality Agreement).


There you go!!

### Option 2: Export your model to Playground App

ThirdAI's playground is a dockerized Gradio app that you can run on your laptop and use any model checkpoint to do QnA and teach using the above mentioned functions. 

Before you save your checkpoint, please go through the following short video tutorial to install Docker Desktop and download our image and run the webapp through a container.

https://drive.google.com/file/d/16tI1OAm2Lu0OuUOCiJzGrTjiBZejJWs3/view

In [None]:
## Coming Soon

After you save the checkpoint, please copy the .zip file to the folder from where you're running the docker container. And then go through this short video tutorial to do QnA and teach.

https://drive.google.com/file/d/1WIt2-EpYkQJpFgFiUXbc_iYU9uhOJdMn/view