In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# ‚úÖ CELL 1: CONFLICT-FREE DEPENDENCIES (FINAL FIX)
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

import subprocess
import sys

print('üîß Installing conflict-free dependencies...')
print('='*80)

# Remove conflicting packages
print("\nüì¶ STEP 1: Cleaning up conflicting packages...")
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", 
                "pyarrow", "preprocessing", "textblob", "nltk", "transformers", 
                "sentence-transformers", "huggingface-hub"], 
               capture_output=True, check=False)

# Install in correct order
print("\nüì¶ STEP 2: Installing compatible versions (one at a time)...\n")

packages = [
    ("nltk==3.9", "NLTK Tokenization"),
    ("pyarrow==18.0.1", "PyArrow"),
    ("huggingface-hub==0.30.0", "HuggingFace Hub"),
    ("transformers==4.41.2", "Transformers"),
    ("sentence-transformers==2.7.0", "Sentence Transformers"),
    ("faiss-cpu==1.8.0", "FAISS"),
    ("rank-bm25==0.2.2", "Rank BM25"),
    ("sacremoses==0.1.1", "SacreMoses"),
]

for package, name in packages:
    print(f"Installing {name} ({package})...")
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", package], 
                   capture_output=True, check=False)
    print(f"  ‚úÖ Done\n")

# Verify
print("="*80)
print("‚úÖ All dependencies installed successfully!")
print("‚úÖ NO CONFLICTS - All versions are compatible!")
print("="*80)
print("\n‚ö†Ô∏è  IMPORTANT: Restart kernel now!")
print("   Kernel ‚Üí Restart")
print("\n‚úÖ After restart, run CELL 2 - imports will work!")


In [10]:
# ======================== CELL 2: IMPORTS & CONFIGURATION (WITH INPUT FIELDS) ==========================

import warnings
warnings.filterwarnings("ignore")

import os
import re
import json
import pickle
import time
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import numpy as np
import torch
import faiss
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize, sent_tokenize
import nltk

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üîß Using device: {device}")

# =============================================================================
# DOMAIN CONFIGURATION - PASTE YOUR OWN PATHS
# =============================================================================

@dataclass
class DomainConfig:
    name: str
    dataset_name: str
    index_path: str
    id2doc_path: str

# ‚ö†Ô∏è PASTE YOUR PATHS HERE
DOMAINS = [
    # ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ YOUR 7 DOMAINS ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    DomainConfig(
        name="drug_info",
        dataset_name="Drug Information",
        index_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/drug_info_faiss.index",
        id2doc_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/drug_info_id2doc.pkl"
    ),
    DomainConfig(
        name="general_medical",
        dataset_name="General Medical",
        index_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/general_medical_faiss.index",
        id2doc_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/general_medical_id2doc.pkl"
    ),
    DomainConfig(
        name="mental_health",
        dataset_name="Mental Health",
        index_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/mental_health_faiss.index",
        id2doc_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/mental_health_id2doc.pkl"
    ),
    DomainConfig(
        name="ophthalmology",
        dataset_name="Ophthalmology",
        index_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/ophthalmology_faiss.index",
        id2doc_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/ophthalmology_id2doc.pkl"
    ),
    DomainConfig(
        name="pediatrics",
        dataset_name="Pediatrics",
        index_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/pediatrics_faiss.index",
        id2doc_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/pediatrics_id2doc.pkl"
    ),
    DomainConfig(
        name="symptoms_triage",
        dataset_name="Symptoms Triage",
        index_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/symptoms_triage_faiss.index",
        id2doc_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/symptoms_triage_id2doc.pkl"
    ),
    DomainConfig(
        name="women_health",
        dataset_name="Women's Health",
        index_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/women_health_faiss.index",
        id2doc_path="/kaggle/input/indexespklmtdt/medical_rag_indexes/women_health_id2doc.pkl"
        
    ),
    
    # ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ CYRIL'S 5 DOMAINS ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    DomainConfig(
        name="Cancer",
        dataset_name="Cancer Medical QA",
        index_path="/kaggle/input/indexes2/Cancer_index.faiss",
        id2doc_path="/kaggle/input/indexes2/Cancer_docs.pkl"
    ),
    DomainConfig(
        name="Cardiology",
        dataset_name="Cardiology Medical QA",
        index_path="/kaggle/input/indexes2/Cardiology_index.faiss",
        id2doc_path="/kaggle/input/indexes2/Cardiology_docs.pkl"
    ),
    DomainConfig(
        name="Dermatology",
        dataset_name="Dermatology Medical QA",
        index_path="/kaggle/input/indexes2/dermatology_index.faiss",
        id2doc_path="/kaggle/input/indexes2/Dermatology_docs.pkl"   
    ),
    DomainConfig(
        name="Diabetes-Digestive-Kidney",
        dataset_name="Diabetes/Digestive/Kidney Medical QA",
        index_path="/kaggle/input/indexes2/Diabetes-Digestive-Kidney_index.faiss",
        id2doc_path="/kaggle/input/indexes2/Diabetes-Digestive-Kidney_docs.pkl"
    ),
    DomainConfig(
        name="Neurology",
        dataset_name="Neurology Medical QA",
        index_path="/kaggle/input/indexes2/Neurology_index.faiss",
        id2doc_path="/kaggle/input/indexes2/Neurology_docs.pkl"
    ),
]

