In [2]:
# ============================================================================
# CELL 1: FIX DEPENDENCIES - RUN FIRST AFTER RESTART
# ============================================================================

import subprocess
import sys

print('🔧 Fixing package dependencies...')

# Fix PyArrow compatibility - install version 15.0.2 for bigframes
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "pyarrow"], 
                capture_output=True, check=False)
subprocess.run([sys.executable, "-m", "pip", "install", "pyarrow==15.0.2", "--no-cache-dir", "-q"], 
                check=True)

# Fix rich version for bigframes compatibility
subprocess.run([sys.executable, "-m", "pip", "install", "rich==13.7.1", "--no-cache-dir", "-q"], 
                check=True)

# Install google-cloud-bigquery-storage (missing dependency)
subprocess.run([sys.executable, "-m", "pip", "install", "google-cloud-bigquery-storage>=2.30.0", "--no-cache-dir", "-q"], 
                check=True)

# Upgrade google-cloud-bigquery for bigframes
subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "google-cloud-bigquery>=3.31.0", "--no-cache-dir", "-q"], 
                check=True)

# Upgrade google-api-core for pandas-gbq
subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "google-api-core>=2.10.2", "--no-cache-dir", "-q"], 
                check=True)

# Install missing packages including faiss-cpu and fix protobuf
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "keybert", "rank-bm25", "evaluate", "faiss-cpu", "protobuf<5.0.0"], 
                check=True)

print('✅ Dependencies fixed and installed')

🔧 Fixing package dependencies...
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 38.3/38.3 MB 252.4 MB/s eta 0:00:00


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 4.1.1 requires pyarrow>=21.0.0, but you have pyarrow 15.0.2 which is incompatible.
cudf-polars-cu12 25.6.0 requires pylibcudf-cu12==25.6.*, but you have pylibcudf-cu12 25.2.2 which is incompatible.


   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 31.4/31.4 MB 57.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.9/294.9 kB 20.1 MB/s eta 0:00:00
✅ Dependencies fixed and installed


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
grpcio-status 1.76.0 requires protobuf<7.0.0,>=6.31.1, but you have protobuf 4.25.8 which is incompatible.
pylibcudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
cudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
google-cloud-automl 1.0.1 requires google-api-core[grpc]<2.0.0dev,>=1.14.0, but you have google-api-core 2.27.0 which is incompatible.
cudf-polars-cu12 25.6.0 requires pylibcudf-cu12==25.6.*, but you have pylibcudf-cu12 25.2.2 which is incompatible.
pydrive2 1.21.3 requires cryptography<44, but you have cryptography 46.0.1 which is incompatible.
pydrive2 1.21.3 requires pyOpenSSL<=24.2.1,>=19.1.0, but you have pyopenssl 25.3.0 which is incom

In [3]:
# ============================================================================
# CELL 2: VERIFY ALL IMPORTS - RUN SECOND
# ============================================================================

import warnings
warnings.filterwarnings("ignore")

print("🔍 Testing imports...")

try:
    from datasets import load_dataset
    print("✅ datasets")
    from sentence_transformers import SentenceTransformer
    print("✅ sentence-transformers")
    from transformers import AutoTokenizer, AutoModelForCausalLM
    print("✅ transformers")
    import faiss
    print("✅ faiss")
    from keybert import KeyBERT
    print("✅ keybert")
    from rank_bm25 import BM25Okapi
    print("✅ rank-bm25")
    import torch
    print(f"✅ torch (device: {'cuda' if torch.cuda.is_available() else 'cpu'})")
    print("\n🎉 ALL IMPORTS SUCCESSFUL!")
except Exception as e:
    print(f"❌ Import failed: {e}\nPlease restart kernel and try again.")
    raise


🔍 Testing imports...
✅ datasets


2025-10-24 18:21:35.362382: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761330095.384307     194 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761330095.391090     194 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


✅ sentence-transformers
✅ transformers
✅ faiss
✅ keybert
✅ rank-bm25
✅ torch (device: cuda)

🎉 ALL IMPORTS SUCCESSFUL!


In [4]:
# ============================================================================
# MULTI-DOMAIN RAG PIPELINE FOR MEDICAL QA
# Complete production-ready implementation
# Datasets: Women's Health + General Medical QA
# Models: all-MiniLM (embedder), BGE-reranker, BioGPT (HyDE), Flan-T5 (generator)
# ============================================================================

