Install & setup

Imports

Load & prepare dataset

Build corpus + queries + qrels

Write corpus to disk (RapidFire expects files)

Queries DataFrame

Define RAG search space (retrieval-focused)

Generator

Preprocess (retrieval-only focus)

Postprocess (attach ground truth)

Metrics (Precision / Recall / MRR / NDCG)

Grid + Experiment

Run hyperparallel evals

Results table

In [11]:
# from collections import Counter

# # This will give you a dictionary of {category: count}
# category_counts = Counter(raw_dataset["product_category"])

# # Print them nicely sorted by the most popular
# for category, count in category_counts.most_common():
#     print(f"{category}: {count} reviews")

home: 17679 reviews
apparel: 15951 reviews
wireless: 15717 reviews
other: 13418 reviews
beauty: 12091 reviews
drugstore: 11730 reviews
kitchen: 10382 reviews
toy: 8745 reviews
sports: 8277 reviews
automotive: 7506 reviews
lawn_and_garden: 7327 reviews
home_improvement: 7136 reviews
pet_products: 7082 reviews
digital_ebook_purchase: 6749 reviews
pc: 6401 reviews
electronics: 6186 reviews
office_product: 5521 reviews
shoes: 5197 reviews
grocery: 4730 reviews
book: 3756 reviews
baby_product: 3150 reviews
furniture: 2984 reviews
jewelry: 2747 reviews
camera: 2139 reviews
industrial_supplies: 1994 reviews
digital_video_download: 1364 reviews
luggage: 1328 reviews
musical_instruments: 1102 reviews
video_games: 775 reviews
watch: 761 reviews
personal_care_appliances: 75 reviews


gem new

In [1]:
try:
    import rapidfireai
    print("‚úÖ rapidfireai installed")
except ImportError:
    !pip install rapidfireai datasets==3.6.0 langchain sentence-transformers
    !rapidfireai init --evals

Collecting rapidfireai
  Downloading rapidfireai-0.12.9-py3-none-any.whl.metadata (24 kB)
Collecting datasets==3.6.0
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting flask-cors (from rapidfireai)
  Downloading flask_cors-6.0.2-py3-none-any.whl.metadata (5.3 kB)
Collecting waitress (from rapidfireai)
  Downloading waitress-3.0.2-py3-none-any.whl.metadata (5.8 kB)
Collecting jq (from rapidfireai)
  Downloading jq-1.11.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (7.0 kB)
Collecting jedi (from rapidfireai)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting uv (from rapidfireai)
  Downloading uv-0.9.27-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting trackio (from rapidfireai)
  Downloading trackio-0.15.0-py3-none-any.whl.metadata (9.9 kB)
Collecting mlflow (from rapidfireai)
  Downloading mlflow-3.8.1-py3-none-any.whl.metadata (31 kB)
