<a href="https://colab.research.google.com/github/AxesAccess/Implementing-a-Local-Retrieval-Augmented-Generation-System/blob/main/Implementing_a_Local_Retrieval_Augmented_Generation_System.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementing a Local Retrieval-Augmented Generation System

This is the source code for the article [Implementing a Local Retrieval-Augmented Generation System](https://axesaccess.github.io/Blog/posts/20250321-rag/index.html).

In [1]:
!pip install bitsandbytes faiss-cpu faiss-gpu-cu12 langchain langchain_community langchain_huggingface sentence-transformers --quiet

In [None]:
!pip install --upgrade numpy
!pip install --upgrade transformers

In [None]:
import os
import requests
import pickle
from bs4 import BeautifulSoup, SoupStrainer
from transformers import pipeline
from langchain.vectorstores.faiss import FAISS
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

We will parse the Wikipedia category page to extract links to articles.

In [4]:
from urllib.parse import urlparse


def fetch_links(url):
    links = []
    response = requests.get(url)
    soup = BeautifulSoup(response.text, "html.parser")
    domain = urlparse(url).netloc

    for ul in soup.find_all("ul"):
        for li in ul.find_all("li"):
            link = li.find("a")
            if link and "href" in link.attrs:
                href = link.attrs["href"]
                if "/wiki" in href[:5]:
                    links.append(f"https://{domain}{href}")

    return links

Set the `url` variable and get links.

In [5]:
url = "https://en.wikipedia.org/wiki/Category:Machine_learning_algorithms"
links = fetch_links(url)

Next, we will download articles as the docs.

Sometimes session in Colab restarts, so we'll save results to pickle.

In [6]:
# Restore saved earlier docs
try:
    with open("docs.pickle", "rb") as f:
        docs = pickle.load(f)
except FileNotFoundError:
    os.environ["USER_AGENT"] = (
        "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
        "AppleWebKit/537.36 (KHTML, like Gecko) "
        "Chrome/134.0.0.0 Safari/537.36 Edg/134.0.0.0"
    )
    loader = WebBaseLoader(
        # first 19 links are not relevant
        links[20:],
        bs_kwargs={
            "parse_only": SoupStrainer("div", {"class": "mw-body-content"}),
        },
        bs_get_text_kwargs={"separator": " ", "strip": True},
    )
    docs = loader.load()
    with open("docs.pickle", "wb") as f:
        pickle.dump(docs, f)

Here we break documents into shorter chunks—overlapping parts that should be provided to the LLM as context.

In [7]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
split_docs = text_splitter.split_documents(docs)

We need to perform quick search for relevant information, so let's transform texts to embeddings and load them into vector database.

In [8]:
from langchain_huggingface import HuggingFaceEmbeddings

model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
model_kwargs = {"device": "cuda"}
encode_kwargs = {"normalize_embeddings": False}
embedding = HuggingFaceEmbeddings(
    model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)

In [9]:
vector_store = FAISS.from_documents(split_docs, embedding=embedding)

Here we define function for retrieving relevant documents from the database.

In [10]:
def retrieve(query, top_k=2):
    documents = vector_store.search(query, "similarity")
    return documents[:top_k]

We will use local LLM. Let's authorize on HuggingFace which is mandatory for downloading certain models.

In [11]:
from huggingface_hub import login
from google.colab import userdata

login(token=userdata.get("HF_TOKEN"))

Download model and define tokenizer, and config.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch

MODEL_NAME = "Qwen/Qwen2.5-7B"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    load_in_8bit=True,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)

Let's put context retrieval and generation pipeline into a function.

In [13]:
gen_pipeline = pipeline(
    "text-generation", model=model, tokenizer=tokenizer, return_full_text=False
)


def generate_response(query):
    relevant_texts = retrieve(query)
    context = " ".join([t.model_dump()["page_content"] for t in relevant_texts])
    prompt = f"""Answer question using only information provided in the context.
    If the context contains no relevant information, say "I couldn't find the information".
    Context: '''{context}'''
    Question: {query}
    Answer:
    """
    response = gen_pipeline(prompt)
    return response[0]["generated_text"]

Device set to use cuda


In [14]:
import warnings

warnings.filterwarnings("ignore")

Here starts our Q&A session.

In [15]:
query = "What is the Actor-critic algorithm in reinforcement learning?"

answer = generate_response(query)
print(answer)

 The Actor-critic algorithm (AC) is a family of reinforcement learning (RL) algorithms that combine policy-based RL algorithms such as policy gradient methods, and value-based RL algorithms such as value iteration, Q-learning, SARSA, and TD learning. An AC algorithm consists of two main components: an "actor" that determines which actions to take according to a policy function, and a "critic" that evaluates those actions according to a value function. Some AC algorithms are on-policy, some are off-policy. Some apply to either continuous or discrete action spaces. Some work in both cases.


In [16]:
query = "What is the purpose of backpropagation in neural networks?"

answer = generate_response(query)
print(answer)

 The purpose of backpropagation in neural networks is to adjust the weights of the connections between neurons in order to minimize the error between the predicted output and the actual output. This is done by propagating the error backwards through the network, starting from the output layer and moving towards the input layer, hence the name "backpropagation". The goal is to iteratively update the weights so that the network can learn from its mistakes and improve its performance over time.


In [17]:
query = "Explain the concept of Curriculum learning in machine learning."

