In [None]:
!pip install transformers sentence-transformers qdrant-client torch accelerate

# Simple RAG


Step 1: offline > Add data to vector DB (Qdrant)

define model , embedding model, vectorDB

create collection, prepare documents, insert emb(doc) into the collection

create three functions for Retrieval, Prompt, Generation

In [2]:
## ============================= offline ================================

from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import VectorParams, Distance
from sentence_transformers import SentenceTransformer



qdrant = QdrantClient(":memory:")

# Create a Collection
qdrant.create_collection(
    collection_name="documents",
    vectors_config=VectorParams(
        size=384,
        distance=Distance.COSINE
    )
)

print(f"[TEST] collection created")



# add to collection: create vectors and then upsert into vectorDB
documents = [
    "RAG stands for Retrieval Augmented Generation.",
    "Qdrant is a vector database optimized for similarity search.",
    "Transformers library provides open source language models.",
    "Sentence Transformers generate embeddings for text."
]

embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

points = []
for idx, doc in enumerate(documents):
    embedding = embedder.encode(doc)

    points.append(
        PointStruct(
            id=idx,
            vector=embedding,
            payload={"text": doc}
        )
    )

qdrant.upsert(
    collection_name="documents",
    points=points
)

## ============================= Core ================================
def retrieve(query, top_k=3):          # encoder query, retrieve related vector and get payload
    # 1. Encode query
    query_vector = embedder.encode(query)

    # 2. Search Qdrant
    search_results = qdrant.query_points(
        collection_name  = "documents",
        query            = query_vector,
        limit            = top_k
    )

    # 3. Extract text from payload
    context = []
    for point in search_results:
            context.append(str(point))

    return context

def build_prompt(context, question):         # prompt is always a question, now it will contain the context also
    context_text = "\n".join(context)

    prompt = f"""
    Use the following context to answer the question.

    Context:
    {context_text}

    Question:
    {question}

    Answer:
    """
    return prompt

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

def generate_answer(question):                                                  # Generation: call Retrieval, call Prompt, tokenize the prompt, pass tp the model and return answer
    # 1. Retrieve text context
    context = retrieve(question)

    # 2. Build prompt
    prompt = build_prompt(context, question)

    # 3. Tokenize
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)

    # 4. Generate
    outputs = model.generate(**inputs, max_new_tokens=100)

    # 5. Decode
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer


## ============================= Full RAG Test ================================

question = "What is RAG?"
print(f"\n{question}: {generate_answer(question)}")

[TEST] collection created

What is RAG?: Retrieval Augmented Generation
