## Retreival Augmented Generation
Setup github personal access token ([instructions](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens)).

Based on the RAG tutorials for huggingface at:
 - [zephyr + langchain](https://huggingface.co/learn/cookbook/rag_zephyr_langchain)

Additional resources
 - [rag with milvus](https://huggingface.co/learn/cookbook/rag_with_hf_and_milvus)

In [31]:
from utils.github import get_github_token
GITHUB_TOKEN = get_github_token()

### Setup FAISS with documents

In [69]:
from typing import Callable
from langchain.document_loaders import GithubFileLoader, GitHubIssuesLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings

def has_extension(ends: list[str]) -> Callable[[str],bool]:
    def check(path: str) -> bool:
        return path.split(".")[-1] in ends
    return check

chunked_issues = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=30).split_documents(GitHubIssuesLoader(repo="oliverkillane/emDB", access_token=GITHUB_TOKEN, include_prs=True, state="all").load())
docs = GithubFileLoader(repo="oliverkillane/emDB", access_token=GITHUB_TOKEN, file_filter=has_extension(["rs", "md", "toml"])).load()
chunked_code = RecursiveCharacterTextSplitter(chunk_size=4096, chunk_overlap=30).split_documents(docs)

db = FAISS.from_documents(chunked_code, HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5"))
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4})

### Setup the LLM

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_name = "HuggingFaceH4/zephyr-7b-beta"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

## Setup the chains

In [None]:
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from transformers import pipeline
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    temperature=0.2,
    do_sample=True,
    repetition_penalty=1.1,
    return_full_text=True,
    max_new_tokens=400,
)

llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
ASSISTANT_SPLIT = "<|assistant|>"
CONTEXT_SPLIT = "<|context|>"
USER_SPLIT = "<|user|>"
ANSWER_SPLIT = "<|answer|>"
prompt_template = f"""
<|system|>
Answer the question based on your knowledge. Use the following context to help:
{CONTEXT_SPLIT}
{{context}}
{USER_SPLIT}
{{question}}
{ASSISTANT_SPLIT}
"""

prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=prompt_template,
)

llm_chain = prompt | llm | StrOutputParser()
retriever = db.as_retriever()
rag_chain = {"context": retriever, "question": RunnablePassthrough()} | llm_chain

def ask(question: str) -> None:
    rag_full_answer = rag_chain.invoke(question)
    rag_answer = rag_full_answer.split(ASSISTANT_SPLIT)[1]
    rag_context = rag_full_answer.split(CONTEXT_SPLIT)[1].split(USER_SPLIT)[0]
    
    llm_answer = llm_chain.invoke({"context": "", "question": question}).split(ASSISTANT_SPLIT)[1]
    
    print(f"""
    LLM: {llm_answer}
    RAG CONTEXT: {rag_context}
    LLM + RAG: {rag_answer}
    """)

In [None]:
ask("Who works on emDB?")

In [None]:
ask("What data structures does emdb support for implementing tables?")

In [None]:
ask("Could you give me some basic code to create an emql table with one column (i32) called 'cool', and to then query for all elements in order.")

In [None]:
ask("How can I build emdb, how do I run tests? How about benchmarks?")

In [None]:
ask("What is combi? And what is pulpit?")

In [None]:
ask("What is the window pattern in emDB, why is it necessary?")