Collecting mlflow-skinny==3.8.1 (from mlflow->

In [2]:
import os
import math
import json
import random
import pandas as pd
from pathlib import Path
from typing import List as listtype, Dict, Any
from datasets import load_dataset

# Critical for Colab compatibility
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'

from rapidfireai import Experiment
from rapidfireai.automl import List, RFLangChainRagSpec, RFvLLMModelConfig, RFPromptManager, RFGridSearch

In [8]:
# for i, row in enumerate(sampled_data):
#     print(row)
#     break

{'review_id': 'en_0760000', 'product_id': 'product_en_0049787', 'reviewer_id': 'reviewer_en_0204668', 'stars': 1, 'review_body': 'One ear bud lasted only 2 weeks. The other went out 3 months later. Not worth the money or frustration', 'review_title': 'One ear bud lasted only 2 weeks. The other ...', 'language': 'en', 'product_category': 'electronics'}


In [9]:
import json
import random
import pandas as pd
from pathlib import Path
from datasets import load_dataset
from collections import defaultdict

# 1. Setup
dataset_dir = Path("./electronics_rag")
dataset_dir.mkdir(exist_ok=True)

# 2. Load and Filter
raw_dataset = load_dataset("buruzaemon/amazon_reviews_multi", "en", split="train")
electronics_data = raw_dataset.filter(lambda x: "electronics" in x["product_category"].lower())

# 3. Downsample (Using a larger set to ensure product overlaps)
sample_size = 100
rseed = 42
random.seed(rseed)
sampled_data = electronics_data.shuffle(seed=rseed).select(range(sample_size))

# 4. Grouping Logic
# We need to know which documents belong to which product
product_to_docs = defaultdict(list)
corpus_list = []
queries_list = []

for i, row in enumerate(sampled_data):
    doc_id = f"doc_{i}"
    query_id = f"q_{i}"
    prod_id = str(row['product_id'])

    # Store the document
    corpus_list.append({"_id": doc_id, "text": row["review_body"]})

    # Store the query (using title)
    queries_list.append({"query_id": query_id, "query": row["review_title"]})

    # Map this document to its product group
    product_to_docs[prod_id].append(doc_id)

# 5. Build Expanded QRELS
qrels_rows = []
for i, row in enumerate(sampled_data):
    query_id = f"q_{i}"
    prod_id = str(row['product_id'])

    # Every document sharing this product_id is now a "correct" answer
    relevant_docs = product_to_docs[prod_id]

    for d_id in relevant_docs:
        qrels_rows.append({
            "query_id": query_id,
            "corpus_id": d_id,
            "relevance": 1
        })

# 6. Save and Finalize
corpus_file = dataset_dir / "corpus_sampled.jsonl"
with open(corpus_file, "w") as f:
    for doc in corpus_list:
        f.write(json.dumps(doc) + "\n")

electronics_dataset = pd.DataFrame(queries_list).astype(str)
qrels = pd.DataFrame(qrels_rows).astype(str)

print(f"‚úÖ Prepared {len(corpus_list)} documents.")
print(f"‚úÖ Expanded QRELS: {len(qrels)} relevance pairs (Multiple reviews per product).")

‚úÖ Prepared 100 documents.
‚úÖ Expanded QRELS: 100 relevance pairs (Multiple reviews per product).


In [10]:
from langchain_community.document_loaders import DirectoryLoader, JSONLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_classic.retrievers.document_compressors import CrossEncoderReranker

# Batch size for embedding model hardware efficiency
batch_size = 50

rag_gpu = RFLangChainRagSpec(
    document_loader=DirectoryLoader(
        path=str(dataset_dir),
        glob="corpus_sampled.jsonl",
        loader_cls=JSONLoader,
        loader_kwargs={
            "jq_schema": ".",
            "content_key": "text",
            "metadata_func": lambda record, metadata: {
                "corpus_id": str(record.get("_id")) # CRITICAL: Must be string
            },
            "json_lines": True,
            "text_content": False,
        },
        sample_seed=42,
    ),
    # testing 2 chunking granularities
    text_splitter=List([
            RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=32),
            RecursiveCharacterTextSplitter(chunk_size=128, chunk_overlap=32),
        ],
    ),
    embedding_cls=HuggingFaceEmbeddings,
    embedding_kwargs={
        "model_name": "sentence-transformers/all-MiniLM-L6-v2",
        "model_kwargs": {"device": "cuda:0"},
        "encode_kwargs": {"normalize_embeddings": True, "batch_size": batch_size},
    },
    vector_store=None,  # Defaults to FAISS
    search_type="similarity",
    # INCREASED K: Fetch more candidates because many reviews are now 'correct'
    search_kwargs={"k": 20},

    reranker_cls=CrossEncoderReranker,
    reranker_kwargs={
        "model_name": "cross-encoder/ms-marco-MiniLM-L6-v2",
        "model_kwargs": {"device": "cpu"},
        # INCREASED TOP_N: Allow more product-relevant evidence through to the LLM
        "top_n": List([5, 10]),
    },
    enable_gpu_search=True, # Stability fix for Colab environment
)