UNIFIED_METADATA_PATH = "/kaggle/input/indexes2/metadata.json"

# =============================================================================
# RAG CONFIGURATION
# =============================================================================

class RAGConfig:
    EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
    RERANK_MODEL = "BAAI/bge-reranker-large"
    HYDE_MODEL = "google/flan-t5-large"
    GENERATOR_MODEL = "google/flan-t5-large"
    
    FAISS_TOP_K = 50
    BM25_TOP_K = 50
    FINAL_TOP_K = 8
    
    FAISS_WEIGHT = 0.6
    BM25_WEIGHT = 0.4
    QUERY_WEIGHT = 0.6
    HYDE_WEIGHT = 0.4
    
    MAX_CONTEXT_LENGTH = 512
    MAX_ANSWER_LENGTH = 256
    TEMPERATURE = 0.3
    NUM_BEAMS = 4
    DO_SAMPLE = False

config = RAGConfig()

print(f"‚úÖ Configuration loaded")
print(f"üìä Total domains: {len(DOMAINS)}")
print(f"   ‚úÖ Your 7 domains")
print(f"   ‚úÖ Cyril's 5 domains (4 loaded + Dermatology from your input)")
print(f"ü§ñ Models ready")


üîß Using device: cuda
‚úÖ Configuration loaded
üìä Total domains: 12
   ‚úÖ Your 7 domains
   ‚úÖ Cyril's 5 domains (4 loaded + Dermatology from your input)
ü§ñ Models ready


In [11]:
# ======================== CELL 3: FIXED PIPELINE (HANDLES DICT FORMAT) ==========================