import os, sys, time, json, pickle, re, warnings, random
warnings.filterwarnings("ignore")
import numpy as np
import torch
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Any, Tuple
import logging
from datetime import datetime

from datasets import load_dataset
from sklearn.model_selection import train_test_split
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from keybert import KeyBERT
import faiss
from rank_bm25 import BM25Okapi

# ============================================================================
# REPRODUCIBILITY & SETUP
# ============================================================================

def set_all_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_all_seeds(42)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Download NLTK data
for resource in ['punkt', 'stopwords']:
    try:
        nltk.data.find(f'tokenizers/{resource}')
    except LookupError:
        nltk.download(resource, quiet=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"🚀 Device: {device}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}, Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")

# ============================================================================
# CONFIGURATION
# ============================================================================

@dataclass
class DomainConfig:
    name: str
    dataset_name: str
    dataset_split: str = "train"
    index_path: str = None
    id2doc_path: str = None
    metadata_path: str = None
    
    def __post_init__(self):
        if self.index_path is None:
            self.index_path = f"{self.name}_faiss.index"
        if self.id2doc_path is None:
            self.id2doc_path = f"{self.name}_id2doc.pkl"
        if self.metadata_path is None:
            self.metadata_path = f"{self.name}_metadata.json"

@dataclass
class RAGConfig:
    embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"
    reranker_model: str = "BAAI/bge-reranker-large"
    hyde_model: str = "microsoft/BioGPT-Large"
    generator_model: str = "google/flan-t5-large"
    
    chunk_window: int = 3
    chunk_stride: int = 1
    retrieve_k: int = 30
    rerank_topk: int = 8
    context_chunks: int = 4
    hyde_weight: float = 0.4
    faiss_alpha: float = 0.6
    
    max_new_tokens: int = 200
    hyde_max_tokens: int = 60
    
    completeness_threshold: float = 0.65
    faithfulness_threshold: float = 0.55
    
    retrieval_weight: float = 0.4
    completeness_weight: float = 0.3
    faithfulness_weight: float = 0.3
    
    prompts_log: str = "prompts_outputs.pkl"
    random_seed: int = 42
    test_size: float = 0.15

DOMAINS = [
    DomainConfig(name="women_health", dataset_name="altaidevorg/women-health-mini"),
    DomainConfig(name="medical_qa", dataset_name="Malikeh1375/medical-question-answering-datasets")
]

config = RAGConfig()

# ============================================================================
# UTILITIES
# ============================================================================

