In [1]:
# Install dependencies

!pip install pymilvus
!pip install ragas
!pip install pymilvus[milvus_lite]
!pip install transformers torch
!pip install evaluate

Collecting pymilvus
  Downloading pymilvus-2.6.2-py3-none-any.whl.metadata (6.5 kB)
Collecting ujson>=2.0.0 (from pymilvus)
  Downloading ujson-5.11.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (9.4 kB)
Downloading pymilvus-2.6.2-py3-none-any.whl (258 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.8/258.8 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ujson-5.11.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (57 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.4/57.4 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ujson, pymilvus
Successfully installed pymilvus-2.6.2 ujson-5.11.0
Collecting ragas
  Downloading ragas-0.3.6-py3-none-any.whl.metadata (21 kB)
Collecting appdirs (from ragas)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting diskcache>=5.6.3 (from ragas)
  Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 

In [2]:
# Imports

import pandas as pd
import transformers, torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from datasets import load_dataset
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
import torch
import evaluate

In [3]:
# ============================================================================
# 1. Load and Analyze Dataset
# ============================================================================

print("Loading dataset from parquet...")
wiki_passages = pd.read_parquet(
    "hf://datasets/rag-datasets/rag-mini-wikipedia/data/passages.parquet/part.0.parquet"
)

# Basic statistics
print("\n" + "="*80)
print("DATASET STATISTICS")
print("="*80)
text_lengths = wiki_passages['passage'].str.len()
print(f"Passage length statistics:")
print(f"  Min: {text_lengths.min()} characters")
print(f"  Max: {text_lengths.max()} characters")
print(f"  Mean: {text_lengths.mean():.1f} characters")
print(f"  Median: {text_lengths.median():.1f} characters")

# Check for nulls
missing_values = wiki_passages.isnull().sum()
print(f"\nNull values per column:")
print(missing_values)

# Drop nulls and clean
print(f"\nShape before dropping nulls: {wiki_passages.shape}")
wiki_passages = wiki_passages.dropna()
print(f"Shape after dropping nulls: {wiki_passages.shape}")

# Limit dataset size for processing
MAX_PASSAGES = 1000
wiki_passages = wiki_passages.head(MAX_PASSAGES)
print(f"\nUsing {len(wiki_passages)} passages for RAG system")

# Load QA dataset
print("\nLoading QA dataset...")
qa_dataset = load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer")
test_questions = qa_dataset["test"]
print(f"Loaded {len(test_questions)} Q&A pairs")

# Display sample
print("\n" + "="*80)
print("SAMPLE DATA")
print("="*80)
print(f"Sample passage:\n{wiki_passages.iloc[0]['passage'][:300]}...\n")
print(f"Sample question: {test_questions[0]['question']}")
print(f"Sample answer: {test_questions[0]['answer']}")

Loading dataset from parquet...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



DATASET STATISTICS
Passage length statistics:
  Min: 1 characters
  Max: 2515 characters
  Mean: 389.8 characters
  Median: 299.0 characters

Null values per column:
passage    0
dtype: int64

Shape before dropping nulls: (3200, 1)
Shape after dropping nulls: (3200, 1)

Using 1000 passages for RAG system

Loading QA dataset...


README.md:   0%|          | 0.00/719 [00:00<?, ?B/s]

data/test.parquet/part.0.parquet:   0%|          | 0.00/54.4k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/918 [00:00<?, ? examples/s]

Loaded 918 Q&A pairs

SAMPLE DATA
Sample passage:
Uruguay (official full name in  ; pron.  , Eastern Republic of  Uruguay) is a country located in the southeastern part of South America.  It is home to 3.3 million people, of which 1.7 million live in the capital Montevideo and its metropolitan area....

Sample question: Was Abraham Lincoln the sixteenth President of the United States?
Sample answer: yes


In [4]:
# ============================================================================
# 2. Chunk Documents
# ============================================================================

def create_text_chunks(text_content, chunk_size=600):
    """Split text into chunks of chunk_size with no overlap."""
    if not text_content or pd.isna(text_content):
        return []
    text_content = str(text_content)  # Ensure string type
    return [text_content[i:i+chunk_size] for i in range(0, len(text_content), chunk_size)]

print("\n" + "="*80)
print("CREATING CHUNKS")
print("="*80)

document_chunks = []
for passage_id, row in wiki_passages.iterrows():
    passage_content = row['passage']
    text_segments = create_text_chunks(passage_content)
    for segment_idx, segment_text in enumerate(text_segments):
        document_chunks.append({
            "chunk_id": f"{passage_id}-{segment_idx}",
            "content": segment_text,
            "source_passage_id": passage_id
        })

print(f"Total chunks created: {len(document_chunks)}")
print(f"Average chunks per passage: {len(document_chunks)/len(wiki_passages):.2f}")
print(f"Sample chunk: {document_chunks[0]['content'][:300]}...")


CREATING CHUNKS
Total chunks created: 1289
Average chunks per passage: 1.29
Sample chunk: Uruguay (official full name in  ; pron.  , Eastern Republic of  Uruguay) is a country located in the southeastern part of South America.  It is home to 3.3 million people, of which 1.7 million live in the capital Montevideo and its metropolitan area....


In [5]:
# ============================================================================
# 3. Generate Embeddings
# ============================================================================

print("\n" + "="*80)
print("GENERATING EMBEDDINGS")
print("="*80)

embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
chunk_texts = [chunk["content"] for chunk in document_chunks]

# Encode in batches to reduce memory usage
chunk_embeddings = embedding_model.encode(
    chunk_texts,
    batch_size=64,
    show_progress_bar=True,
    normalize_embeddings=True,
    convert_to_numpy=True
)
chunk_embeddings = chunk_embeddings.astype("float32")

embedding_dim = chunk_embeddings.shape[1]
print(f"\nEmbeddings shape: {chunk_embeddings.shape}")
print(f"Embedding dimension: {embedding_dim}")
assert embedding_dim == 384, f"Unexpected dimension {embedding_dim}"


GENERATING EMBEDDINGS


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/21 [00:00<?, ?it/s]


Embeddings shape: (1289, 384)
Embedding dimension: 384


In [6]:
# ============================================================================
# 4. Setup Milvus Vector Database
# ============================================================================

print("\n" + "="*80)
print("SETTING UP MILVUS")
print("="*80)

# Define schema
chunk_id_field = FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False)
passage_field = FieldSchema(name="passage", dtype=DataType.VARCHAR, max_length=3000)
embedding_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=384)
collection_schema = CollectionSchema(fields=[chunk_id_field, passage_field, embedding_field])