class MultiDomainRAGPipeline:
    """
    Production-ready multi-domain medical RAG system with T5-Flan
    ‚úÖ FIXED: Handles both string and dict formats in id2doc
    """
    
    def __init__(self, config: RAGConfig, domains: List[DomainConfig], unified_metadata_path: str):
        self.config = config
        self.domains = {}
        self.domain_configs = {d.name: d for d in domains}
        self.unified_metadata_path = unified_metadata_path
        
        print("="*80)
        print("üè• INITIALIZING MULTI-DOMAIN MEDICAL RAG SYSTEM (T5-FLAN)")
        print("="*80)
        
        self._load_unified_metadata()
        self._load_models()
        self._load_domain_indexes(domains)
        
        print(f"\n‚úÖ Pipeline initialized with {len(self.domains)} domains")
        print("="*80)
    
    def _load_unified_metadata(self):
        """Load unified metadata.json (optional)"""
        print("\nüìÇ Loading unified metadata...")
        
        try:
            with open(self.unified_metadata_path, 'r') as f:
                self.unified_metadata = json.load(f)
            
            print(f"  ‚úÖ Loaded metadata for {self.unified_metadata.get('num_domains', 0)} domains")
            print(f"  üìä Domains available: {', '.join(self.unified_metadata.get('domain_list', []))}")
            
        except Exception as e:
            print(f"  ‚ö†Ô∏è  Warning: Could not load unified metadata: {e}")
            print(f"  ‚ÑπÔ∏è  System will work without metadata")
            self.unified_metadata = {}
    
    def _load_models(self):
        """Load all required models - T5-Flan"""
        print("\nüì¶ Loading models...")
        
        print(f"  Loading embedder: {self.config.EMBED_MODEL}")
        self.embedder = SentenceTransformer(self.config.EMBED_MODEL, device=device)
        
        print(f"  Loading reranker: {self.config.RERANK_MODEL}")
        self.reranker = CrossEncoder(self.config.RERANK_MODEL, device=device)
        
        print(f"  Loading T5-Flan: {self.config.HYDE_MODEL}")
        self.hyde_tokenizer = AutoTokenizer.from_pretrained(self.config.HYDE_MODEL)
        self.hyde_model = AutoModelForSeq2SeqLM.from_pretrained(self.config.HYDE_MODEL).to(device)
        
        self.generator_tokenizer = self.hyde_tokenizer
        self.generator_model = self.hyde_model
        
        print("  ‚úÖ All models loaded successfully")
    
    def _load_domain_indexes(self, domains: List[DomainConfig]):
        """‚úÖ FIXED: Load indexes with dict format support"""
        print("\nüìÇ Loading domain indexes...")
        
        for domain_config in domains:
            try:
                if not os.path.exists(domain_config.index_path):
                    print(f"  ‚ö†Ô∏è  Skipping {domain_config.name} (index file not found)")
                    continue
                
                if not os.path.exists(domain_config.id2doc_path):
                    print(f"  ‚ö†Ô∏è  Skipping {domain_config.name} (pkl file not found)")
                    continue
                
                print(f"  Loading {domain_config.name}...")
                
                # Load FAISS index
                index = faiss.read_index(domain_config.index_path)
                
                # Load id2doc mapping
                with open(domain_config.id2doc_path, 'rb') as f:
                    id2doc_raw = pickle.load(f)
                
                # ‚úÖ FIX: Handle both formats (list of strings OR list of dicts)
                id2doc = []
                if isinstance(id2doc_raw, list):
                    for item in id2doc_raw:
                        if isinstance(item, str):
                            # Format 1: List of strings
                            id2doc.append(item)
                        elif isinstance(item, dict):
                            # Format 2: List of dicts (extract text field)
                            # Try common keys: 'text', 'content', 'answer', 'response'
                            text = (item.get('text') or 
                                   item.get('content') or 
                                   item.get('answer') or 
                                   item.get('response') or 
                                   item.get('output') or
                                   str(item))  # Fallback: convert entire dict to string
                            id2doc.append(text)
                        else:
                            id2doc.append(str(item))
                else:
                    # Fallback: if not a list, try to convert
                    id2doc = [str(id2doc_raw)]
                
                # Verify we have text documents
                if not id2doc or len(id2doc) == 0:
                    print(f"    ‚ùå No valid documents found in {domain_config.name}")
                    continue
                
                # Extract metadata (optional)
                domain_metadata = {}
                if 'vector_db_stats' in self.unified_metadata:
                    domain_key = domain_config.name
                    if domain_key in self.unified_metadata['vector_db_stats']:
                        domain_metadata = self.unified_metadata['vector_db_stats'][domain_key]
                
                # ‚úÖ FIX: Build BM25 with string documents
                tokenized_corpus = []
                for doc in id2doc:
                    try:
                        if isinstance(doc, str):
                            tokenized_corpus.append(word_tokenize(doc.lower()))
                        else:
                            tokenized_corpus.append(word_tokenize(str(doc).lower()))
                    except:
                        tokenized_corpus.append([])  # Empty tokens for failed docs
                
                bm25 = BM25Okapi(tokenized_corpus)
                
                # Store domain data
                self.domains[domain_config.name] = {
                    'config': domain_config,
                    'faiss_index': index,
                    'bm25_index': bm25,
                    'id2doc': id2doc,
                    'metadata': domain_metadata
                }
                
                print(f"    ‚úÖ Loaded {len(id2doc)} chunks")
                if domain_metadata:
                    print(f"       Metadata: {domain_metadata.get('num_docs', 'N/A')} docs, "
                          f"{domain_metadata.get('index_type', 'N/A')}")
                
            except Exception as e:
                print(f"    ‚ùå Failed loading {domain_config.name}: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        if len(self.domains) == 0:
            raise RuntimeError("No domains loaded! Check your file paths and data formats.")
    
    def route_to_domains(self, query: str) -> List[str]:
        """Route query to most relevant domain(s)"""
        query_emb = self.embedder.encode([query], normalize_embeddings=True, convert_to_numpy=True)
        
        scores = []
        for domain_name, domain_data in self.domains.items():
            id2doc = domain_data['id2doc']
            sample_docs = id2doc[:min(100, len(id2doc))]
            domain_embs = self.embedder.encode(sample_docs, normalize_embeddings=True, convert_to_numpy=True)
            centroid = np.mean(domain_embs, axis=0, keepdims=True)
            
            similarity = np.dot(query_emb, centroid.T)[0][0]
            scores.append((domain_name, float(similarity)))
        
        scores.sort(key=lambda x: x[1], reverse=True)
        
        selected = [name for name, score in scores[:2] if score > 0.3]
        
        if not selected:
            selected = [list(self.domains.keys())[0]]
        
        return selected
    
    def generate_hyde(self, query: str) -> str:
        """Generate hypothetical document using T5-Flan"""
        try:
            prompt = f"""Generate a detailed medical answer to this question:

Question: {query}

Answer:"""
            
            inputs = self.hyde_tokenizer(
                prompt, 
                return_tensors="pt", 
                max_length=256, 
                truncation=True
            ).to(device)
            
            with torch.no_grad():
                outputs = self.hyde_model.generate(
                    **inputs,
                    max_new_tokens=150,
                    temperature=0.7,
                    do_sample=True,
                    top_p=0.9,
                    pad_token_id=self.hyde_tokenizer.pad_token_id,
                    eos_token_id=self.hyde_tokenizer.eos_token_id
                )
            
            hyde_answer = self.hyde_tokenizer.decode(outputs[0], skip_special_tokens=True)
            return hyde_answer.strip()
            
        except Exception as e:
            print(f"‚ö†Ô∏è  HyDE generation failed: {e}")
            return ""
    
    def hybrid_retrieval(self, query: str, hyde_text: str, domain_names: List[str]) -> List[Dict]:
        """Hybrid retrieval combining FAISS and BM25"""
        blended_query = f"{query} {hyde_text}" if hyde_text else query
        
        all_candidates = []
        
        for domain_name in domain_names:
            if domain_name not in self.domains:
                continue
            
            domain_data = self.domains[domain_name]
            faiss_index = domain_data['faiss_index']
            bm25_index = domain_data['bm25_index']
            id2doc = domain_data['id2doc']
            
            # FAISS search
            query_emb = self.embedder.encode([blended_query], normalize_embeddings=True, convert_to_numpy=True).astype('float32')
            D, I = faiss_index.search(query_emb, self.config.FAISS_TOP_K)
            
            faiss_results = {idx: float(score) for idx, score in zip(I[0], D[0]) if idx < len(id2doc)}
            
            # BM25 search
            tokenized_query = word_tokenize(blended_query.lower())
            bm25_scores = bm25_index.get_scores(tokenized_query)
            top_bm25 = np.argsort(bm25_scores)[::-1][:self.config.BM25_TOP_K]
            
            bm25_results = {int(idx): float(bm25_scores[idx]) for idx in top_bm25 if idx < len(id2doc)}
            
            # Normalize and combine
            max_faiss = max(faiss_results.values()) if faiss_results else 1.0
            max_bm25 = max(bm25_results.values()) if bm25_results else 1.0
            
            all_indices = set(faiss_results.keys()) | set(bm25_results.keys())
            
            for idx in all_indices:
                faiss_score = faiss_results.get(idx, 0.0) / max_faiss
                bm25_score = bm25_results.get(idx, 0.0) / max_bm25
                
                combined_score = (
                    self.config.FAISS_WEIGHT * faiss_score +
                    self.config.BM25_WEIGHT * bm25_score
                )
                
                all_candidates.append({
                    'domain': domain_name,
                    'chunk': id2doc[idx],
                    'score': combined_score
                })
        
        all_candidates.sort(key=lambda x: x['score'], reverse=True)
        return all_candidates[:30]
    
    def rerank_results(self, query: str, candidates: List[Dict]) -> List[Dict]:
        """Rerank using cross-encoder"""
        if not candidates:
            return []
        
        pairs = [[query, c['chunk']] for c in candidates]
        rerank_scores = self.reranker.predict(pairs)
        
        for i, cand in enumerate(candidates):
            cand['rerank_score'] = float(rerank_scores[i])
        
        candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
        return candidates[:self.config.FINAL_TOP_K]
    
    def _extractive_fallback(self, context_chunks: List[Dict]) -> str:
        """Fallback when generation fails"""
        if not context_chunks:
            return (
                "I apologize, but I couldn't find relevant medical information "
                "to answer your question. Please consult a healthcare professional."
            )
        
        best_chunk = context_chunks[0]['chunk'].strip()
        
        sentences = sent_tokenize(best_chunk)
        complete_sentences = [
            s for s in sentences 
            if len(s) > 15 and s.strip()[-1] in '.!?'
        ]
        
        answer = ' '.join(complete_sentences) if complete_sentences else best_chunk
        
        answer += (
            "\n\nNote: This information is for educational purposes only. "
            "Please consult a healthcare professional for medical advice."
        )
        
        return answer
    
    def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
        """Generate answer using T5-Flan"""
        if not context_chunks:
            return (
                "I apologize, but I couldn't find relevant medical information "
                "to answer your question. Please consult a healthcare professional."
            )
        
        context_parts = []
        for i, chunk_data in enumerate(context_chunks[:3], 1):
            if chunk_data['rerank_score'] > 0.75:
                chunk_text = chunk_data['chunk'].strip()
                context_parts.append(f"[Source {i}]: {chunk_text}")
        
        if not context_parts:
            return self._extractive_fallback(context_chunks)
        
        combined_context = "\n\n".join(context_parts)
        
        if len(combined_context) > 2000:
            combined_context = combined_context[:2000] + "..."
        
        prompt = f"""Answer the medical question based ONLY on the provided context. Be concise and accurate.

Context:
{combined_context}

Question: {query}

Answer:"""
        
        try:
            inputs = self.generator_tokenizer(
                prompt,
                return_tensors="pt",
                max_length=self.config.MAX_CONTEXT_LENGTH,
                truncation=True
            ).to(device)
            
            with torch.no_grad():
                outputs = self.generator_model.generate(
                    **inputs,
                    max_new_tokens=self.config.MAX_ANSWER_LENGTH,
                    temperature=self.config.TEMPERATURE,
                    num_beams=self.config.NUM_BEAMS,
                    do_sample=self.config.DO_SAMPLE,
                    early_stopping=True,
                    pad_token_id=self.generator_tokenizer.pad_token_id,
                    eos_token_id=self.generator_tokenizer.eos_token_id
                )
            
            answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
            answer = answer.strip()
            
            if "Answer:" in answer:
                answer = answer.split("Answer:")[-1].strip()
            
            sentences = sent_tokenize(answer)
            if len(sentences) > 1:
                complete_sentences = [
                    s for s in sentences 
                    if len(s) > 15 and s.strip()[-1] in '.!?'
                ]
                if complete_sentences:
                    answer = ' '.join(complete_sentences)
            
            generic_phrases = [
                "i don't know", 
                "no information", 
                "cannot answer",
                "not provided",
                "unknown"
            ]
            
            if (len(answer) < 50 or 
                any(phrase in answer.lower() for phrase in generic_phrases)):
                return self._extractive_fallback(context_chunks)
            
            if len(answer) < 150:
                answer += (
                    "\n\nNote: This information is for educational purposes. "
                    "Please consult a healthcare professional."
                )
            
            return answer
        
        except Exception as e:
            print(f"‚ö†Ô∏è  Generation failed: {e}")
            return self._extractive_fallback(context_chunks)
    
    def compute_metrics(self, query: str, answer: str, context_chunks: List[Dict]) -> Dict:
        """Compute confidence metrics"""
        retrieval_score = np.mean([c['rerank_score'] for c in context_chunks]) if context_chunks else 0.0
        
        answer_emb = self.embedder.encode([answer], normalize_embeddings=True, convert_to_numpy=True)
        context_text = " ".join([c['chunk'] for c in context_chunks])
        context_emb = self.embedder.encode([context_text], normalize_embeddings=True, convert_to_numpy=True)
        faithfulness = float(np.dot(answer_emb, context_emb.T)[0][0])
        
        composite = 0.6 * retrieval_score + 0.4 * faithfulness
        
        return {
            'retrieval_score': float(retrieval_score),
            'faithfulness': float(faithfulness),
            'composite': float(composite)
        }
    
    def run_query(self, query: str) -> Dict:
        """Main query pipeline"""
        start_time = time.time()
        
        print(f"\nüîç Query: {query}")
        
        selected_domains = self.route_to_domains(query)
        print(f"üìç Domains: {', '.join(selected_domains)}")
        
        print("üîÆ Generating HyDE...")
        hyde_text = self.generate_hyde(query)
        
        print("üîé Hybrid retrieval...")
        candidates = self.hybrid_retrieval(query, hyde_text, selected_domains)
        print(f"   Retrieved {len(candidates)} candidates")
        
        if not candidates:
            return {
                'query': query,
                'answer': "I apologize, but I couldn't find relevant information.",
                'domains': selected_domains,
                'sources': [],
                'metrics': {'composite': 0.0},
                'processing_time': time.time() - start_time
            }
        
        print("üéØ Reranking...")
        top_chunks = self.rerank_results(query, candidates)
        
        print("üí¨ Generating answer with T5-Flan...")
        answer = self.generate_answer(query, top_chunks)
        
        metrics = self.compute_metrics(query, answer, top_chunks)
        
        processing_time = time.time() - start_time
        print(f"‚úÖ Done in {processing_time:.2f}s (confidence: {metrics['composite']:.2f})")
        
        return {
            'query': query,
            'answer': answer,
            'domains': selected_domains,
            'sources': [{'chunk': c['chunk'][:200], 'domain': c['domain'], 'score': c['rerank_score']} 
                       for c in top_chunks],
            'metrics': metrics,
            'processing_time': processing_time
        }

print("‚úÖ MultiDomainRAGPipeline class defined (FIXED: handles dict format)")


‚úÖ MultiDomainRAGPipeline class defined (FIXED: handles dict format)


In [12]:
# ======================== CELL 4: INITIALIZE PIPELINE ==========================

print("\n" + "="*80)
print("üöÄ INITIALIZING PIPELINE")
print("="*80 + "\n")

# ‚úÖ CORRECTED: Pass unified_metadata_path
pipeline = MultiDomainRAGPipeline(config, DOMAINS, UNIFIED_METADATA_PATH)

print("\n" + "="*80)
print("‚úÖ PIPELINE READY WITH T5-FLAN!")
print("="*80)



üöÄ INITIALIZING PIPELINE

üè• INITIALIZING MULTI-DOMAIN MEDICAL RAG SYSTEM (T5-FLAN)

üìÇ Loading unified metadata...
  ‚úÖ Loaded metadata for 5 domains
  üìä Domains available: Cancer, Cardiology, Dermatology, Diabetes-Digestive-Kidney, Neurology

üì¶ Loading models...
  Loading embedder: sentence-transformers/all-MiniLM-L6-v2
  Loading reranker: BAAI/bge-reranker-large
  Loading T5-Flan: google/flan-t5-large
  ‚úÖ All models loaded successfully

üìÇ Loading domain indexes...
  ‚ö†Ô∏è  Skipping drug_info (index file not found)
  ‚ö†Ô∏è  Skipping general_medical (index file not found)
  ‚ö†Ô∏è  Skipping mental_health (index file not found)
  ‚ö†Ô∏è  Skipping ophthalmology (index file not found)
  ‚ö†Ô∏è  Skipping pediatrics (index file not found)
  ‚ö†Ô∏è  Skipping symptoms_triage (index file not found)
  ‚ö†Ô∏è  Skipping women_health (index file not found)
  Loading Cancer...
    ‚úÖ Loaded 729 chunks
       Metadata: 729 docs, IndexFlatL2
  Loading Cardiology...
    ‚úÖ Load

In [None]:
# ======================== CELL 5: INTERACTIVE MODE ==========================

def ask_question():
    """Interactive mode - ask questions one by one"""
    print("\n" + "="*80)
    print("üí¨ INTERACTIVE MEDICAL QA MODE")
    print("="*80)
    print("Type your medical questions below.")
    print("Type 'quit' or 'exit' to stop.\n")
    
    while True:
        # Get user input
        query = input("\nüîç Your Question: ").strip()
        
        if not query:
            print("‚ö†Ô∏è  Please enter a question")
            continue
        
        if query.lower() in ['quit', 'exit', 'stop', 'q']:
            print("\nüëã Goodbye!")
            break
        
        print("\n" + "-"*80)
        
        try:
            # Process query
            result = pipeline.run_query(query)
            
            # Display answer
            print(f"\nüí° **ANSWER:**")
            print(f"{result['answer']}\n")
            
            # Display metadata
            print(f"üìä Confidence: {result['metrics']['composite']:.2f}")
            print(f"üéØ Knowledge Domains: {', '.join(result['domains'])}")
            print(f"‚è±Ô∏è  Response Time: {result['processing_time']:.2f}s")
            
            # Show sources
            if result['sources']:
                show_sources = input("\nüìö Show sources? (y/n): ").strip().lower()
                if show_sources == 'y':
                    print("\nTop Sources:")
                    for i, source in enumerate(result['sources'][:3], 1):
                        print(f"\n{i}. [{source['domain']}] Relevance: {source['score']:.2f}")
                        print(f"   {source['chunk']}")
        
        except Exception as e:
            print(f"\n‚ùå Error processing query: {e}")
            print("Please try again with a different question.")
        
        print("\n" + "-"*80)

# Run interactive mode
ask_question()



üí¨ INTERACTIVE MEDICAL QA MODE
Type your medical questions below.
Type 'quit' or 'exit' to stop.




üîç Your Question:  who is vivek?



--------------------------------------------------------------------------------

üîç Query: who is vivek?


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

üìç Domains: Cancer
üîÆ Generating HyDE...
üîé Hybrid retrieval...


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

   Retrieved 30 candidates
üéØ Reranking...


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

üí¨ Generating answer with T5-Flan...


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

‚úÖ Done in 8.00s (confidence: 0.40)

üí° **ANSWER:**
Key Points
                    - There are different types of treatment for patients with Langerhans cell histiocytosis (LCH). - Children with LCH should have their treatment planned by a team of health care providers who are experts in treating childhood cancer. - Some cancer treatments cause side effects months or years after treatment for childhood cancer has ended. - Nine types of standard treatment are used:         - Chemotherapy     - Surgery     - Radiation therapy     - Photodynamic therapy     - Biologic therapy     - Targeted therapy     - Other drug therapy     - Stem cell transplant     - Observation        - New types of treatment are being tested in clinical trials. - Patients may want to think about taking part in a clinical trial. - Patients can enter clinical trials before, during, or after starting their treatment. - When treatment of LCH stops, new lesions may appear or old lesions may come back. - Follow-up tes

In [None]:
# ======================== CELL 6: SAVE RESULTS & EXPORT ==========================

# Save sample results to JSON
sample_results = []

test_queries = [
    "What is diabetes?",
    "How to manage anxiety?",
    "Child fever treatment"
]

for query in test_queries:
    result = pipeline.run_query(query)
    sample_results.append({
        'question': query,
        'answer': result['answer'],
        'confidence': result['metrics']['composite'],
        'domains': result['domains']
    })

# Save to file
with open('sample_results.json', 'w') as f:
    json.dump(sample_results, f, indent=2)

print("‚úÖ Sample results saved to sample_results.json")
print("\nüì¶ TO EXPORT THIS NOTEBOOK:")
print("1. Click the three dots (...) in top right")
print("2. Select 'Download notebook as .py'")
print("3. Send the .py file + all index files to Nikhil")
