# Advanced RAG for Scientific Papers

In [1]:
# import tensorflow as tf

# print("TF:", tf.__version__)
# print("Devices:", tf.config.list_physical_devices())
# print("GPUs:", tf.config.list_physical_devices("GPU"))


In [2]:
#uv pip install chromadb sentence-transformers

In [3]:
# --- 1. Load scientific papers from JSON ---
import json
import os

papers_dir = "../papers_json_3"

corpus = []
files = sorted([f for f in os.listdir(papers_dir) if f.endswith('.json')])[:200]

for filename in files:
    with open(os.path.join(papers_dir, filename), 'r', encoding='utf-8') as f:
        paper = json.load(f)
    corpus.append({
        "article_id": paper.get("article_id", filename.replace(".json", "")),
        "text": paper.get("abstract", "") + "\n\n" + paper.get("article", "")
    })

print(f"Loaded {len(corpus)} papers")

Loaded 200 papers


In [4]:
# --- 2. Improved Chunking Strategy (larger chunks, more overlap) ---
'''
Advanced chunking: chunk_size=450 words, overlap=100 words
Larger chunks preserve more context for scientific papers
'''
def chunk_text(text, chunk_size=450, overlap=100):
    words = text.split()
    if len(words) <= 100:
        return [text] if words else []
    chunks = []
    i = 0
    while i < len(words):
        chunk = " ".join(words[i:i + chunk_size])
        chunks.append(chunk)
        i += chunk_size - overlap
    return chunks

# Build chunks with metadata
chunk_texts, metadatas, ids = [], [], []

for paper in corpus:
    chunks = chunk_text(paper["text"])
    for idx, ch in enumerate(chunks):
        chunk_texts.append(ch)
        metadatas.append({"article_id": paper["article_id"], "chunk_idx": idx})
        ids.append(f'{paper["article_id"]}_chunk_{idx}')

print(f"Total chunks: {len(chunk_texts)}")

Total chunks: 3830


In [5]:
# --- 3. Initialize ChromaDB ---
import chromadb

client = chromadb.PersistentClient(path="scientific_rag_db")
collection = client.get_or_create_collection("scientific_papers")

In [6]:
# --- 4. Embed and index chunks with BAAI model ---
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

# Using BAAI/bge-base-en-v1.5 - state-of-the-art for retrieval
embedder = SentenceTransformer("BAAI/bge-base-en-v1.5")

if collection.count() == 0:
    print(f"Embedding and indexing {len(chunk_texts)} chunks...")
    
    # Embed and add in batches (ChromaDB has max batch size ~5000)
    batch_size = 500
    for i in tqdm(range(0, len(chunk_texts), batch_size), desc="Indexing"):
        batch_texts = chunk_texts[i:i + batch_size]
        batch_metas = metadatas[i:i + batch_size]
        batch_ids = ids[i:i + batch_size]
        
        batch_embs = embedder.encode(batch_texts, show_progress_bar=False).tolist()
        
        collection.add(
            documents=batch_texts,
            embeddings=batch_embs,
            metadatas=batch_metas,
            ids=batch_ids
        )
    
    print(f"Indexed {collection.count()} chunks")
else:
    print(f"Collection already has {collection.count()} chunks")

  from .autonotebook import tqdm as notebook_tqdm


Embedding and indexing 3830 chunks...


Indexing: 100%|██████████| 8/8 [01:10<00:00,  8.85s/it]

Indexed 3830 chunks





In [7]:
# --- 5. Basic Retrieval function ---
def retrieve(query, k=3):
    q_emb = embedder.encode([query]).tolist()[0]
    results = collection.query(query_embeddings=[q_emb], n_results=k)
    return results["documents"][0], results["metadatas"][0], results["ids"][0]

# Test retrieval
docs, metas, doc_ids = retrieve("random walk on networks")
for i, (doc, meta) in enumerate(zip(docs, metas)):
    print(f"[{i+1}] {meta['article_id']} (chunk {meta['chunk_idx']})")
    print(f"    {doc[:150]}...\n")

[1] article_1 (chunk 5)
    not be directly mapped on an equivalent walk on the aggregated graph . _ stationary probability distribution . _ starting from the one - step transiti...

[2] article_1 (chunk 0)
    efficient techniques to navigate networks with local information are fundamental to sample large - scale online social systems and to retrieve resourc...

[3] article_102 (chunk 42)
    and sons , 2008 , xv+352 pages . s. janson , t. uczak , t. turova , and t. vallier . bootstrap percolation on the random graph @xmath858 . , 22(5):198...



In [8]:
# --- 6. Load LLM ---
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Use GPU if available, otherwise CPU
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
    
print(f"Using device: {device}")

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device,
    torch_dtype="auto"
)
print(f"Model loaded: {model_name}")

`torch_dtype` is deprecated! Use `dtype` instead!


Using device: mps
Model loaded: Qwen/Qwen2.5-1.5B-Instruct