answer = generate_response(query)
print(answer)

 Curriculum learning in machine learning is a technique that involves gradually introducing more complex concepts or data to a model as it learns. This approach is inspired by the way humans learn, starting with simple concepts and building upon them. In the context provided, it is mentioned that Bengio et al. showed good results for problems in image classification and language modeling using curriculum learning strategies. They concluded that this technique can be beneficial for the model's performance on the test.


In [18]:
query = "How does K-nearest neighbors (K-NN) algorithm classify data?"

answer = generate_response(query)
print(answer)

 The K-nearest neighbors (K-NN) algorithm classifies data by a plurality vote of its neighbors, with the object being assigned to the class most common among its K nearest neighbors (K is a positive integer, typically small). If K = 1, then the object is simply assigned to the class of that single nearest neighbor.


In [19]:
query = "What is Federated Learning of Cohorts and how does it improve data privacy?"

answer = generate_response(query)
print(answer)

 Federated Learning of Cohorts (FLoC) is a type of web tracking that groups people into "cohorts" based on their browsing history for the purpose of interest-based advertising. It was being developed as a part of Google's Privacy Sandbox initiative, which includes several other advertising-related technologies with bird-themed names. FLoC was being tested in Chrome 89 as a replacement for third-party cookies. Despite "federated learning" in the name, FLoC does not utilize any federated learning. FLoC improves data privacy by grouping people into cohorts based on their browsing history, rather than tracking individual users. This means that advertisers can still target users based on their interests, but without the need for individual user data.


Looks good. Let's ask something that is not in the context. For instance, there was no articles on Transformer architecture among wiki articles.

In [20]:
query = (
    "How does the Transformer architecture improve upon traditional RNNs and LSTMs in NLP tasks?"
)

answer = generate_response(query)
print(answer)

 The Transformer architecture improves upon traditional RNNs and LSTMs in NLP tasks by using self-attention mechanisms to capture long-range dependencies between words in a sentence. This allows the model to process entire sentences at once, rather than one word at a time, which can lead to better performance on tasks such as machine translation and language modeling. Additionally, the Transformer architecture is more efficient than traditional RNNs and LSTMs, as it does not require sequential processing of input data.


That's interesting. To be sure that there's no information on this topic, let's check context.

In [21]:
retrieve(query)

[Document(id='8dfaf504-12a7-40ab-9681-caf72cb90e32', metadata={'source': 'https://en.wikipedia.org/wiki/Loss_functions_for_classification'}, page_content='Fei-Fei Li Alex Krizhevsky Ilya Sutskever Demis Hassabis David Silver Ian Goodfellow Andrej Karpathy Architectures Neural Turing machine Differentiable neural computer Transformer Vision transformer (ViT) Recurrent neural network (RNN) Long short-term memory (LSTM) Gated recurrent unit (GRU) Echo state network Multilayer perceptron (MLP) Convolutional neural network (CNN) Residual neural network (RNN) Highway network Mamba Autoencoder Variational autoencoder (VAE) Generative adversarial network (GAN) Graph neural network (GNN) Portals Technology Category Artificial neural networks Machine learning List Companies Projects Retrieved from " https://en.wikipedia.org/w/index.php?title=Loss_functions_for_classification&oldid=1261562183 "'),
 Document(id='c4a4a1c2-b912-41c8-8368-5e857c5da84a', metadata={'source': 'https://en.wikipedia.org/w

It appears that the query retrieved random parts of pages mentioning transformers. However, as they contained no valuable information, the answer was fully generated by the LLM. Although the response was accurate, we may want to enhance the retrieval function by setting a threshold for relevancy to minimize the risk of hallucinations.

Let's ask a question from the completely different domain.

In [22]:
query = "How does the process of photosynthesis work in plants?"

answer = generate_response(query)
print(answer)

 I couldn't find the information.


This question left unanswered. What about another one?

In [23]:
query = "How does blockchain technology ensure security and decentralization?"

answer = generate_response(query)
print(answer)

 Blockchain technology ensures security and decentralization through its decentralized nature, where data is stored across a network of computers rather than a single central authority. This makes it more difficult for hackers to compromise the system, as they would need to hack multiple nodes simultaneously. Additionally, blockchain uses cryptographic algorithms to secure data and transactions, making it nearly impossible to alter or manipulate the data without detection.


Unexpectedly, one of the documents contained information on this topic.

In [24]:
retrieve(query)

[Document(id='aca1234d-db32-450f-a279-8a9792b8535f', metadata={'source': 'https://en.wikipedia.org/wiki/Augmented_Analytics'}, page_content='to democratising data: Data Parameterisation and Characterisation. Data Decentralisation using an OS of blockchain and DLT technologies, as well as an independently governed secure data exchange to enable trust. Consent Market-driven Data Monetisation. When it comes to connecting assets, there are two features that will accelerate the adoption and usage of data democratisation: decentralized identity management and business data object monetization of data ownership. It enables multiple individuals and organizations to identify, authenticate, and authorize participants and organizations, enabling them to access services, data or systems across multiple networks, organizations, environments, and use cases. It empowers users and enables a personalized, self-service digital onboarding system so that users can self-authenticate without relying on a ce

One more question from another domain.

In [25]:
query = "What are the fundamental principles of classical mechanics?"

answer = generate_response(query)
print(answer)

 I couldn't find the information.
