# RAG — A Basic Implementation

In this notebook, we implement a basic rag system with the `fed-rag` framework,
which is an open-sourced library that facilitates both centralized as well as 
federated fine-tuning of RAG systems.

While we won't fine-tune a RAG system in this notebook, we can still make use of
the framework to perform inference with RAG systems. Here, we'll use the HuggingFace
integration/extra to build a RAG system with a HuggingFace PeftModel as the LLM
Generator and a SentenceTransformer as the Retriever.


In [None]:
%pip install "fed-rag[huggingface]" -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## Generator

In [None]:
from fed_rag.generators.hf_peft_model import HFPeftModelGenerator
from transformers.generation.utils import GenerationConfig
from transformers.utils.quantization_config import BitsAndBytesConfig

PEFT_MODEL_NAME = "Styxxxx/llama2_7b_lora-quac"
BASE_MODEL_NAME = "meta-llama/Llama-2-7b-hf"

generation_cfg = GenerationConfig(
    do_sample=True,
    eos_token_id=[128000, 128009],
    bos_token_id=128000,
    max_new_tokens=4096,
    top_p=0.9,
    temperature=0.6,
    cache_implementation="offloaded",
    stop_strings="</response>",
)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
generator = HFPeftModelGenerator(
    model_name=PEFT_MODEL_NAME,
    base_model_name=BASE_MODEL_NAME,
    generation_config=generation_cfg,
    load_model_at_init=False,
    load_model_kwargs={"is_trainable": True, "device_map": "auto"},
    load_base_model_kwargs={
        "device_map": "auto",
        "quantization_config": quantization_config,
    },
)

  from .autonotebook import tqdm as notebook_tqdm


PackageNotFoundError: No package metadata was found for bitsandbytes

## Retriever

In [None]:
from fed_rag.retrievers.hf_sentence_transformer import (
    HFSentenceTransformerRetriever,
)

retriever = HFSentenceTransformerRetriever(
    query_model_name="nthakur/dragon-plus-query-encoder",
    context_model_name="nthakur/dragon-plus-context-encoder",
    load_model_at_init=False,
)

## Knowledge Store

In [None]:
import json

from fed_rag.knowledge_stores.in_memory import InMemoryKnowledgeStore
from fed_rag.types.knowledge_node import KnowledgeNode, NodeType

# knowledge chunks
chunks_json_strs = [
    '{"id": "0", "title": "Orchid", "text": "Orchids are easily distinguished from other plants, as they share some very evident derived characteristics or synapomorphies. Among these are: bilateral symmetry of the flower (zygomorphism), many resupinate flowers, a nearly always highly modified petal (labellum), fused stamens and carpels, and extremely small seeds"}'
    '{"id": "1", "title": "Tulip", "text": "Tulips are easily distinguished from other plants, as they share some very evident derived characteristics or synapomorphies. Among these are: bilateral symmetry of the flower (zygomorphism), many resupinate flowers, a nearly always highly modified petal (labellum), fused stamens and carpels, and extremely small seeds"}'
]
chunks = [json.loads(line) for line in chunks_json_strs]


knowledge_store = InMemoryKnowledgeStore()

# create knowledge nodes
nodes = []
for c in chunks:
    node = KnowledgeNode(
        embedding=retriever.encode_context(c["text"]).tolist(),
        node_type=NodeType.TEXT,
        text_content=c["text"],
        metadata={"title": c["title"], "id": c["id"]},
    )
    nodes.append(node)

# load into knowledge_store
knowledge_store.load_nodes(nodes=nodes)

## RAG System

In [None]:
from fed_rag.types.rag_system import RAGConfig, RAGSystem

rag_config = RAGConfig(top_k=2)
rag_system = RAGSystem(
    knowledge_store=knowledge_store,
    generator=generator,
    retriever=retriever,
    rag_config=rag_config,
)

In [None]:
# query the rag system
response = rag_system.query("What is a Tulip?")

print(f"\n{response}")

# inspect source nodes
print(response.source_nodes)