In [22]:
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from openai import OpenAI
from glob import glob
from collections import defaultdict
import json
import re

# Helper functions

In [23]:
def get_chunk(path = "data", type = "pdf", chunk_size = 2048, chunk_overlap = 100, references = False): 
    """
    Only support pdf for now, references = True to keep reference section, refereces = False to remove reference section
    """
    if type != "pdf":
        raise TypeError("Currently only support pdf files")

    chunks_dict = defaultdict(list)

    if references == True:
        loader = PyPDFDirectoryLoader(path, glob = "**/[!.]*.pdf")
        chunks = loader.load_and_split(RecursiveCharacterTextSplitter(chunk_size = chunk_size, chunk_overlap = chunk_overlap))
        for chunk in chunks:
            chunks_dict[chunk.metadata['source']].append(chunk.page_content)
        
        return chunks_dict

    else:
        pdf_files = glob(os.path.join(path, "**/[!.]*.pdf"), recursive= True)
        for f in pdf_files:
            loader = PyPDFLoader(f)
            chunks = loader.load_and_split(RecursiveCharacterTextSplitter(chunk_size = chunk_size, chunk_overlap = chunk_overlap))
            
            possible_name = ["\nReferences\n", "\nREFERENCES\n"]
            found = False
            index = 0
            n = len(chunks)
            for j in possible_name:
                if found == True:
                    break
                for i in range(n-1, -1, -1):
                    if j in chunks[i].page_content:
                        chunks[i].page_content = chunks[i].page_content[:chunks[i].page_content.rindex(j)]
                        index = i
                        found = True
                        break
            chunks = chunks[:index+1] if index != 0 else chunks    
            chunks_dict[f].extend([chunk.page_content for chunk in chunks])
            
        return chunks_dict

def generate_questions(chunk, num = 3, model = "llama3"):
    """
    Generates `num` questions / use cases for `chunk`. Used when the input document is of general types 
    """
    messages=[
                {"role": "system", "content": "You are a synthetic question-answer pair generator. Given a chunk of context about some topic(s), generate %s example questions a user could ask and that question could be able to answer using information from the chunk. For example, if the given context has information about supercomputer, an example question could be 'What is a supercomputer?'" % (num)},
                {"role": "system", "content": "The questions should be able to be answered in a few words or less. Show the example questions in numbered list. Every questions MUST end with a question mark"},
                {"role": "user", "content": str(chunk)}
            ]

    response = client.chat.completions.create(
    model=model,
    messages=messages
    )
    queries = response.choices[0].message.content.split('\n')

    # Only include questions
    queries = [q for q in queries if q.endswith("?") and not (q.startswith("You are a synthetic"))]

    return [re.sub(r'^[\d+\.|*+\.]+\s', '', q) for q in queries] # If questions start with numbers or stars, remove them.

def encode_question(question, chunk):
    """
    Encode multiple prompt instructions into a single string to generate correct answer.
    Using chain of thought improve answer accuracy
    """
    
    prompts = []
        
    prompt = """
        Question: {question}\nContext: {context}\n
        Answer this question using the information given in the context above. Here is things to pay attention to: 
        - First provide step-by-step reasoning on how to answer the question. 
        - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context. 
        - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
        You MUST begin your final answer with the tag "<ANSWER>:".
    """.format(question=question, context=str(chunk))
    prompts.append({"role": "system", "content": "You are a helpful question answerer who can provide an answer given a question and relevant context."})
    prompts.append({"role": "user", "content": prompt})
    return prompts

def encode_question_incorrect(question, chunk, num_answer=4):
    """
    Encode multiple prompt instructions into a single string to generate incorrect answers.
    """
    
    prompts = []
        
    prompt = """
        Question: {question}\nContext: {context}\n
        Answer this question incorrectly in {num_answer} ways. 
        The incorrect answers should be succinct.        
    """.format(question=question, context=str(chunk), num_answer=str(num_answer))
    prompts.append({"role": "system", "content": "You are a bad question answerer who can provide incorrect answers given a question and relevant context."})
    prompts.append({"role": "user", "content": prompt})
    return prompts

def generate_label(question, chunk, model = "llama3"):
    """
    Generates the correct answer to `question` using `context`.
    """
    question = encode_question(question, chunk)
    response = client.chat.completions.create(
        model=model,
        messages=question,
        n=1,
        temperature=0
    )
    queries = response.choices[0].message.content

    # Only include Answer
    beg = "<ANSWER>:"
    try:
        start = queries.rindex(beg)
        queries = queries[start+len(beg)+1:]
    except:
        pass
    
    return queries


def generate_incorrect_answer(question, chunk, num_answer = 4, model = "llama3"):
    """
    Generates {num_answer} incorrect answers to `question`.
    """
    question = encode_question_incorrect(question, chunk, num_answer)
    response = client.chat.completions.create(
        model=model,
        messages=question,
        n=1,
        temperature=0.8 # increase temperature for LLM be more creative with incorrect answer
    )
    queries = response.choices[0].message.content.split('\n')
    pattern = r'^[\d+\.|\*+\.|\*\*Answer:\*\*|\d+\.+\s+\*\*Answer:\*\*]+\s'
    
    return [re.sub(pattern, '', a) for a in filter(None, queries) if a[0].isdigit()] # return list of incorrect answers

def run(i, chunk, num = 3, num_answer = 4, model = "llama3"):
    """
    Given a chunk, create {Questions, answer, incorrect answers}.
    """
    res = []
    questions = generate_questions(chunk, num, model)
    for j, q in enumerate(questions):
        datapt = {
            "id": None,
            "context": None,
            "question": None,
            "correct_answer": None,
            "incorrect_answers": None
        }
        datapt["id"] = f"chunk_{i}_question_{j}" # id of chunk_question for easier tracking
        datapt["question"] = q
        datapt["context"] = chunk.split("\n")

        # add answer to data
        datapt["correct_answer"] = generate_label(q, chunk, model) 

        # add incorrect answer to data
        datapt["incorrect_answers"] = generate_incorrect_answer(q, chunk, num_answer, model) 

        res.append(datapt)
    return res 

In [24]:
CHUNK_SIZE = 2000
NUM_INCORRECT_ANSWERS = 4
CHUNK_OVERLAP = 100
NUM_QUESTION = 3
MODEL = "llama3" # local LLM downloadable from Ollama or gpt

# init OpenAI client
client = OpenAI(
    base_url = 'http://localhost:11434/v1', # remove this line if using gpt
    api_key='ollama', # [ollama, OPENAI_API_KEY] local LLM or using gpt
)

In [25]:
data = defaultdict(list)
path = "data/pdf/RAG_papers" # path to folder 
type = "pdf" # only support pdf for now
chunks_dict = get_chunk(path=path, type=type, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, references = False) 

# generate question, correct answer, incorrect answer
for (source, chunks) in chunks_dict.items():
    for i, chunk in enumerate(chunks):
        data[source].extend(run(i, chunk, num = NUM_QUESTION, num_answer = NUM_INCORRECT_ANSWERS, model = MODEL))

In [26]:
# save data to a json file
out_path = 'output/QA_RAG_AI4S_2000.json'
with open(out_path, 'w', encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=4)