In [9]:
# --- 6. Multi-Query Rewriting ---
def generate_alternative_queries(question, n_queries=3):
    prompt = (
        f"Rewrite the following question into {n_queries} alternative search queries, "
        f"each on a new line:\n\nQuestion: {question}"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    out = model.generate(**inputs, max_new_tokens=150)
    text = tokenizer.decode(out[0], skip_special_tokens=True)

    lines = [l.strip("-• ").strip() for l in text.split("\n") if l.strip()]
    lines = [l for l in lines if question.lower() not in l.lower()]

    return lines[:n_queries]

In [10]:
# --- 7. Multi-Query Retrieval with Deduplication ---
def retrieve_multi_query(question, k_per_query=3, n_queries=3):
    mq = generate_alternative_queries(question, n_queries=n_queries)
    all_qs = [question] + mq

    seen = set()
    results_list = []

    for q in all_qs:
        q_emb = embedder.encode([q]).tolist()[0]
        result = collection.query(
            query_embeddings=[q_emb],
            n_results=k_per_query
        )

        for doc_id, doc, meta in zip(result["ids"][0], result["documents"][0], result["metadatas"][0]):
            if doc_id not in seen:
                seen.add(doc_id)
                results_list.append({"id": doc_id, "text": doc, "metadata": meta})

    return results_list

In [11]:
# --- 8. Context Building with Character Limits ---
def build_context_with_limits(docs, limit_chars=3500):
    context = []
    total = 0
    for i, d in enumerate(docs, 1):
        header = f"[{i}] Source: {d['metadata']['article_id']} (chunk {d['metadata']['chunk_idx']})"
        block = f"{header}\n{d['text'][:900]}\n"
        if total + len(block) > limit_chars:
            break
        context.append(block)
        total += len(block)
    return "\n".join(context)

In [12]:
# --- 9. Advanced RAG Answer Generation ---
def rag_answer(query, use_multi_query=True):
    if use_multi_query:
        retrieved = retrieve_multi_query(query)
        context = build_context_with_limits(retrieved)
        sources = [d['metadata']['article_id'] for d in retrieved]
    else:
        docs, metas, _ = retrieve(query)
        context_parts = []
        for i, (doc, meta) in enumerate(zip(docs, metas), 1):
            context_parts.append(f"[{i}] Source: {meta['article_id']}\n{doc[:900]}")
        context = "\n\n".join(context_parts)
        sources = [meta['article_id'] for meta in metas]

    prompt = f"""You are a helpful assistant answering based ONLY on the following scientific paper excerpts.
Use citations like [1], [2] referring to the sources provided.

==================== SOURCES ====================
{context}
=================================================

Question: {query}

Answer with citations:
"""
    tokens = tokenizer(prompt, return_tensors="pt").to(model.device)
    output = model.generate(**tokens, max_new_tokens=250, do_sample=False)
    answer = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Extract just the answer part
    if "Answer with citations:" in answer:
        answer = answer.split("Answer with citations:")[-1].strip()
    
    return answer, sources

In [13]:
# --- 10. Test Advanced RAG ---
test_queries = [
" What is a random walk in the context of a network? ",
" What is meant by a multiplex (multi-layer) network? ",
" What does the term stationary probability distribution refer to? ",
" What are scalar perturbations in cosmology? ",
" What physical system do the airline transportation networks represent in the corpus? ",

" How does a biased random walk differ from an unbiased random walk? ",
" What is the role of entropy rate in characterizing a random walk? ",
" Why are multi-layer networks considered more realistic than single-layer networks? ",
" What does gauge invariance ensure in cosmological perturbation theory? ",
" How does edge overlap affect diffusion on multiplex networks? ",

" How are extensive and intensive bias functions defined, and why do they differ in parameter scaling? ",
" In what way does the overlapping adjacency matrix differ from a simple aggregated network representation? ",
" How does inter-layer degree correlation influence the dispersiveness of biased random walks? ",
" Why does the extended electromagnetic vector field introduce additional scalar modes in cosmology? ",
" How do real-world multiplex airline networks demonstrate a trade-off between diffusion efficiency and robustness? "
]

for q in test_queries:
    print(f"QUESTION: {q}")
    answer, sources = rag_answer(q, use_multi_query=True)
    print(f"\nANSWER: {answer}")
    print(f"\nSOURCES: {', '.join(set(sources))}")
    print("=" * 80)

QUESTION:  What is a random walk in the context of a network? 


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



ANSWER: A random walk in the context of a network refers to a process where a walker moves from one node to another, typically according to some set of rules or probabilities, without any specific direction or goal. The movement is stochastic, meaning it follows a probabilistic path rather than a deterministic one. In the context of the provided excerpt, a random walk involves a walker moving from one node to another based on certain transition probabilities, which can be influenced by various factors such as the topology of the network and the properties of neighboring nodes. This type of movement is often used to model diffusion processes, exploratory behavior, or other phenomena within complex networks. [Source: article_1 (chunk 5)]

SOURCES: article_63, article_93, article_36, article_1
QUESTION:  What is meant by a multiplex (multi-layer) network? 

ANSWER: A multiplex (multi-layer) network refers to a network composed of multiple interconnected layers, each representing differen