In [1]:
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from openai import OpenAI
import random
import json
# import datasets
# from datasets import Dataset, load_dataset
# import logging

In [2]:
CHUNK_SIZE = 512
NUM_DISTRACT_DOCS = 2
CHUNK_OVERLAP = 20
NUM_QUESTION = 3
P = 0.8
MODEL = "llama3"

# Helper function

In [3]:
def get_chunk(path = "data", type = "md", chunk_size = CHUNK_SIZE, chunk_overlap = CHUNK_OVERLAP):

    if type == "md":
        # Load all file ends with .md
        loader = DirectoryLoader(path, glob="**/*.md", loader_cls=UnstructuredMarkdownLoader)
    elif type == "pdf":
        loader = PyPDFDirectoryLoader(path)
    else:
        raise TypeError("Only accept pdf and md")
    
    chunks = loader.load_and_split(RecursiveCharacterTextSplitter(chunk_size = chunk_size, chunk_overlap = chunk_overlap))
    chunks = [chunk.page_content for chunk in chunks]

    return chunks


def generate_questions(chunk, num = NUM_QUESTION, model = MODEL):
    """
    Generates `num` questions / use cases for `chunk`. Used when the input document is of general types 
    """
    messages=[
                {"role": "system", "content": "You are a synthetic question-answer pair generator. Given a chunk of context about some topic(s), generate %s example questions a user could ask and would be answered using information from the chunk. For example, if the given context was a manual instruction paragraph about the Polaris, an example question could be 'How to login to Polaris?'" % (num)},
                {"role": "system", "content": "The questions should be able to be answered in a few words or less. Include only the questions in your response."},
                {"role": "user", "content": str(chunk)}
            ]

    response = client.chat.completions.create(
    model=model,
    messages=messages
    )
    queries = response.choices[0].message.content.split('\n')
    beg = "1. "
    index = 2
    for i, q in enumerate(queries):
        if q[:3] == beg:
            index = i
            break
    
    return queries[index:]

def encode_question(question, chunk):
    """
    Encode multiple prompt instructions into a single string.
    """
    
    prompts = []
        
    prompt = """
        Question: {question}\nContext: {context}\n
        Answer this question using the information given in the context above. Here is things to pay attention to: 
        - First provide step-by-step reasoning on how to answer the question. 
        - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context. 
        - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
        You MUST begin your final answer with the tag "<ANSWER>:".
    """.format(question=question, context=str(chunk))
    prompts.append({"role": "system", "content": "You are a helpful question answerer who can provide an answer given a question and relevant context."})
    prompts.append({"role": "user", "content": prompt})
    return prompts

def generate_label(question, chunk, model = "llama3"):
    """
    Generates the label / answer to `question` using `context`.
    """
    question = encode_question(question, chunk)
    response = client.chat.completions.create(
        model=model,
        messages=question,
        n=1,
        temperature=0
    )
    response = response.choices[0].message.content
    return response

def add_chunk_to_dataset(i, chunks, chunk, num = NUM_QUESTION, num_distract = NUM_DISTRACT_DOCS, p = P, model = MODEL):
    """
    Given a chunk, create {Q, A, D} triplets and add them to the dataset.
    """
    res = []
    qs = generate_questions(chunk, num, model)
    for j, q in enumerate(qs):
        datapt = {
            "id": None,
            "question": None,
            "context": None,
            "oracle_context": None,
            "cot_answer": None
        }

        datapt["id"] = f"seed_task_{i}_{j}"
        datapt["question"] = q

        # add num_distract distractor docs
        docs = [chunk]
        indices = list(range(0, len(chunks)))
        indices.remove(i)
        for k in random.sample(indices, num_distract):
            docs.append(chunks[k])
        # decides whether to add oracle document
        oracle = random.uniform(0, 1) < p
        if not oracle:
            docs[0] = chunks[random.sample(indices, 1)[0]]
        random.shuffle(docs)

        datapt["context"] = docs
        datapt["oracle_context"] = chunk

        # add answer to q
        datapt["cot_answer"] = generate_label(q, chunk, model=model) 

        # construct model instruction 
        context = ""
        for doc in docs:
            context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
        context += q
        datapt["instruction"] = context
        res.append(datapt)
    return res        
    
def save_checkpoint(state, filename):
    with open(filename, 'w') as f:
        f.write(str(state))

def load_checkpoint(filename):
    with open(filename, 'r') as f:
        return int(f.read())

In [4]:
client = OpenAI(
    base_url = 'http://localhost:11434/v1',
    api_key='ollama', # required, but unused
)

data = []
chunks = get_chunk("data/polaris", "md", CHUNK_SIZE, CHUNK_OVERLAP)

for i, chunk in enumerate(chunks):
    data.extend(add_chunk_to_dataset(i, chunks, chunk))
    if i == 10:
        break

with open('output/data_md.json', 'w', encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=4)

100%|██████████| 49/49 [00:03<00:00, 13.27it/s]


In [5]:
data = []
chunks = get_chunk("data/pdf", "pdf", CHUNK_SIZE, CHUNK_OVERLAP)

for i, chunk in enumerate(chunks):
    data.extend(add_chunk_to_dataset(i, chunks, chunk))
    if i == 10:
        break

with open('output/data_pdf.json', 'w', encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=4)