### Load chunks

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.docstore.document import Document
import os


files = []
for fname in os.listdir("../sources"):
    complete_path = os.path.join("../sources", fname)
    if os.path.isfile(complete_path):
        files.append(complete_path)

chunk_size = 2_000
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = chunk_size,
    chunk_overlap = chunk_size/10,
    add_start_index = True,
)

docs = []
page_contents = []
giant_docs = []

for file in files:
    loader = PyMuPDFLoader(file, extract_images = False)
    giant_doc = {"page_content": "", "metadata": ""}
    first = True
    async for doc in loader.alazy_load():
        if first:
            metadata = {k: v for k, v in doc.metadata.items() if k != "page"}
            giant_doc["metadata"] = metadata
        giant_doc["page_content"] += doc.page_content
        first = False
    giant_docs.append(giant_doc)

for gdoc in giant_docs:
    page_contents = text_splitter.split_text(gdoc["page_content"])
    docs += [{"metadata": gdoc["metadata"], "page_content": pc} for pc in page_contents]

docs = [Document(metadata = doc["metadata"], page_content = doc["page_content"]) for doc in docs]

### Context cleaning for QA generation

In [None]:
import re


def decapitalize_content(pages: list[str]):

    """Turns document content into lower case"""

    for p in pages:
        p.page_content = p.page_content.lower()


def remove_non_ASCII(pages: list[str]):

    """Removes non ASCII characters from document. Not suitable for many non english languages 
    which have several non ASCII characters """

    for p in pages:
        if "non-en" not in p.metadata["keywords"]:
            p.page_content = re.sub(r"[^\x00-\x7F]+", "", p.page_content)


def remove_bullets(pages: list[str]):

    """Removes bullets from document """

    for p in pages:
        p.page_content = re.sub(r"^[→•▪\-*✔➢●✗]\s*", "", p.page_content, flags = re.MULTILINE)
        p.page_content = re.sub(r"\d+\.(?=\s*[a-zA-Z])", "", p.page_content)


def remove_escape(pages: list[str]):

    """Turns multiple consecutive escape characters into a single white space"""
    
    for p in pages:
        p.page_content = ' '.join(p.page_content.split())


remove_non_ASCII(docs)
decapitalize_content(docs)
remove_bullets(docs)
remove_escape(docs)

### Setup agent for questions generation

In [None]:
from huggingface_hub import InferenceClient
import json


repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"

llm_client = InferenceClient(
    model=repo_id,
    timeout = None,
)


def call_llm(inference_client: InferenceClient, prompt: str):
    response = inference_client.post(
        json={
            "inputs": prompt,
            "parameters": {"max_new_tokens": 1000},
            "task": "text-generation",
        },
    )
    return json.loads(response.decode())[0]["generated_text"]

QA_generation_prompt = """
Your task is to write a factoid question and an answer given a context.
Your factoid question should be answerable with a specific, concise piece of factual information from the context.
Your factoid question should be formulated in the same style as questions users could ask in a search engine.
This means that your factoid question MUST NOT mention something like "according to the passage" or "context".

Provide your answer as follows:

Output:::
Factoid question: (your factoid question)
Answer: (your answer to the factoid question)

Now here is the context.

Context: {context}\n
Output:::"""

### Generate questions

In [None]:
import random
import json
from tqdm.auto import tqdm


N_GENERATIONS = 10
print(f"Generating {N_GENERATIONS} QA couples...")

outputs = []
for sampled_context in tqdm(random.sample(docs, N_GENERATIONS)):
    # Generate QA couple
    output_QA_couple = call_llm(llm_client, QA_generation_prompt.format(context=sampled_context.page_content))
    try:
        question = output_QA_couple.split("Factoid question: ")[-1].split("Answer: ")[0]
        answer = output_QA_couple.split("Answer: ")[-1]
        assert len(answer) < 300, "Answer is too long"
        outputs.append(
            {
                "context": sampled_context.page_content,
                "question": question,
                "answer": answer,
                "source_doc": sampled_context.metadata["source"],
            }
        )
    except:
        continue

with open("dataset/all_QA.json", "w") as f:
    json.dump(outputs, f)