In [11]:
def sample_preprocess_fn(batch: Dict[str, listtype], rag: RFLangChainRagSpec, prompt_manager: RFPromptManager) -> Dict[str, listtype]:
    INSTRUCTIONS = "Utilize your knowledge of electronics to answer the following question based on the provided reviews."

    # 1. Cast queries to strings to avoid 'AttributeError' in retrieval
    batch_queries = [str(q).strip() for q in batch["query"]]

    # 2. Perform retrieval
    all_context = rag.get_context(batch_queries=batch_queries, serialize=False)

    # 3. Explicitly extract and cast IDs to strings to match QRELS
    retrieved_documents = [
        [str(doc.metadata.get("corpus_id", "")).strip() for doc in docs]
        for docs in all_context
    ]

    serialized_context = rag.serialize_documents(all_context)

    return {
        "prompts": [
            [
                {"role": "system", "content": INSTRUCTIONS},
                {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"},
            ]
            for question, context in zip(batch_queries, serialized_context)
        ],
        "retrieved_documents": retrieved_documents,
        # Ensure all original batch keys are passed through as lists
        **{k: list(v) for k, v in batch.items()},
    }

def sample_postprocess_fn(batch: Dict[str, listtype]) -> Dict[str, listtype]:
    # Ensure ID comparison is string-to-string to avoid empty lists
    batch["ground_truth_documents"] = [
        qrels[qrels["query_id"].astype(str) == str(qid).strip()]["corpus_id"].tolist()
        for qid in batch["query_id"]
    ]
    return batch

def compute_ndcg_at_k(retrieved_docs, expected_docs, k=5):
    relevance = [1 if doc in expected_docs else 0 for doc in list(retrieved_docs)[:k]]
    dcg = sum(rel / math.log2(i + 2) for i, rel in enumerate(relevance))
    ideal_length = min(k, len(expected_docs))
    idcg = sum(1 / math.log2(i + 2) for i in range(ideal_length))
    return dcg / idcg if idcg > 0 else 0.0

def sample_compute_metrics_fn(batch: Dict[str, listtype]) -> Dict[str, Dict[str, Any]]:
    precisions, recalls, ndcgs, rrs, hits = [], [], [], [], []
    total_queries = len(batch["query"])

    for pred, gt in zip(batch["retrieved_documents"], batch["ground_truth_documents"]):
        # Use sets for O(1) intersection performance
        actual = set(str(p).strip() for p in pred)
        expected = set(str(g).strip() for g in gt)

        tp = len(actual.intersection(expected))

        precisions.append(tp / len(actual) if actual else 0)
        recalls.append(tp / len(expected) if expected else 0)
        ndcgs.append(compute_ndcg_at_k(pred, expected, k=5))

        # Hit Rate: Did we get at least one review for the right product?
        hits.append(1 if tp > 0 else 0)

        # Reciprocal Rank calculation
        rr = 0
        for i, p in enumerate(pred):
            if str(p).strip() in expected:
                rr = 1 / (i + 1)
                break
        rrs.append(rr)

    return {
        "Total": {"value": total_queries},
        "Hit Rate": {"value": sum(hits) / total_queries}, # NEW
        "Precision": {"value": sum(precisions) / total_queries},
        "Recall": {"value": sum(recalls) / total_queries},
        "NDCG@5": {"value": sum(ndcgs) / total_queries},
        "MRR": {"value": sum(rrs) / total_queries},
    }

def sample_accumulate_metrics_fn(aggregated_metrics: Dict[str, listtype]) -> Dict[str, Dict[str, Any]]:
    total_queries = sum(m["value"] for m in aggregated_metrics["Total"])
    # Added Hit Rate to the algebraic accumulation list
    metrics = ["Hit Rate", "Precision", "Recall", "NDCG@5", "MRR"]

    return {
        "Total": {"value": total_queries},
        **{
            m: {
                "value": sum(v["value"] for v in aggregated_metrics[m]) / len(aggregated_metrics[m]),
                "is_algebraic": True
            } for m in metrics
        }
    }

In [12]:
vllm_config = RFvLLMModelConfig(
    model_config={
        "model": "Qwen/Qwen2.5-0.5B-Instruct",
        "dtype": "half", # Force half-precision for speed
        "gpu_memory_utilization": 0.25,
        "enforce_eager": True,
        "max_model_len": 2048, # Limits KV cache to prevent OOM
        "disable_log_stats": True,
    },
    sampling_params={
        "temperature": 0.7, # Added for more natural answers
        "top_p": 0.95,
        "max_tokens": 128,
    },
    rag=rag_gpu,
)

config_set = {
    "vllm_config": vllm_config,
    "batch_size": 4,
    "preprocess_fn": sample_preprocess_fn,
    "postprocess_fn": sample_postprocess_fn,
    "compute_metrics_fn": sample_compute_metrics_fn,
    "accumulate_metrics_fn": sample_accumulate_metrics_fn,
    # Matches Code 1's real-time metric aggregation
    "online_strategy_kwargs": {
        "strategy_name": "normal",
        "confidence_level": 0.95,
        "use_fpc": True,
    },
}

In [13]:
config_group = RFGridSearch(config_set)
experiment = Experiment(experiment_name="amazon-electronics-rag-v2", mode="evals")

results = experiment.run_evals(config_group=config_group, dataset=electronics_dataset, num_actors=1,num_shards=4,seed=42)

# Cleanup and log viewing
experiment.end()


Experiment amazon-electronics-rag-v2 is currently running. Returning the same experiment object.
Any running tasks have been cancelled.
üåê Google Colab detected. Ray dashboard URL: https://8855-gpu-t4-hm-1lpb172kzqkzg-c.asia-southeast1-1.prod.colab.dev
üåê Google Colab detected. Dispatcher URL: https://8851-gpu-t4-hm-1lpb172kzqkzg-c.asia-southeast1-1.prod.colab.dev


=== Preprocessing RAG Sources ===


RAG Source ID,Status,Duration,Details
1,Building,0.0s,"FAISS, GPU"
2,Building,0.0s,"FAISS, GPU"


KeyboardInterrupt: 

In [None]:
pd.DataFrame([{k: (v['value'] if isinstance(v, dict) else v) for k, v in {**m, 'run_id': rid}.items()} for rid, (_, m) in results.items()])