# Q&A pipeline with Retrieval Augmented Generation (RAG) 

### Bence Csabai - KPMG Machine Learning Engineer interview

The following notebook serves as a demo for a small project for the KPMG Machine Learning position take home assignment. The task was to create a chatbot that answers mock "business inquiries", with the following instructions:
- The solution should be based on a RAG (Retreival Augmented Generation model) which consists of
    - a Retriever layer
    - an LLM Layer
- The suggested [dataset](https://huggingface.co/datasets/wikipedia/viewer/20220301.en) to use for the Retriever is the English wikipedia dataset from HuggingFace. Instead I used its smaller brother, the simplified Wikipedia [dataset](https://huggingface.co/datasets/wikipedia/viewer/20220301.simple)
- "Business inquiries" for this example should be considered anything a person might ask from e.g. the Wikipedia dataset

### RAG

LLMs have proven to be powerful and versatile tools for a broad range of applications. They are however (as their name suggests) very large, therefore training them for a new specific task from scratch requires tremendous amount of computational power, meaning a lot of time and resources would have to be invested. One of the ways to overcome this issue is to use a pretrained model, and tailor it to our needs. A method for this is using Retrieval Augmented Generation, or RAG models. RAGs consist of two main blocks: the retriever and the LLM block.

![RAG Structure](F:/Users/bence/Documents/_PROJECTS/KPMG_LLM/large-language-models/rag/rag_s.png)

**Retriever block**

The aim of the retriever block is to create a richer prompt for the LLM block by adding context to the original query. It achieves this by having its own knowledge database. In this case this is the simplified wikipedia database from HuggingFace. It then converts this text based database into vector representations with an embedding network (in this case a pretrained transformer model, which transforms the text into a 384 dimensional vector space). The original prompt/query is also passed through the same embedding network, ensuring that the knowledge base and the query will be in the same representation space as the data. 

Afterwards, the closest *k* vectors to the query vector are found and selected as context. This is done using the Facebook AI Similarity Search (FAISS) which is a library that is optimized to find nearest neighbors fast. The corresponding texts from the document dataset are then added to a template prompt (e.g. "Consider the following context and then answer the question: *context*, *question*")

**LLM block**

The LLM block, as its name suggests it contains a pretrained LLM, which answers the enriched prompt. In this example, I used a pretrained Llama 2 model (the open source LLM model developed by Meta). The LLM is expected to produce a better quality result then the base LLM. This is tested at the end of the notebook

### How to use? 

- Enter a query in the appropriate cell (QUERY variable)
- Run all the code
- **NOTES**: the RAG model can take a few minutes to produce a result, as it gets a long prompt and the model runs locally

In [1]:
from langchain.docstore.document import Document
from langchain.document_loaders import HuggingFaceDatasetLoader

from encoder.encoder import Encoder
from generator.generator import Generator
from retriever.vector_db import VectorDatabase

This following cell contains the template for the enhanced prompt: {context} gets replaced with the chosen data from the Wikipedia dataset, while {question} gets replaced with the original query

In [2]:
TEMPLATE = """
Use the following pieces of context to answer the question at the end. 
{context}
Question: {question}
Answer:
"""

In [3]:
# load wikipedia dataset
loader = HuggingFaceDatasetLoader("wikipedia", name="20220301.simple")
docs = loader.load()[:100]

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
#set device to cuda (just in case)
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# initiate our classes for the Encoder, Retriever and Generator
encoder = Encoder()
faiss_db = VectorDatabase()
generator = Generator(TEMPLATE)

AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 0 | VSX = 0 | 


### Enter the query

In [8]:
QUERY = "When is April Fools day?"

In [9]:
#QUERY_USER = input('What would you like to know?\n')
#QUERY = QUERY_USER

### Find 
Convert docs and query to vector space and find the closest neighbors of the query using the FAISS library

In [11]:
# Create passages of same length from documents
# This should be able to split documents, but I do not think I got it to work correctly, 
# therefore the whole articles get represented in the vector space

passages = faiss_db.create_passages(docs)
faiss_db.store_passages_db(passages, encoder.encoder)


# retrieve the k most similar documents to our query

context = faiss_db.retrieve_most_similar_document(QUERY, k=3)

### Answers - comparison

Get results from both the RAG model, and the base LLaMa model and compare them

In [12]:
# RAG response
print(generator.get_answer(context[:4096], QUERY))

April Fools day, also known as All Fools Day or Poisson d’Avril in French, occurs every year on the first of April. It is a lighthearted day when people play practical jokes and hoaxes on each other, often making fun of someone else's gullibility or naivety. The origins of this holiday are uncertain, but one theory suggests that it stems from the custom in France during the 18th century to celebrate the New Year with practical jokes and pranks on April 1st. Another theory is that it comes from the adoption of Christianity, where the first day of April was considered a sacred day and thus unlucky for starting new projects or initiatives. Regardless of its origins, many people around the world celebrate April Fools Day by playing pranks on each other, such as sending fake news stories to their friends and family or creating elaborate hoaxes that they hope will be revealed as such at the end of the day. Some people even plan elaborate prank events like an April Fools' Day festival or para

In [13]:
# Base LLM model
print(generator.get_answer('', QUERY))

Llama.generate: prefix-match hit


1) April Fools Day is a holiday in many countries, including the United States and Canada, where people play practical jokes on one another and perform pranks as a way of celebrating the beginning of spring.
2) It takes place annually on April 1st.


### Future possibilities

While the model produces an acceptable rewards and we can observe instances where the RAG model yields better answers than the base LLM, there are still a number of changes that could be made to further improve the quality. Limitations for this solution included lack of computing power and disk space (only the first 100 instances of the simplified dataset are used), and limited time for tweaking and testing. The following is a non exhaustive list of some of the ideas, that could be done, given more time and resources:

- **Model experimenting**
    - Try different versions of different open source LLM models (e.g. different Llama versions, Falcon, Yi, etc.) and compare speed and quality
    - Try different models for encoding
    - Try different ranking models, not just necessarily nearest neighbors (could be smaller LLM based?)
- **Hyperparameter tuning**
    - Change different non-learnable parameters, such as
        - The length of the passages - Doc encoded by few sentences instead of article? This would probably increase the time taken to create the representation but would probably speed the answer generation therefore increasing user experience. Does this affect quality?
        -  How does number of nearest neighbors affect speed and quality?
- **Scale and Speed**
    -  Find ways to speed up the answer generation
        - Provide shorter contexts for shorter overall prompts?
        - Create multiple smaller retrievers with different context groups?
        - Other  
    -  Try the model with more powerful hardware -> able to load more data, generate response quicker

### Source

Most of the project was based on the following github repo:
https://github.com/zaai-ai/large-language-models/blob/main/rag/

The base skeleton and structure of the model remains. I spent most of my time diving deeper and understanding the model's theory rather than optimizing the code. However there were many changes needed to be made to configure the solution for the task at hand.