def clean_text_artifacts(text: str) -> str:
    text = re.sub(r"^(Answer:|Final answer:|Response:)\s*", "", text, flags=re.IGNORECASE)
    text = re.sub(r"<\/?[^>]+>|</s>|▃|\[INST\]|\[/INST\]", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text.strip(" \n\r\t\"'")

def monitor_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        logger.info(f"💾 GPU: {allocated:.2f}GB / {total:.2f}GB ({allocated/total*100:.1f}%)")
        if allocated/total > 0.85:
            torch.cuda.empty_cache()

# ============================================================================
# DATA LOADING
# ============================================================================

class DatasetLoader:
    @staticmethod
    def extract_qa_pairs(dataset, domain_name: str) -> List[Dict[str, Any]]:
        qa_data = []
        for idx, row in enumerate(dataset):
            try:
                conv = None
                if isinstance(row, dict):
                    for field in ["conversations", "conversation", "dialog", "dialogue", "messages", "turns"]:
                        if field in row:
                            conv = row[field]
                            break
                
                if not conv:
                    continue
                
                user_msgs, assistant_msgs = [], []
                
                if isinstance(conv, list) and len(conv) > 0:
                    if isinstance(conv[0], dict):
                        if "from" in conv[0] and "value" in conv[0]:
                            user_msgs = [m["value"] for m in conv if m.get("from") in ("human", "user")]
                            assistant_msgs = [m["value"] for m in conv if m.get("from") in ("assistant", "bot", "system")]
                        elif "role" in conv[0] and "content" in conv[0]:
                            user_msgs = [m["content"] for m in conv if m.get("role") in ("user", "human")]
                            assistant_msgs = [m["content"] for m in conv if m.get("role") in ("assistant", "bot")]
                    else:
                        if len(conv) >= 2:
                            user_msgs, assistant_msgs = [conv[0]], conv[1:]
                
                if user_msgs and assistant_msgs:
                    question = " ".join(user_msgs).strip()
                    answer = " ".join(assistant_msgs).strip()
                    if question and answer and len(question) > 10 and len(answer) > 10:
                        qa_data.append({
                            "question": question,
                            "answer": answer,
                            "domain": domain_name,
                            "source_id": idx
                        })
            except Exception:
                continue
        return qa_data
    
    @staticmethod
    def load_domain_data(domain_config: DomainConfig) -> Tuple[List[Dict], List[Dict]]:
        logger.info(f"📥 Loading {domain_config.name}...")
        try:
            dataset = load_dataset(domain_config.dataset_name, split=domain_config.dataset_split)
            qa_data = DatasetLoader.extract_qa_pairs(dataset, domain_config.name)
            
            if not qa_data:
                raise ValueError(f"No QA pairs extracted")
            
            train_data, test_data = train_test_split(
                qa_data, test_size=config.test_size, random_state=config.random_seed
            )
            logger.info(f"✅ {domain_config.name}: {len(train_data)} train, {len(test_data)} test")
            return train_data, test_data
        except Exception as e:
            logger.error(f"❌ Failed to load {domain_config.name}: {e}")
            raise

# ============================================================================
# TEXT CHUNKING
# ============================================================================

class TextChunker:
    @staticmethod
    def create_chunks(data: List[Dict], window: int = 3, stride: int = 1, min_chars: int = 50) -> List[Dict]:
        chunks = []
        for item in data:
            text = item.get("answer", "")
            if not text or len(text) < min_chars:
                continue
            
            sentences = sent_tokenize(text)
            if not sentences:
                continue
            
            if len(sentences) <= window:
                chunks.append({
                    "chunk": " ".join(sentences),
                    "source_idx": item.get("source_id", -1),
                    "domain": item.get("domain", "unknown"),
                    "chunk_id": len(chunks)
                })
                continue
            
            for i in range(0, max(1, len(sentences) - window + 1), stride):
                chunks.append({
                    "chunk": " ".join(sentences[i:i + window]),
                    "source_idx": item.get("source_id", -1),
                    "domain": item.get("domain", "unknown"),
                    "chunk_id": len(chunks),
                    "window": (i, i + window)
                })
        return chunks

# ============================================================================
# MODEL MANAGEMENT
# ============================================================================

class ModelManager:
    def __init__(self, config: RAGConfig, device: torch.device):
        self.config = config
        self.device = device
        self.models = {}
    
    def load_embedder(self):
        logger.info(f"📦 Loading embedder...")
        embedder = SentenceTransformer(self.config.embed_model, device=self.device)
        self.models['embedder'] = embedder
        logger.info(f"✅ Embedder loaded")
        return embedder
    
    def load_reranker(self):
        logger.info(f"📦 Loading reranker...")
        tokenizer = AutoTokenizer.from_pretrained(self.config.reranker_model)
        model = AutoModelForSequenceClassification.from_pretrained(
            self.config.reranker_model,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        ).to(self.device)
        model.eval()
        self.models['reranker_tokenizer'] = tokenizer
        self.models['reranker_model'] = model
        logger.info(f"✅ Reranker loaded")
        return tokenizer, model
    
    def load_hyde_model(self):
        logger.info(f"📦 Loading HyDE model...")
        try:
            tokenizer = AutoTokenizer.from_pretrained(self.config.hyde_model)
            model = AutoModelForCausalLM.from_pretrained(
                self.config.hyde_model,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                low_cpu_mem_usage=True
            ).to(self.device)
            model.eval()
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            self.models['hyde_tokenizer'] = tokenizer
            self.models['hyde_model'] = model
            logger.info(f"✅ HyDE model loaded")
            return tokenizer, model
        except Exception as e:
            logger.warning(f"⚠️ HyDE load failed, using query expansion: {e}")
            return None, None
    
    def load_generator(self):
        logger.info(f"📦 Loading generator...")
        tokenizer = AutoTokenizer.from_pretrained(self.config.generator_model)
        model = AutoModelForCausalLM.from_pretrained(
            self.config.generator_model,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(self.device)
        model.eval()
        self.models['gen_tokenizer'] = tokenizer
        self.models['gen_model'] = model
        logger.info(f"✅ Generator loaded")
        return tokenizer, model
    
    def load_keyword_extractor(self):
        try:
            kw_model = KeyBERT(model=self.models.get('embedder'))
            self.models['keyword_extractor'] = kw_model
            logger.info(f"✅ KeyBERT loaded")
            return kw_model
        except Exception as e:
            logger.warning(f"⚠️ KeyBERT load failed: {e}")
            return None
    
    def load_all(self):
        logger.info("🔧 Loading all models...")
        self.load_embedder()
        self.load_reranker()
        self.load_hyde_model()
        self.load_generator()
        self.load_keyword_extractor()
        monitor_memory()
        logger.info("✅ All models loaded")
        return self.models

# ============================================================================
# INDEX MANAGEMENT
# ============================================================================

class MultiDomainIndexManager:
    def __init__(self, config: RAGConfig, embedder: SentenceTransformer):
        self.config = config
        self.embedder = embedder
        self.domain_indices = {}
    
    def build_or_load_domain_index(self, domain_config: DomainConfig, chunks: List[Dict]) -> Tuple[faiss.Index, List[str], BM25Okapi]:
        if Path(domain_config.index_path).exists() and Path(domain_config.id2doc_path).exists():
            try:
                return self._load_existing_index(domain_config)
            except:
                pass
        return self._build_new_index(domain_config, chunks)
    
    def _load_existing_index(self, domain_config: DomainConfig) -> Tuple[faiss.Index, List[str], BM25Okapi]:
        logger.info(f"📂 Loading existing {domain_config.name} index...")
        index = faiss.read_index(domain_config.index_path)
        with open(domain_config.id2doc_path, "rb") as f:
            id2doc = pickle.load(f)
        bm25_corpus = [word_tokenize(doc.lower()) for doc in id2doc]
        bm25 = BM25Okapi(bm25_corpus)
        logger.info(f"✅ Loaded {domain_config.name}: {index.ntotal} vectors")
        return index, id2doc, bm25
    
    def _build_new_index(self, domain_config: DomainConfig, chunks: List[Dict]) -> Tuple[faiss.Index, List[str], BM25Okapi]:
        logger.info(f"🔨 Building {domain_config.name} index...")
        id2doc = [chunk["chunk"] for chunk in chunks]
        
        embeddings = self.embedder.encode(
            id2doc, normalize_embeddings=True, show_progress_bar=True,
            batch_size=64, convert_to_numpy=True
        ).astype('float32')
        
        dim = embeddings.shape[1]
        index = faiss.IndexFlatIP(dim)
        index.add(embeddings)
        
        bm25_corpus = [word_tokenize(doc.lower()) for doc in id2doc]
        bm25 = BM25Okapi(bm25_corpus)
        
        faiss.write_index(index, domain_config.index_path)
        with open(domain_config.id2doc_path, "wb") as f:
            pickle.dump(id2doc, f)
        
        metadata = {"created_at": time.time(), "n_vectors": int(index.ntotal), "embedding_dim": dim, "domain": domain_config.name}
        with open(domain_config.metadata_path, "w") as f:
            json.dump(metadata, f, indent=2)
        
        logger.info(f"✅ Built {domain_config.name}: {index.ntotal} vectors")
        return index, id2doc, bm25
    
    def load_all_domains(self, domain_chunks: Dict[str, List[Dict]]):
        for domain in DOMAINS:
            index, id2doc, bm25 = self.build_or_load_domain_index(domain, domain_chunks.get(domain.name, []))
            self.domain_indices[domain.name] = {
                'index': index, 'id2doc': id2doc, 'bm25': bm25, 'config': domain
            }
        logger.info(f"✅ Loaded {len(self.domain_indices)} domain indices")

# ============================================================================
# QUERY ROUTER
# ============================================================================

class QueryRouter:
    def __init__(self, embedder: SentenceTransformer, domain_indices: Dict):
        self.embedder = embedder
        self.domain_indices = domain_indices
        self.domain_centroids = self._compute_centroids()
    
    def _compute_centroids(self) -> Dict[str, np.ndarray]:
        centroids = {}
        logger.info("🎯 Computing domain centroids...")
        for domain_name, domain_data in self.domain_indices.items():
            id2doc = domain_data['id2doc']
            sample_docs = random.sample(id2doc, min(500, len(id2doc)))
            embeddings = self.embedder.encode(sample_docs, normalize_embeddings=True, convert_to_numpy=True)
            centroids[domain_name] = embeddings.mean(axis=0)
        return centroids
    
    def route_query(self, query: str, top_k: int = 2) -> List[str]:
        query_emb = self.embedder.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0]
        similarities = {domain: float(np.dot(query_emb, centroid)) for domain, centroid in self.domain_centroids.items()}
        sorted_domains = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
        selected = [d[0] for d in sorted_domains[:top_k]]
        logger.info(f"🧭 Routed to: {selected}")
        return selected

# ============================================================================
# RAG PIPELINE
# ============================================================================

class MultiDomainRAGPipeline:
    def __init__(self, config: RAGConfig, domains: List[DomainConfig]):
        self.config = config
        self.domains = domains
        self.device = device
        
        self.model_manager = ModelManager(config, device)
        self.models = self.model_manager.load_all()
        
        self.data = {}
        self.test_data = {}
        domain_chunks = {}
        
        for domain in domains:
            train_data, test_data = DatasetLoader.load_domain_data(domain)
            self.data[domain.name] = train_data
            self.test_data[domain.name] = test_data
            chunks = TextChunker.create_chunks(train_data, window=config.chunk_window, stride=config.chunk_stride)
            domain_chunks[domain.name] = chunks
        
        self.index_manager = MultiDomainIndexManager(config, self.models['embedder'])
        self.index_manager.load_all_domains(domain_chunks)
        
        self.router = QueryRouter(self.models['embedder'], self.index_manager.domain_indices)
        self.prompts_log = []
        
        logger.info("✅ Multi-domain RAG pipeline initialized")
    
    def generate_hyde_answer(self, query: str) -> str:
        if self.models['hyde_model'] is None:
            return query
        
        prompt = f"Question: {query}\nAnswer:"
        try:
            inputs = self.models['hyde_tokenizer'](prompt, return_tensors="pt", truncation=True, max_length=256).to(self.device)
            with torch.no_grad():
                outputs = self.models['hyde_model'].generate(
                    **inputs, max_new_tokens=self.config.hyde_max_tokens,
                    do_sample=False, pad_token_id=self.models['hyde_tokenizer'].eos_token_id,
                    repetition_penalty=1.15
                )
            text = self.models['hyde_tokenizer'].decode(outputs[0], skip_special_tokens=True)
            hyde = clean_text_artifacts(text.split("Answer:")[-1])
            return hyde if hyde else query
        except:
            return query
    
    def retrieve_from_domain(self, query: str, domain_name: str, k: int) -> List[Tuple[int, float, str]]:
        domain_data = self.index_manager.domain_indices[domain_name]
        index = domain_data['index']
        id2doc = domain_data['id2doc']
        bm25 = domain_data['bm25']
        
        hyde_text = self.generate_hyde_answer(query)
        q_emb = self.models['embedder'].encode([query], normalize_embeddings=True, convert_to_numpy=True).astype('float32')
        h_emb = self.models['embedder'].encode([hyde_text], normalize_embeddings=True, convert_to_numpy=True).astype('float32')
        merged_emb = (1 - self.config.hyde_weight) * q_emb + self.config.hyde_weight * h_emb
        
        D, I = index.search(merged_emb, k)
        faiss_scores = D[0]
        if faiss_scores.max() > faiss_scores.min():
            faiss_norm = (faiss_scores - faiss_scores.min()) / (faiss_scores.max() - faiss_scores.min())
        else:
            faiss_norm = np.ones_like(faiss_scores)
        faiss_map = {int(idx): float(score) for idx, score in zip(I[0], faiss_norm)}
        
        bm25_scores = bm25.get_scores(word_tokenize(query.lower()))
        if bm25_scores.max() > bm25_scores.min():
            bm25_norm = (bm25_scores - bm25_scores.min()) / (bm25_scores.max() - bm25_scores.min())
        else:
            bm25_norm = np.zeros_like(bm25_scores)
        
        candidates = set(I[0].tolist()) | set(np.argsort(bm25_scores)[::-1][:k].tolist())
        merged_scores = []
        for idx in candidates:
            f = faiss_map.get(int(idx), 0.0)
            b = float(bm25_norm[int(idx)]) if int(idx) < len(bm25_norm) else 0.0
            score = self.config.faiss_alpha * f + (1 - self.config.faiss_alpha) * b
            merged_scores.append((int(idx), score, domain_name))
        
        merged_scores.sort(key=lambda x: x[1], reverse=True)
        return merged_scores[:k]
    
    def rerank_candidates(self, query: str, candidates: List[Tuple[int, float, str]]) -> List[Tuple[str, float, str]]:
        texts, metadata = [], []
        for idx, score, domain_name in candidates:
            domain_data = self.index_manager.domain_indices[domain_name]
            text = domain_data['id2doc'][idx]
            texts.append(text)
            metadata.append((idx, domain_name))
        
        reranker_scores = []
        batch_size = 8
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            inputs = self.models['reranker_tokenizer'](
                [query] * len(batch_texts), batch_texts,
                padding=True, truncation=True, max_length=512, return_tensors="pt"
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.models['reranker_model'](**inputs)
                logits = outputs.logits.cpu().numpy()
            
            for lg in logits:
                if lg.shape == ():
                    score = float(lg)
                elif len(lg.shape) == 1 and lg.shape[0] == 1:
                    score = float(lg[0])
                elif len(lg.shape) == 1 and lg.shape[0] == 2:
                    score = float(lg[1])
                else:
                    score = float(np.max(lg))
                reranker_scores.append(score)
        
        reranked = [(texts[i], reranker_scores[i], metadata[i][1]) for i in range(len(texts))]
        reranked.sort(key=lambda x: x[1], reverse=True)
        return reranked[:self.config.rerank_topk]
    
    def generate_answer(self, query: str, contexts: List[Tuple[str, float, str]]) -> str:
        context_parts = [f"[Source {i+1} from {domain}]:\n{text}" 
                        for i, (text, score, domain) in enumerate(contexts[:self.config.context_chunks])]
        context_block = "\n\n".join(context_parts)
        
        prompt = f"""Based on the following medical information, answer the question concisely and accurately.

{context_block}

Question: {query}

Answer:"""
        
        try:
            inputs = self.models['gen_tokenizer'](prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
            with torch.no_grad():
                outputs = self.models['gen_model'].generate(
                    **inputs, max_new_tokens=self.config.max_new_tokens,
                    do_sample=False, pad_token_id=self.models['gen_tokenizer'].eos_token_id,
                    repetition_penalty=1.1
                )
            raw = self.models['gen_tokenizer'].decode(outputs[0], skip_special_tokens=True)
            answer = clean_text_artifacts(raw.split("Answer:")[-1])
            
            self.prompts_log.append({
                "type": "generate", "query": query,
                "contexts": [(t, d) for t, _, d in contexts[:self.config.context_chunks]],
                "prompt": prompt, "raw": raw, "answer": answer, "timestamp": time.time()
            })
            
            return answer if answer else "Insufficient information."
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            return "Error generating answer."
    
    def compute_metrics(self, query: str, answer: str, contexts: List[Tuple[str, float, str]]) -> Dict[str, float]:
        metrics = {}
        
        if contexts:
            retrieval_score = np.mean([score for _, score, _ in contexts[:self.config.context_chunks]])
            metrics['retrieval'] = float(retrieval_score)
        else:
            metrics['retrieval'] = 0.0
        
        try:
            context_texts = [text for text, _, _ in contexts[:self.config.context_chunks]]
            all_keywords = []
            
            if self.models['keyword_extractor']:
                for ctx_text in context_texts:
                    keywords = self.models['keyword_extractor'].extract_keywords(
                        ctx_text, keyphrase_ngram_range=(1, 2), stop_words='english', top_n=5
                    )
                    all_keywords.extend([kw for kw, _ in keywords])
            
            unique_keywords = list(dict.fromkeys([kw.lower() for kw in all_keywords if kw]))
            
            if unique_keywords and answer:
                answer_emb = self.models['embedder'].encode([answer], normalize_embeddings=True, convert_to_tensor=True)
                keyword_embs = self.models['embedder'].encode(unique_keywords, normalize_embeddings=True, convert_to_tensor=True)
                similarities = util.cos_sim(answer_emb, keyword_embs).cpu().numpy()[0]
                covered = (similarities >= self.config.completeness_threshold).sum()
                metrics['completeness'] = float(covered / len(unique_keywords))
            else:
                metrics['completeness'] = 0.0
        except:
            metrics['completeness'] = 0.0
        
        try:
            if answer and contexts:
                answer_sentences = sent_tokenize(answer)
                context_sentences = []
                for text, _, _ in contexts[:self.config.context_chunks]:
                    context_sentences.extend(sent_tokenize(text))
                
                if answer_sentences and context_sentences:
                    ans_embs = self.models['embedder'].encode(answer_sentences, normalize_embeddings=True, convert_to_tensor=True)
                    ctx_embs = self.models['embedder'].encode(context_sentences, normalize_embeddings=True, convert_to_tensor=True)
                    sim_matrix = util.cos_sim(ans_embs, ctx_embs).cpu().numpy()
                    max_sims = np.max(sim_matrix, axis=1)
                    faithful = (max_sims >= self.config.faithfulness_threshold).sum()
                    metrics['faithfulness'] = float(faithful / len(answer_sentences))
                else:
                    metrics['faithfulness'] = 0.0
            else:
                metrics['faithfulness'] = 0.0
        except:
            metrics['faithfulness'] = 0.0
        
        metrics['composite'] = (
            self.config.retrieval_weight * metrics['retrieval'] +
            self.config.completeness_weight * metrics['completeness'] +
            self.config.faithfulness_weight * metrics['faithfulness']
        )
        
        return metrics
    
    def run_query(self, query: str, top_domains: int = 2, log_diagnostics: bool = False) -> Dict[str, Any]:
        logger.info(f"🔍 Processing: {query[:100]}...")
        
        selected_domains = self.router.route_query(query, top_k=top_domains)
        
        all_candidates = []
        for domain_name in selected_domains:
            candidates = self.retrieve_from_domain(query, domain_name, k=self.config.retrieve_k)
            all_candidates.extend(candidates)
        
        if log_diagnostics:
            logger.info(f"Retrieved {len(all_candidates)} candidates from {len(selected_domains)} domains")
        
        reranked = self.rerank_candidates(query, all_candidates)
        
        if log_diagnostics:
            logger.info("Top reranked contexts:")
            for i, (text, score, domain) in enumerate(reranked[:3]):
                logger.info(f"  {i+1}. [{domain}] (score={score:.3f}): {text[:150]}...")
        
        answer = self.generate_answer(query, reranked)
        metrics = self.compute_metrics(query, answer, reranked)
        
        result = {
            "query": query,
            "routed_domains": selected_domains,
            "answer": answer,
            "contexts": [(text, domain) for text, _, domain in reranked[:self.config.context_chunks]],
            "metrics": metrics
        }
        
        return result
    
    def evaluate_batch(self, queries: List[str], log_diagnostics: bool = False) -> Dict[str, Any]:
        logger.info(f"📊 Evaluating {len(queries)} queries...")
        
        results = []
        failed = []
        
        for i, query in enumerate(queries):
            try:
                result = self.run_query(query, log_diagnostics=log_diagnostics)
                results.append(result)
                
                if (i + 1) % 3 == 0:
                    logger.info(f"Progress: {i+1}/{len(queries)}")
                    monitor_memory()
            except Exception as e:
                logger.error(f"Failed query {i}: {e}")
                failed.append((i, query, str(e)))
        
        if not results:
            return {"error": "No successful queries"}
        
        avg_metrics = {
            "retrieval": np.mean([r["metrics"]["retrieval"] for r in results]),
            "completeness": np.mean([r["metrics"]["completeness"] for r in results]),
            "faithfulness": np.mean([r["metrics"]["faithfulness"] for r in results]),
            "composite": np.mean([r["metrics"]["composite"] for r in results])
        }
        
        summary = {
            "total_queries": len(queries),
            "successful": len(results),
            "failed": len(failed),
            "success_rate": len(results) / len(queries),
            "average_metrics": avg_metrics,
            "failed_queries": failed,
            "individual_results": results
        }
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = f"evaluation_{timestamp}.json"
        try:
            with open(results_file, "w") as f:
                json.dump(summary, f, indent=2, default=str)
            logger.info(f"💾 Results saved to {results_file}")
        except:
            pass
        
        return summary

# ============================================================================
# PIPELINE EXECUTION
# ============================================================================

logger.info("="*80)
logger.info("🚀 INITIALIZING MULTI-DOMAIN RAG PIPELINE")
logger.info("="*80)

rag_pipeline = MultiDomainRAGPipeline(config, DOMAINS)

test_queries = [
    "What are the recommended health screenings for women in their 40s?",
    "Explain the symptoms and management of preeclampsia.",
    "What are the early warning signs of Parkinson's disease?",
    "How is PCOS diagnosed and treated?",
    "What are the differences between Type 1 and Type 2 diabetes?"
]

logger.info("\n" + "="*80)
logger.info("🧪 RUNNING SINGLE QUERY TEST WITH DIAGNOSTICS")
logger.info("="*80)

test_query = test_queries[0]
result = rag_pipeline.run_query(test_query, top_domains=2, log_diagnostics=True)

logger.info(f"\n{'='*80}")
logger.info(f"📋 QUERY: {result['query']}")
logger.info(f"{'='*80}")
logger.info(f"🎯 Routed to: {result['routed_domains']}")
logger.info(f"\n✅ ANSWER:\n{result['answer']}")
logger.info(f"\n📊 METRICS:")
for metric_name, value in result['metrics'].items():
    logger.info(f"   {metric_name}: {value:.3f}")

logger.info("\n" + "="*80)
logger.info("📊 RUNNING BATCH EVALUATION")
logger.info("="*80)

batch_results = rag_pipeline.evaluate_batch(test_queries[:3], log_diagnostics=False)

logger.info(f"\n{'='*80}")
logger.info("📈 BATCH EVALUATION SUMMARY")
logger.info(f"{'='*80}")
logger.info(f"Success Rate: {batch_results['success_rate']:.1%}")
logger.info(f"Average Retrieval: {batch_results['average_metrics']['retrieval']:.3f}")
logger.info(f"Average Completeness: {batch_results['average_metrics']['completeness']:.3f}")
logger.info(f"Average Faithfulness: {batch_results['average_metrics']['faithfulness']:.3f}")
logger.info(f"Average Composite: {batch_results['average_metrics']['composite']:.3f}")

logger.info("\n" + "="*80)
logger.info("✅ MULTI-DOMAIN RAG PIPELINE COMPLETE")
logger.info("="*80)

try:
    with open(config.prompts_log, "wb") as f:
        pickle.dump(rag_pipeline.prompts_log, f)
    logger.info(f"📝 Prompt logs saved to {config.prompts_log}")
except:
    pass

monitor_memory()


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/443 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/279 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/801 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/256 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

ValueError: Unrecognized configuration class <class 'transformers.models.t5.configuration_t5.T5Config'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of ArceeConfig, AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BitNetConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CohereConfig, Cohere2Config, CpmAntConfig, CTRLConfig, Data2VecTextConfig, DbrxConfig, DeepseekV3Config, DiffLlamaConfig, Dots1Config, ElectraConfig, Emu3Config, ErnieConfig, FalconConfig, FalconH1Config, FalconMambaConfig, FuyuConfig, GemmaConfig, Gemma2Config, Gemma3Config, Gemma3TextConfig, Gemma3nConfig, Gemma3nTextConfig, GitConfig, GlmConfig, Glm4Config, GotOcr2Config, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, GraniteConfig, GraniteMoeConfig, GraniteMoeHybridConfig, GraniteMoeSharedConfig, HeliumConfig, JambaConfig, JetMoeConfig, LlamaConfig, Llama4Config, Llama4TextConfig, MambaConfig, Mamba2Config, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MiniMaxConfig, MistralConfig, MixtralConfig, MllamaConfig, MoshiConfig, MptConfig, MusicgenConfig, MusicgenMelodyConfig, MvpConfig, NemotronConfig, OlmoConfig, Olmo2Config, OlmoeConfig, OpenLlamaConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PersimmonConfig, PhiConfig, Phi3Config, Phi4MultimodalConfig, PhimoeConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, Qwen2Config, Qwen2MoeConfig, Qwen3Config, Qwen3MoeConfig, RecurrentGemmaConfig, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, RwkvConfig, SmolLM3Config, Speech2Text2Config, StableLmConfig, Starcoder2Config, TransfoXLConfig, TrOCRConfig, WhisperConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig, ZambaConfig, Zamba2Config.