# Retrieval Pipeline Orchestration + Reranking

**Goal**
Build a hybrid retrieval system that:

1. Pulls candidate results from your Qdrant vector store (semantic retrieval).

2. Optionally augments those with graph-based or metadata filters (hierarchical or relational layer).

3. Applies cross-encoder reranking to refine relevance before passing context to the LLM.

# 1: Setup & Connections

* Import required libraries.
* Load/connect to Qdrant (vector DB) and Memgraph (graph DB).
* Load the embedding model (for encoding queries) and the cross-encoder reranker (for re-scoring * candidate documents).
* Define some configuration variables (collection name, top_k defaults).
* Print short status messages so you know everything connected/loaded correctly.

In [3]:
# Before running this notebook, ensure Notebooks 01–02 are done
# Qdrant + Memgraph must be running

import os
import json
from typing import List, Dict, Any
from pprint import pprint
import numpy as np
from tqdm import tqdm

# Qdrant
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest

# Memgraph
from gqlalchemy import Memgraph

# NEW: BGE-M3 embedder
from FlagEmbedding import BGEM3FlagModel

# Cross-encoder for reranking
from sentence_transformers import CrossEncoder


# ----- Configuration -----

QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
MEMGRAPH_HOST = os.getenv("MEMGRAPH_HOST", "localhost")
MEMGRAPH_PORT = int(os.getenv("MEMGRAPH_PORT", 7687))

COLLECTION_NAME = os.getenv("QDRANT_COLLECTION", "enterprise_docs")

DEFAULT_TOP_K = 50   # initial candidates
FINAL_TOP_K = 5      # final reranked results


# ----- Connect to Qdrant -----

try:
    qdrant_client = QdrantClient(url=QDRANT_URL)
    collections = qdrant_client.get_collections()
    print(f"Connected to Qdrant at {QDRANT_URL}. Collections: {[c.name for c in collections.collections]}")
except Exception as e:
    qdrant_client = None
    print("Could not connect to Qdrant:", e)


# ----- Connect to Memgraph -----

try:
    memgraph = Memgraph(host=MEMGRAPH_HOST, port=MEMGRAPH_PORT)
    test = list(memgraph.execute_and_fetch("RETURN 1 AS ok"))
    print(f"Connected to Memgraph at {MEMGRAPH_HOST}:{MEMGRAPH_PORT}")
except Exception as e:
    memgraph = None
    print("Could not connect to Memgraph:", e)


# ----- Load BGE-M3 Embedder -----

EMBEDDING_MODEL_NAME = "BAAI/bge-m3"

try:
    embedder = BGEM3FlagModel(EMBEDDING_MODEL_NAME, use_fp16=False)
    print(f"Loaded BGE-M3 embedder: {EMBEDDING_MODEL_NAME}")
except Exception as e:
    embedder = None
    print("Could not load BGE-M3 embedder:", e)


# ----- Query embedding helper (BGE-M3) -----

def embed_query(text: str) -> List[float]:
    """Return normalized 1024-d query embedding using BGE-M3."""
    out = embedder.encode(
        text,
        max_length=8192,
        return_dense=True,
        return_sparse=False,
        return_colbert_vecs=False
    )
    vec = out["dense_vecs"]
    norm = np.linalg.norm(vec)
    if norm > 0:
        vec = vec / norm
    return vec.tolist()


# ----- Load reranker -----