print("Schema defined with fields: id, passage, embedding")


SETTING UP MILVUS
Schema defined with fields: id, passage, embedding


In [7]:
# ============================================================================
# 5. Create Collection and Insert Data
# ============================================================================

print("\n" + "="*80)
print("CREATING COLLECTION AND INSERTING DATA")
print("="*80)

from pymilvus import MilvusClient

milvus_client = MilvusClient("rag_wikipedia_mini.db")
milvus_client.create_collection(collection_name="rag_mini", schema=collection_schema)
print("Collection 'rag_mini' created")

# Prepare data for insertion
rag_data = [
    {"id": i, "passage": chunk["content"], "embedding": chunk_embeddings[i].tolist()}
    for i, chunk in enumerate(document_chunks)
]

insert_response = milvus_client.insert(collection_name="rag_mini", data=rag_data)
print(f"Insert response: {insert_response}")

# Verify insertion
entity_count = milvus_client.get_collection_stats("rag_mini")["row_count"]
print(f"Entity count: {entity_count}")
print(f"Collection schema: {milvus_client.describe_collection('rag_mini')}")


CREATING COLLECTION AND INSERTING DATA
Collection 'rag_mini' created
Insert response: {'insert_count': 1289, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 19

In [8]:
# ============================================================================
# 6. Create Index and Load Collection
# ============================================================================

print("\n" + "="*80)
print("CREATING INDEX AND LOADING COLLECTION")
print("="*80)

index_parameters = milvus_client.prepare_index_params()
index_parameters.add_index(field_name="embedding", metric_type="COSINE")

try:
    milvus_client.create_index(collection_name="rag_mini", index_params=index_parameters)
    print("Index created successfully")
except Exception as index_error:
    print(f"Index creation result: {index_error}")

milvus_client.load_collection("rag_mini")
print("Collection loaded into memory")



CREATING INDEX AND LOADING COLLECTION
Index created successfully
Collection loaded into memory


In [9]:
# ============================================================================
# 7. Initialize FLAN-T5 Model for Text Generation
# ============================================================================

flan_model_name = "google/flan-t5-small"

try:
    tokenizer = AutoTokenizer.from_pretrained(flan_model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(flan_model_name)
    flan_pipeline = pipeline(
        "text2text-generation",
        model=model,
        tokenizer=tokenizer,
        device=0 if torch.cuda.is_available() else -1
    )
    print(f"FLAN-T5 model '{flan_model_name}' loaded successfully")
except Exception as e:
    print(f"Error loading FLAN-T5: {e}")
    flan_pipeline = None

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Device set to use cuda:0


FLAN-T5 model 'google/flan-t5-small' loaded successfully


In [10]:
# ============================================================================
# 8. Retrieval and Generation Functions
# ============================================================================

def search_similar_chunks(user_query, num_results=5):
    """Retrieve top-k most similar chunks for a query."""
    query_vector = embedding_model.encode([user_query], normalize_embeddings=True).astype("float32")[0].tolist()

    search_results = milvus_client.search(
        collection_name="rag_mini",
        data=[query_vector],
        limit=num_results,
        output_fields=["id", "passage"]
    )

    retrieved_chunks = []
    for result in search_results[0]:
        result_id = result["id"]
        result_content = result["entity"]["passage"]
        similarity_score = result["distance"]
        retrieved_chunks.append((result_id, result_content, similarity_score))

    return retrieved_chunks


def generate_rag_answer(user_query, num_results=5, max_context_length=2000):
    """Retrieve relevant context and generate answer using FLAN-T5."""
    retrieved_chunks = search_similar_chunks(user_query, num_results=num_results)
    combined_context = "\n\n".join([chunk[1] for chunk in retrieved_chunks])[:max_context_length]

    if not flan_pipeline:
        generated_answer = (
            "[No FLAN model available] Please check model loading.\n\n"
            f"Top retrieved context:\n{combined_context[:500]}..."
        )
        return generated_answer, retrieved_chunks

    instruction_prompt = ("You are a helpful assistant. Answer strictly using the provided context. ""If the context is insufficient, answer 'I don't know.'\n\n"
        f"Context:\n{combined_context}\n\n"
        f"Question: {user_query}\n\n"
        "Answer:"
    )

    try:
        response = flan_pipeline(
            instruction_prompt,
            max_length=512,
            temperature=0.2,
            num_return_sequences=1
        )
        generated_answer = response[0]["generated_text"]
    except Exception as error:
        generated_answer = f"[Error calling FLAN pipeline: {str(error)}]\n\nTop context:\n{combined_context[:500]}..."

    return generated_answer, retrieved_chunks

In [11]:
# ============================================================================
# 9. Test the System
# ============================================================================

print("\n" + "="*80)
print("TESTING RETRIEVAL")
print("="*80)

demo_query = "What are the three sections of a beetle?"
print(f"Query: {demo_query}\n")
demo_results = search_similar_chunks(demo_query, num_results=3)
for result_id, result_text, score in demo_results:
    print(f"ID: {result_id} | Score: {score:.4f}")
    print(f"Text: {result_text[:150]}...\n")

# Full RAG test
print("="*80)
print("TESTING FULL RAG PIPELINE")
print("="*80)

first_question = test_questions[0]["question"]
print(f"Question: {first_question}\n")

final_answer, top_chunks = generate_rag_answer(first_question, num_results=5)
print(f"=== GENERATED ANSWER ===")
print(final_answer)
print(f"\n=== TOP {len(top_chunks)} RETRIEVED CHUNKS ===")
for rank, (chunk_id, chunk_text, score) in enumerate(top_chunks, 1):
    print(f"\n[{rank}] ID: {chunk_id} | Score: {score:.4f}")
    print(f"Text: {chunk_text[:120]}...")

print("\n" + "="*80)
print("RAG SYSTEM READY")

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



TESTING RETRIEVAL
Query: What are the three sections of a beetle?

ID: 1281 | Score: 0.3709
Text: s as generally assumed, which would necessitate splitting the traditional Pelecaniformes in three....

ID: 1274 | Score: 0.3135
Text: The Megadyptes - Eudyptes clade occurs at similar latitudes (though not as far north as the Galapagos Penguin), has its highest diversity in the New Z...

ID: 1269 | Score: 0.2790
Text: Pygoscelis contains species with a fairly simple black-and-white head pattern; their distribution is intermediate, centered on Antarctic coasts but ex...

TESTING FULL RAG PIPELINE
Question: Was Abraham Lincoln the sixteenth President of the United States?



Both `max_new_tokens` (=256) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


=== GENERATED ANSWER ===
Yes

=== TOP 5 RETRIEVED CHUNKS ===

[1] ID: 339 | Score: 0.7095
Text: Young Abraham Lincoln...

[2] ID: 320 | Score: 0.6434
Text: Abraham Lincoln (February 12, 1809 â April 15, 1865) was the sixteenth President of the United States, serving from Ma...

[3] ID: 381 | Score: 0.5896
Text: On November 6, 1860, Lincoln was elected as the 16th President of the United States, beating Democrat Stephen A. Douglas...

[4] ID: 882 | Score: 0.5569
Text: Sixteen months before his death, his son, John Quincy Adams, became the sixth President of the United States (1825 1829)...

[5] ID: 480 | Score: 0.5484
Text: * American School, Lincoln's economic views....

RAG SYSTEM READY


In [23]:
# ============================================================================
# 10. Prompting Strategy Evaluation
# ============================================================================

print("\n" + "="*80)
print("STEP 3: PROMPTING STRATEGY EVALUATION")
print("="*80)

class PromptingStrategies:

    @staticmethod
    def chain_of_thought_prompt(context, question):
        """Chain-of-thought: step-by-step reasoning"""
        return f"""Think step by step. Use only the context.
If the answer is not in the context, say 'I don't know'.

Context: {context}
Question: {question}
Answer:"""

    @staticmethod
    def persona_prompt(context, question):
        """Persona-based: Subject matter expert"""
        return f"""You are a subject matter expert. Use only the context.
If the answer is not in the context, say 'I don't know'. Be direct.

Context: {context}
Question: {question}
Answer:"""

    @staticmethod
    def instruction_prompt(context, question):
        """Instruction-based prompt"""
        return f"""Answer using only the context.
If the answer is not there, say 'I don't know'.

Context: {context}
Question: {question}
Answer:"""

print("✓ Defined 3 prompting strategies")


STEP 3: PROMPTING STRATEGY EVALUATION
✓ Defined 3 prompting strategies


In [24]:
# ============================================================================
# 11. Prompting Strategy: Generation & Evaluation Functions
# ============================================================================

PROMPT_STRATEGIES = {
    "chain_of_thought": PromptingStrategies.chain_of_thought_prompt,
    "persona": PromptingStrategies.persona_prompt,
    "instruction": PromptingStrategies.instruction_prompt,
}

def retrieve_top1_context(user_query):
    """Get top-1 chunk from Milvus."""
    results = search_similar_chunks(user_query, num_results=1)
    return results[0][1] if results else ""

def evaluate_prompting_strategy(strategy_name, num_samples=100):
    prompt_builder = PROMPT_STRATEGIES[strategy_name]
    predictions_list = []
    references_list = []

    print(f"\nEvaluating '{strategy_name}' on {num_samples} samples...")

    for idx in range(num_samples):
        question = test_questions[idx]["question"]
        gold_answer = test_questions[idx].get("answer") or test_questions[idx].get("answers")
        gold_text = gold_answer if isinstance(gold_answer, str) else gold_answer[0]

        context = retrieve_top1_context(question)
        prompt = prompt_builder(context, question)

        output = flan_pipeline(prompt, max_new_tokens=128, temperature=0.2)
        predicted_answer = output[0]["generated_text"].strip()

        predictions_list.append({"id": str(idx), "prediction_text": predicted_answer})
        references_list.append({"id": str(idx), "answers": {"text": [gold_text], "answer_start": [0]}})

    squad_metric = evaluate.load("squad")
    return squad_metric.compute(predictions=predictions_list, references=references_list)

In [25]:
# ============================================================================
# 12. Run Evaluation on 100 Samples
# ============================================================================

import datetime

N_SAMPLES = 100
results = {}

for name in PROMPT_STRATEGIES:
    start_time = datetime.datetime.now()
    print(f"\n→ Evaluating {name} on {N_SAMPLES} samples (GPU) - started at {start_time.strftime('%H:%M:%S')}")

    metrics = evaluate_prompting_strategy(name, num_samples=N_SAMPLES)
    results[name] = metrics

    end_time = datetime.datetime.now()
    print(f"✓ Completed {name} at {end_time.strftime('%H:%M:%S')}")

print("\n=== Results ===")
for k, v in results.items():
    print(k, v)


→ Evaluating chain_of_thought on 100 samples (GPU) - started at 02:35:45

Evaluating 'chain_of_thought' on 100 samples...
✓ Completed chain_of_thought at 02:36:40

→ Evaluating persona on 100 samples (GPU) - started at 02:36:40

Evaluating 'persona' on 100 samples...
✓ Completed persona at 02:37:06

→ Evaluating instruction on 100 samples (GPU) - started at 02:37:06

Evaluating 'instruction' on 100 samples...
✓ Completed instruction at 02:37:30

=== Results ===
chain_of_thought {'exact_match': 8.0, 'f1': 12.332352903527536}
persona {'exact_match': 28.0, 'f1': 31.380048266706773}
instruction {'exact_match': 23.0, 'f1': 26.121037102432453}


In [26]:
# ============================================================================
# 13. Identify Best Performing Strategies (Friend-Style)
# ============================================================================

print("\n" + "="*60)
print("BEST PERFORMING STRATEGIES")
print("="*60)

best_f1_strategy = max(results.items(), key=lambda x: x[1]["f1"])
best_em_strategy = max(results.items(), key=lambda x: x[1]["exact_match"])

print(f"\n→ Best by F1 Score:")
print(f"  Strategy: {best_f1_strategy[0]}")
print(f"  F1: {best_f1_strategy[1]['f1']:.4f}")
print(f"  EM: {best_f1_strategy[1]['exact_match']:.4f}")

print(f"\n→ Best by Exact Match:")
print(f"  Strategy: {best_em_strategy[0]}")
print(f"  F1: {best_em_strategy[1]['f1']:.4f}")
print(f"  EM: {best_em_strategy[1]['exact_match']:.4f}")

print("\n=== STEP 3 EVALUATION COMPLETE ===")


BEST PERFORMING STRATEGIES

→ Best by F1 Score:
  Strategy: persona
  F1: 31.3800
  EM: 28.0000

→ Best by Exact Match:
  Strategy: persona
  F1: 31.3800
  EM: 28.0000

=== STEP 3 EVALUATION COMPLETE ===