RERANKER_MODEL_NAME = os.getenv("RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
try:
    reranker = CrossEncoder(RERANKER_MODEL_NAME)
    print(f"Loaded reranker: {RERANKER_MODEL_NAME}")
except Exception as e:
    reranker = None
    print("Could not load reranker:", e)


# ----- Summary -----

print("\n--- Summary ---")
print(f"Qdrant client: {'OK' if qdrant_client else 'MISSING'}")
print(f"Memgraph client: {'OK' if memgraph else 'MISSING'}")
print(f"Embedder: {'OK' if embedder else 'MISSING'} (BGE-M3)")
print(f"Reranker: {'OK' if reranker else 'MISSING'} ({RERANKER_MODEL_NAME})")
print(f"Collection in use: {COLLECTION_NAME}")
print(f"Retrieval settings: DEFAULT_TOP_K={DEFAULT_TOP_K}, FINAL_TOP_K={FINAL_TOP_K}")


Connected to Qdrant at http://localhost:6333. Collections: ['enterprise_docs']
Connected to Memgraph at localhost:7687


Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

Loaded BGE-M3 embedder: BAAI/bge-m3
Loaded reranker: cross-encoder/ms-marco-MiniLM-L-6-v2

--- Summary ---
Qdrant client: OK
Memgraph client: OK
Embedder: OK (BGE-M3)
Reranker: OK (cross-encoder/ms-marco-MiniLM-L-6-v2)
Collection in use: enterprise_docs
Retrieval settings: DEFAULT_TOP_K=50, FINAL_TOP_K=5


# 2 Semantic Retrieval from Qdrant

Here we will:
1. Take a natural language query.
2. Encode it into a vector using your embedding model.
3. Search for the most semantically similar chunks (documents) in Qdrant.
4. Return and display the top K results, with metadata for inspection.

In [5]:
def semantic_retrieval(query: str, top_k: int = 5, collection_name: str = COLLECTION_NAME):
    """
    Retrieve top-k most relevant documents from Qdrant using BGE-M3 embeddings.
    """

    # --- 1) Embed query using BGE-M3 ---
    query_vector = embed_query(query)   # returns normalized 1024-d list

    # --- 2) Qdrant vector search using new API ---
    results = qdrant_client.query_points(
        collection_name=collection_name,
        query=query_vector,
        limit=top_k,
        with_payload=True,
        with_vectors=False
    )

    # --- 3) Format results ---
    formatted = []
    for point in results.points:
        formatted.append({
            "score": point.score,
            "text": point.payload.get("text", ""),
            "source": point.payload.get("source", ""),
            "chunk_id": point.payload.get("chunk_id", ""),
            "metadata": point.payload
        })

    return formatted

# ===========================================
# Example Query — try one from your dataset
# Replace this with a question relevant to your data.
sample_query = "What does 6G offer?"

# Run the retrieval function
retrieved_docs = semantic_retrieval(sample_query, top_k=5)

# Display retrieved documents
print(f"\nTop {len(retrieved_docs)} results for query: '{sample_query}'\n")
for i, doc in enumerate(retrieved_docs, start=1):
    print(f"Result {i}: (score={doc['score']:.4f})")
    print(f"Source: {doc['source']} | Page: {doc.get('page')}")
    print(f"Snippet: {doc['text'][:300]}...\n")


Top 5 results for query: 'What does 6G offer?'

Result 1: (score=0.6629)
Source: ./data/test.pdf | Page: None
Snippet: . The 6 G vision is to create a seamless reality where the physical and digital worlds, so far separated, are converged. This will enable seamless movement in a cyberphysical continuum of a connected physical world of senses, actions, and experiences, and its programmable digital representation. Wit...

Result 2: (score=0.6524)
Source: ./data/test.pdf | Page: None
Snippet: . II and overviewing the use cases (UC) that are expected to drive a digital and societal revolution in Sec. III. This is followed by introducing the paradigm shifts that formulate an evolved network architecture in Sec. IV. In Sec. V, we highlight the main 6 G technologies needed to realize the vis...

Result 3: (score=0.6511)
Source: ./data/test.pdf | Page: None
Snippet: . In addition, the vision of 6 G is to also create more humanfriendly, sustainable, and efficient communities. This requires net

# 3 Hybrid Retrieval (Vector + Graph)
This next cell implements a hybrid retrieval layer that merges: 
1. **Semantic relevance**
2. **Graph-based context**

This hybrid approach improves entreprise knowledge retrieval by connecting related entities instead of relying exclusively on text similarity. 

In [9]:
# Fixed semantic_retrieval + hybrid_retrieval_pipeline
# Uses: BGE-M3 embedder and new Qdrant client API (query_points)
# -------------------------

from typing import List, Dict, Any
import numpy as np

def semantic_retrieval(query: str, top_k: int = 20, collection_name: str = COLLECTION_NAME) -> List[Dict[str, Any]]:
    """
    Retrieve top-k vector matches from Qdrant using BGE-M3 query embedding.
    Uses qdrant_client.query_points(...) (new client API).
    Returns a list of dicts with keys: score, text, source, chunk_id, source_type.
    """
    if qdrant_client is None:
        print("⚠️ Qdrant client not connected — skipping semantic retrieval.")
        return []

    if embedder is None:
        print("⚠️ Embedder not loaded — cannot run semantic retrieval.")
        return []

    # 1) Embed the query consistently with how chunks were embedded
    out = embedder.encode(
        query,
        max_length=8192,
        return_dense=True,
        return_sparse=False,
        return_colbert_vecs=False
    )

    qvec = out["dense_vecs"]
    # ensure numpy array -> normalize -> list
    qvec = np.asarray(qvec, dtype=float)
    norm = np.linalg.norm(qvec)
    if norm > 0:
        qvec = (qvec / norm).tolist()
    else:
        qvec = qvec.tolist()

    # 2) Use the new Qdrant API to query points
    # Note: some qdrant-client versions expose query_points(...) which returns an object with .points
    try:
        resp = qdrant_client.query_points(
            collection_name=collection_name,
            query=qvec,
            limit=top_k,
            with_payload=True,
            with_vectors=False
        )
    except AttributeError:
        # Fallback for slightly different client API names
        resp = qdrant_client.search(
            collection_name=collection_name,
            query_vector=qvec,
            limit=top_k,
            with_payload=True
        )

    # 3) Normalize/form the response
    formatted: List[Dict[str, Any]] = []
    # If resp has .points (new API)
    points = getattr(resp, "points", None)
    if points is None:
        # older return type: resp might already be a list
        points = resp

    for p in points:
        # p may be a Point object or a dict depending on client version
        score = getattr(p, "score", None)
        payload = getattr(p, "payload", None) or (p.get("payload") if isinstance(p, dict) else None)

        formatted.append({
            "score": float(score) if score is not None else 0.0,
            "text": payload.get("text", "") if payload else "",
            "source": payload.get("source", "") if payload else "",
            "chunk_id": payload.get("chunk_id", "") if payload else "",
            "metadata": payload or {},
            "source_type": "vector"
        })

    return formatted


def graph_retrieval(query: str, memgraph_conn: Memgraph, limit: int = 10) -> List[Dict[str, Any]]:
    """
    (Safe) graph retrieval to fetch chunks from Memgraph matching the query.
    Keeps same output schema as semantic_retrieval.
    """
    if memgraph_conn is None:
        return []

    cypher_query = f"""
    MATCH (c:Chunk)
    WHERE toLower(c.text) CONTAINS toLower("{query}")
    RETURN c.text AS text, c.source AS source, c.chunk_id AS chunk_id
    LIMIT {limit};
    """

    try:
        rows = list(memgraph_conn.execute_and_fetch(cypher_query))
    except Exception as e:
        print("Graph retrieval failed:", e)
        return []

    formatted = []
    for r in rows:
        formatted.append({
            "score": 0.50,
            "text": r.get("text", ""),
            "source": r.get("source", ""),
            "chunk_id": r.get("chunk_id", ""),
            "metadata": dict(r),
            "source_type": "graph"
        })
    return formatted


def hybrid_retrieval_pipeline(query: str, top_k_semantic: int = DEFAULT_TOP_K, top_k_final: int = FINAL_TOP_K) -> List[Dict[str, Any]]:
    """
    Combined hybrid retrieval:
      1) Vector retrieval (BGE-M3 / Qdrant)
      2) Graph retrieval (Memgraph)
      3) Merge, deduplicate by chunk_id, and return top-K by score
    """
    sem = semantic_retrieval(query, top_k=top_k_semantic)
    gr = graph_retrieval(query, memgraph_conn=memgraph, limit=top_k_semantic)

    # Merge preserving best score seen
    merged: Dict[str, Dict[str, Any]] = {}
    for item in sem + gr:
        cid = item.get("chunk_id") or item.get("metadata", {}).get("chunk_id") or item.get("source") + "_" + str(len(merged))
        if cid in merged:
            # keep higher score
            if item.get("score", 0.0) > merged[cid].get("score", 0.0):
                merged[cid] = item
        else:
            merged[cid] = item

    merged_list = list(merged.values())
    merged_list = sorted(merged_list, key=lambda x: x.get("score", 0.0), reverse=True)

    # Trim to final top_k_final
    return merged_list[:top_k_final]


# -------------------------
# Quick test (replace sample_query as needed)
# -------------------------
sample_query = "What does 6G offer?"
hybrid_results = hybrid_retrieval_pipeline(sample_query, top_k_semantic=20, top_k_final=10)

print(f"Retrieved {len(hybrid_results)} hybrid candidates for: {sample_query}\n")
for i, doc in enumerate(hybrid_results, 1):
    print(f"Result {i}: [{doc.get('source_type')}] score={doc.get('score'):.4f} source={doc.get('source')}")
    print(doc.get('text', '')[:300].replace("\n", " "), "\n" + "-"*80)

Retrieved 10 hybrid candidates for: What does 6G offer?

Result 1: [vector] score=0.6629 source=./data/test.pdf
. The 6 G vision is to create a seamless reality where the physical and digital worlds, so far separated, are converged. This will enable seamless movement in a cyberphysical continuum of a connected physical world of senses, actions, and experiences, and its programmable digital representation. Wit 
--------------------------------------------------------------------------------
Result 2: [vector] score=0.6524 source=./data/test.pdf
. II and overviewing the use cases (UC) that are expected to drive a digital and societal revolution in Sec. III. This is followed by introducing the paradigm shifts that formulate an evolved network architecture in Sec. IV. In Sec. V, we highlight the main 6 G technologies needed to realize the vis 
--------------------------------------------------------------------------------
Result 3: [vector] score=0.6511 source=./data/test.pdf
. In additio

# 4 Semantic Reranking

(after hybrid retrieval)

This cell takes everything retrieved in Cell 3, and reorders the results using a cross-encoder, which computes true, pairwise relevance between the query and each candidate chunk.

In [10]:
# Robust Cross-Encoder Reranking cell
# Inputs:
#   - query (str)
#   - hybrid_results (list of dicts, each must have at least 'text' and ideally 'score' and 'chunk_id')
# Outputs:
#   - reranked_results (list of dicts) with fields: rerank_score, final_score, original metadata

from sentence_transformers import CrossEncoder
from typing import List, Dict, Any
import numpy as np
from tqdm import tqdm

# ---------- Config ----------
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
RERANKER_BATCH_SIZE = 16     # lower this (e.g., 4) on small-CPU machines
MAX_CHARS_FOR_RERANKER = 2000  # truncate long passages to this many chars
ALPHA = 0.85                 # final_score = ALPHA * rerank_score + (1-ALPHA) * original_score
# ----------------------------

# Load reranker if not loaded already
try:
    # If you already have 'reranker' in the namespace, reuse it
    reranker
except NameError:
    try:
        print(f"Loading reranker model: {RERANKER_MODEL} ...")
        reranker = CrossEncoder(RERANKER_MODEL)
    except Exception as e:
        reranker = None
        print("Could not load reranker model:", e)

def _safe_text(txt: str, max_chars: int = MAX_CHARS_FOR_RERANKER) -> str:
    """Make sure text is a string and truncate it preserving head+tail if too long."""
    if not isinstance(txt, str):
        return ""
    txt = txt.strip()
    if len(txt) <= max_chars:
        return txt
    half = max_chars // 2
    return txt[:half] + "\n\n[...] \n\n" + txt[-half:]

def rerank_results(query: str, results: List[Dict[str, Any]], top_n: int = 5) -> List[Dict[str, Any]]:
    """
    Rerank `results` (list of dicts with 'text' and 'score') using a CrossEncoder.
    Returns top_n reranked items, with 'rerank_score' and 'final_score' added.
    """
    if not results:
        return []

    # If no reranker available, fallback to original sorting
    if reranker is None:
        print("Reranker unavailable — returning results sorted by original score.")
        return sorted(results, key=lambda x: x.get("score", 0.0), reverse=True)[:top_n]

    # Prepare safe (query, text) pairs
    pairs = []
    index_map = []  # map pair index -> result index
    for i, r in enumerate(results):
        text = r.get("text", "")
        safe_text = _safe_text(text)
        pairs.append((query, safe_text))
        index_map.append(i)

    # Batch-predict scores to avoid OOM
    rerank_scores = []
    for i in tqdm(range(0, len(pairs), RERANKER_BATCH_SIZE), desc="Reranker batches"):
        batch = pairs[i : i + RERANKER_BATCH_SIZE]
        try:
            batch_scores = reranker.predict(batch, show_progress_bar=False)
        except TypeError:
            # Some cross-encoder versions expect list[str] instead of list[tuple]
            # convert to "query \t text" fallback
            batch_joined = [q + "\t" + t for q, t in batch]
            batch_scores = reranker.predict(batch_joined, show_progress_bar=False)
        # ensure list of floats
        batch_scores = np.asarray(batch_scores, dtype=float).tolist()
        rerank_scores.extend(batch_scores)

    # Attach reranker scores back to original results and compute final score
    reranked = []
    for pair_idx, score in enumerate(rerank_scores):
        res_idx = index_map[pair_idx]
        orig = results[res_idx].copy()
        orig_score = float(orig.get("score", 0.0))
        orig["rerank_score"] = float(score)
        orig["final_score"] = float(ALPHA * orig["rerank_score"] + (1.0 - ALPHA) * orig_score)
        reranked.append(orig)

    # Sort by final_score descending and return top_n
    reranked_sorted = sorted(reranked, key=lambda x: x["final_score"], reverse=True)
    return reranked_sorted[:top_n]


# ---------------------------
# Example usage (run after hybrid_retrieval_pipeline)
# ---------------------------
query = "What does 6G offer?"
try:
    reranked_results = rerank_results(query, hybrid_results, top_n=5)
    print(f"Top {len(reranked_results)} reranked results:")
    for i, r in enumerate(reranked_results, 1):
        print(f"\n{i}. chunk_id={r.get('chunk_id', 'n/a')} source_type={r.get('source_type','?')}")
        print(f"   final_score={r['final_score']:.4f} rerank_score={r['rerank_score']:.4f} orig_score={r.get('score', 0.0):.4f}")
        print("   snippet:", r.get("text","")[:300].replace("\n"," "), "...")
except Exception as e:
    print("Reranking failed:", e)

Reranker batches: 100%|██████████| 1/1 [00:00<00:00,  3.14it/s]

Top 5 reranked results:

1. chunk_id=test_chunk8 source_type=vector
   final_score=3.2346 rerank_score=3.6906 orig_score=0.6511
   snippet: . In addition, the vision of 6 G is to also create more humanfriendly, sustainable, and efficient communities. This requires networks that guarantee worldwide digital inclusion to support a wide range of elements, end-to-end (E 2 E) life-cycle tracking to reduce waste and automate recycling, resourc ...

2. chunk_id=test_chunk7 source_type=vector
   final_score=2.5895 rerank_score=2.9295 orig_score=0.6629
   snippet: . The 6 G vision is to create a seamless reality where the physical and digital worlds, so far separated, are converged. This will enable seamless movement in a cyberphysical continuum of a connected physical world of senses, actions, and experiences, and its programmable digital representation. Wit ...

3. chunk_id=test_chunk19 source_type=vector
   final_score=1.6258 rerank_score=1.8077 orig_score=0.5954
   snippet: . To this end, th




In [None]:
# Mini evaluation dataset
# ------------------------
# Format: {"query": "...", "gold_keywords": ["word1", "word2", ...]}

evaluation_set = [
    {
        "query": "What does 6G offer?",
        "gold_keywords": ["sub-thz", "terahertz", "ai-native", "ai native", 
 "extreme bandwidth", "ultra-reliable", "low-latency",
 "massive connectivity", "holographic", "ris", "b5g"]
    },
    {
        "query": "What is RIS technology?",
        "gold_keywords": ["reconfigurable intelligent surface", "ris", 
 "metasurface", "reflective element", "beamforming"]
    },
    {
        "query": "What challenges does 5G face?",
        "gold_keywords": ["latency", "energy efficiency", "spectrum", 
 "massive iot", "ultra reliable", "coverage"]

    }
]

print(f"Loaded {len(evaluation_set)} evaluation samples.")

Loaded 3 evaluation samples.


In [None]:
def compute_keyword_f1(text: str, keywords: list) -> float:
    """
    Computes F1 between gold keywords and text content.
    This version is correct: F1 is always between 0 and 1.
    """
    if not text:
        return 0.0
    
    text_lower = text.lower()

    # Count hits
    hits = sum(1 for k in keywords if k.lower() in text_lower)
    total = len(keywords)

    # Precision = hits / retrieved_keywords (1 chunk → retrieved_keywords=hits>0 means 1 relevant retrieved)
    precision = hits / max(hits, 1)

    # Recall = hits / total gold keywords
    recall = hits / total if total > 0 else 0.0

    if precision + recall == 0:
        return 0.0

    return 2 * precision * recall / (precision + recall)

def evaluate_pipeline(evaluation_set, top_n=5):
    """
    Runs a small benchmark:
      1) semantic retrieval
      2) hybrid retrieval
      3) hybrid + reranking
    Computes keyword F1 for top retrieved chunk.
    """

    results = []

    for sample in evaluation_set:
        query = sample["query"]
        gold = sample["gold_keywords"]

        print(f"\n=== Evaluating: {query} ===")

        # ---- 1. Semantic ----
        semantic = semantic_retrieval(query, top_k=top_n)
        semantic_top = semantic[0]["text"] if semantic else ""
        f1_sem = compute_keyword_f1(semantic_top, gold)
        print(f"Semantic F1: {f1_sem:.3f}")

        # ---- 2. Hybrid ----
        hybrid = hybrid_retrieval_pipeline(query, top_k_semantic=top_n)
        hybrid_top = hybrid[0]["text"] if hybrid else ""
        f1_hybrid = compute_keyword_f1(hybrid_top, gold)
        print(f"Hybrid F1:   {f1_hybrid:.3f}")

        # ---- 3. Reranked ----
        reranked = rerank_results(query, hybrid, top_n=top_n)
        reranked_top = reranked[0]["text"] if reranked else ""
        f1_rerank = compute_keyword_f1(reranked_top, gold)
        print(f"Reranked F1: {f1_rerank:.3f}")
        
        

        results.append({
            "query": query,
            "semantic_f1": f1_sem,
            "hybrid_f1": f1_hybrid,
            "reranked_f1": f1_rerank
        })

    return results

In [18]:
results = evaluate_pipeline(evaluation_set, top_n=5)
results


=== Evaluating: What does 6G offer? ===
Semantic F1: 0.167
Hybrid F1:   0.167


Reranker batches: 100%|██████████| 1/1 [00:00<00:00,  9.99it/s]

Reranked F1: 0.000

=== Evaluating: What is RIS technology? ===





Semantic F1: 0.000
Hybrid F1:   0.000


Reranker batches: 100%|██████████| 1/1 [00:00<00:00, 10.92it/s]

Reranked F1: 0.000

=== Evaluating: What challenges does 5G face? ===





Semantic F1: 0.500
Hybrid F1:   0.500


Reranker batches: 100%|██████████| 1/1 [00:00<00:00, 11.98it/s]

Reranked F1: 0.500





[{'query': 'What does 6G offer?',
  'semantic_f1': 0.16666666666666669,
  'hybrid_f1': 0.16666666666666669,
  'reranked_f1': 0.0},
 {'query': 'What is RIS technology?',
  'semantic_f1': 0.0,
  'hybrid_f1': 0.0,
  'reranked_f1': 0.0},
 {'query': 'What challenges does 5G face?',
  'semantic_f1': 0.5,
  'hybrid_f1': 0.5,
  'reranked_f1': 0.5}]

In [21]:
def debug_keywords(text, keywords):
    text_lower = text.lower()
    for k in keywords:
        print(f"{k:30} → {'FOUND' if k.lower() in text_lower else 'not found'}")

debug_keywords(hybrid_results[0]["text"], evaluation_set[0]["gold_keywords"])

sub-thz                        → not found
terahertz                      → not found
ai-native                      → not found
ai native                      → not found
extreme bandwidth              → not found
ultra reliable                 → not found
low latency                    → not found
massive connectivity           → not found
holographic                    → not found
ris                            → FOUND
b5g                            → not found


# 5. Final Context + Citation Packaging
In this last cell we take the relevance-boosted results (semantic reranking) and out put a context dic for LLM to take.

In [24]:
# -------------------------------------------------------------
# Cell 5 (Fixed): Final Context + Citations Packaging (BGE-M3)
# -------------------------------------------------------------

import re

def clean_and_deduplicate_structured(items: list):
    """
    Deduplicate based on cleaned text, but keep alignment
    between text, source, rerank score, etc.
    """
    seen = set()
    unique = []

    for item in items:
        text = item["text"]
        key = re.sub(r"\s+", " ", text.strip().lower())

        if key not in seen:
            seen.add(key)
            unique.append(item)

    return unique


def assemble_context(reranked_results: list, max_chars: int = 8000):
    """
    Merge reranked chunks into final LLM context.
    Now correctly handles deduplication while preserving metadata alignment.
    """

    # Deduplicate the full items, not just the text
    unique_items = clean_and_deduplicate_structured(reranked_results)

    merged_text = ""
    citations = []

    for item in unique_items:
        text = item["text"].strip()

        # Stop once we exceed context budget
        if len(merged_text) + len(text) > max_chars:
            break

        merged_text += text + "\n\n---\n\n"

        citations.append({
            "source": item["source"],
            "chunk_preview": text[:120] + "...",
            "rerank_score": item.get("rerank_score"),
        })

    return {
        "merged_context": merged_text.strip(),
        "citations": citations,
        "raw_chunks": [item["text"] for item in unique_items]
    }


# ---------------------------
# Test final context assembly
# ---------------------------
context_for_llm = assemble_context(reranked_results, max_chars=8000)

print("Final Context Block Ready\n")
print(context_for_llm["merged_context"][:800])  # preview first 800 chars
print("\nCitations:")
for c in context_for_llm["citations"]:
    print(f"- {c['source']} | {c['chunk_preview']}")

Final Context Block Ready

. In addition, the vision of 6 G is to also create more humanfriendly, sustainable, and efficient communities. This requires networks that guarantee worldwide digital inclusion to support a wide range of elements, end-to-end (E 2 E) life-cycle tracking to reduce waste and automate recycling, resource-efficient connected agriculture, universal access to digital healthcare, etc. This requires embedded autonomous sensors and actuators, worldwide coverage with outstanding energy-, material-, and cost-efficiency, as well as a network platform with high availability and security [3]. III

---

. The 6 G vision is to create a seamless reality where the physical and digital worlds, so far separated, are converged. This will enable seamless movement in a cyberphysical continuum of a connected physi

Citations:
- ./data/test.pdf | . In addition, the vision of 6 G is to also create more humanfriendly, sustainable, and efficient communities. This requ...
- ./data/test.pd

In [None]:
print("context_for_llm exists:", 'context_for_llm' in locals())

if 'context_for_llm' in locals():
    print(context_for_llm.keys())

# Notebook 4
From this point on this is what used to be notebook 04 and due to kernel and .venv issues, had to be moved here. Will probably moved back once a solution is found.

# LLM Answering Pipeline
This section is responsible for:

* Running a local, open-source LLM
* Constructing the final prompt (query + context)
* Generating an answer
* Attaching citations
* Returning a polished response

Implementing open-source software, we will use:

Ollama (best, simplest, FREE local LLM runner)

Model: llama3 (this can changed eventually in order to improve)

# NB04.1 Imports + Global Settings

In [None]:
# ---------------------------------------------------------------
# Notebook 04 — Cell 3
# Imports + Global Settings
# ---------------------------------------------------------------

import json
import os
from datetime import datetime
from typing import Dict

# Optional: pretty printing
import textwrap

# Load Ollama local LLM client (try the official package first, fall back to HTTP)
# Preferred: `pip install ollama`
llm = None
try:
    from ollama import Client as OllamaClient
    llm = OllamaClient(host="http://localhost:11434")
    print("Local LLM client ready (ollama package).")
except Exception as e:
    # Fallback: lightweight HTTP client using requests
    try:
        import requests
        from requests.exceptions import HTTPError, RequestException

        class SimpleOllamaClient:
            def __init__(self, host: str = "http://localhost:11434"):
                self.host = host.rstrip("/")

            def _normalize(self, data):
                # Normalize common response shapes to a consistent dict
                # Ollama's primary shape: {"response": "...", "model": "...", ...}
                if isinstance(data, dict):
                    # Check for Ollama's native "response" key first
                    if "response" in data:
                        return {"message": {"content": data["response"]}}
                    # Check for standard message.content shape
                    if "message" in data and isinstance(data["message"], dict) and "content" in data["message"]:
                        return data
                    # Check for simple text key
                    if "text" in data:
                        return {"message": {"content": data["text"]}}
                    # Check for choices array (OpenAI-like)
                    if "choices" in data and isinstance(data["choices"], list) and len(data["choices"]) > 0:
                        first = data["choices"][0]
                        if isinstance(first, dict) and ("text" in first or "message" in first):
                            text = first.get("text") or (first.get("message") if isinstance(first.get("message"), str) else (first.get("message", {}).get("content") if isinstance(first.get("message"), dict) else None))
                            return {"message": {"content": text}}
                # Fallback: stringify
                return {"message": {"content": str(data)}}

            def chat(self, model: str, messages: list):
                """Send a consolidated prompt to Ollama HTTP API and return a normalized response.

                Returns a dict containing `{"message": {"content": <str>}}` on success.
                """
                prompt_text = "\n".join([m.get("content", "") for m in messages if m.get("role") == "user"]).strip()

                url = f"{self.host}/api/generate"
                payload = {"model": model, "prompt": prompt_text, "stream": False}

                try:
                    resp = requests.post(url, json=payload, timeout=60)
                    resp.raise_for_status()
                except HTTPError as he:
                    status = getattr(he.response, "status_code", None)
                    if status == 404:
                        raise RuntimeError(
                            f"Ollama HTTP endpoint not found (404). Is the Ollama daemon running at {self.host}?"
                        ) from he
                    raise
                except RequestException as re:
                    raise RuntimeError(
                        f"Failed to reach Ollama at {self.host}: {re}. Ensure the daemon is running and reachable."
                    ) from re

                try:
                    data = resp.json()
                except ValueError:
                    # Non-JSON response
                    return {"message": {"content": resp.text}}

                return self._normalize(data)

        llm = SimpleOllamaClient(host="http://localhost:11434")
        print("Using fallback HTTP Ollama client (requests).")
    except Exception as e2:
        print("Could not import `ollama` package or use HTTP fallback.")
        print("To enable the local LLM client, either:")
        print("  1) pip install ollama")
        print("  2) ensure Ollama daemon is running at http://localhost:11434 and install requests (`pip install requests`)")
        print("Falling back to a stub llm object that raises if used.")

        class _StubLLM:
            def chat(self, *args, **kwargs):
                raise RuntimeError("No Ollama client available. Install `ollama` or `requests` and start the Ollama daemon.")

        llm = _StubLLM()


In [None]:
# ---------------------------------------------------------------
# Notebook 04 — Cell 2
# Prompt Template + Formatting Function
# ---------------------------------------------------------------

def build_rag_prompt(query: str, context_data: Dict):
    """
    Build the final RAG prompt to send to the local LLM.
    Includes:
    - merged context
    - user question
    - instruction to use citations
    """

    context_block = context_data["merged_context"]

    prompt = f"""
You are a helpful assistant answering questions using ONLY the context provided.

CONTEXT:
-------
{context_block}
-------

INSTRUCTIONS:
- Use ONLY the information in the context.
- Do NOT hallucinate.
- Cite your sources using this format: [source].
- If the answer is not in the context, say: "The answer is not available in the provided context."

USER QUESTION:
{query}

FINAL ANSWER (with citations):
"""

    return prompt.strip()


In [None]:
# ---------------------------------------------------------------
# Notebook 04 — Cell 3
# Generate Answer using Local LLM (Ollama)
# ---------------------------------------------------------------

def generate_llm_answer(query: str, context_for_llm: Dict, model: str = None):
    """
    Sends the final prompt to the local LLM and retrieves the response.
    If model is not specified, uses OLLAMA_MODEL env var (default: llama3.1:latest).
    """
    if model is None:
        model = os.getenv("OLLAMA_MODEL", "llama3.1:latest")
    
    prompt = build_rag_prompt(query, context_for_llm)

    try:
        response = llm.chat(
            model=model,
            messages=[{"role": "user", "content": prompt}]
        )
    except RuntimeError:
        # Re-raise runtime errors from the client (e.g., helpful 404 message)
        raise
    except Exception as e:
        raise RuntimeError(f"LLM call failed: {e}") from e

    # Attempt to extract a text response from a range of possible shapes
    content = None
    if isinstance(response, dict):
        # Check for normalized message.content first (most common after _normalize)
        if "message" in response and isinstance(response["message"], dict) and "content" in response["message"]:
            content = response["message"]["content"]
        # Fallback checks in case _normalize wasn't applied
        elif "response" in response:
            content = response["response"]
        elif "text" in response:
            content = response["text"]
        elif "choices" in response and isinstance(response["choices"], list) and len(response["choices"]) > 0:
            first = response["choices"][0]
            if isinstance(first, dict):
                content = first.get("text") or (first.get("message") if isinstance(first.get("message"), str) else (first.get("message", {}).get("content") if isinstance(first.get("message"), dict) else None))

    if content is None:
        # Fallback: stringify whatever we got
        content = str(response)

    return content


# ---- Test the answering pipeline (only if variables are defined) ----
# If 'context_for_llm' comes from notebook 03, use it. Otherwise, use a sample query.
if 'context_for_llm' in locals():
    # Use the query and context from notebook 03
    sample_answer = generate_llm_answer(query, context_for_llm)
    
    print("\n===== LLM ANSWER =====\n")
    print(sample_answer)
else:
    print("Variable 'context_for_llm' not found in kernel.")
    print("Please run Notebook 03 first to generate the retrieval context, then return here.")

In [None]:
# ---------------------------------------------------------------
# Notebook 04 — Cell 4
# Final Output Formatting (Answer + Citations)
# ---------------------------------------------------------------

def format_final_output(answer: str, context_for_llm: Dict):
    """
    Returns a clean, structured output including:
    - Answer
    - Citations
    """
    print("\n=======================")
    print("FINAL ANSWER")
    print("=======================\n")
    print(answer)

    print("\n=======================")
    print("CITATIONS")
    print("=======================\n")
    for c in context_for_llm["citations"]:
        print(f"- Source: {c['source']}")
        print(f"  Preview: {c['chunk_preview']}")
        print(f"  Score: {c['rerank_score']:.4f}\n")


# Display final result (only if variables are defined)
if 'sample_answer' in locals() and 'context_for_llm' in locals():
    format_final_output(sample_answer, context_for_llm)
else:
    print("Variables 'sample_answer' and/or 'context_for_llm' not found.")
    print("Please run cells above first to generate these variables.")

In [None]:
# ---- Ollama HTTP diagnostic (run this cell to test your local Ollama docker)
import os
print("--- Ollama HTTP Diagnostic ---")
host = os.getenv('OLLAMA_HOST', 'http://localhost:11434')
model = os.getenv('OLLAMA_MODEL', 'llama3.1:latest')
print(f'Using host={host} model={model}')

try:
    import requests
    print('requests version:', requests.__version__)
except Exception as e:
    print('requests not available in this kernel:', e)

# Quick GET checks for common endpoints
for path in ['', '/api', '/api/models', '/api/generate']:
    url = host.rstrip('/') + path
    try:
        r = requests.get(url, timeout=4)
        print(f'GET {path} ->', r.status_code)
    except Exception as e:
        print(f'GET {path} error:', e)

# Try a small generate POST to /api/generate
payload = {'model': model, 'prompt': 'Say hello and identify the model used.', 'stream': False}
try:
    gen_url = host.rstrip('/') + '/api/generate'
    r = requests.post(gen_url, json=payload, timeout=20)
    print('POST /api/generate ->', r.status_code)
    text = r.text
    print('Response (truncated):', text[:800])
    try:
        j = r.json()
        print('JSON keys:', list(j.keys()))
        # If llm object exists with _normalize, try to normalize and show result
        if 'llm' in globals() and hasattr(llm, '_normalize'):
            try:
                print('Normalized sample:', llm._normalize(j))
            except Exception as e:
                print('Normalization error:', e)
        else:
            # Try simple normalization heuristics
            if isinstance(j, dict):
                if 'message' in j:
                    print('message:', j.get('message'))
                elif 'text' in j:
                    print('text:', j.get('text'))
                elif 'choices' in j and isinstance(j['choices'], list) and len(j['choices'])>0:
                    print('choices[0]:', j['choices'][0])
    except Exception:
        pass
except Exception as e:
    print('POST /api/generate error:', e)

print('Diagnostic complete. If you see 404 from POST, verify the Ollama daemon and model name (try OLLAMA_MODEL=llama3.1:latest).')
