<a href="https://colab.research.google.com/github/abkimc/Medical-RAG./blob/main/Medical_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
 !pip install gradio transformers torch faiss-cpu sentence-transformers accelerate bitsandbytes

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import gradio as gr
import re
import warnings
from typing import List, Dict, Tuple
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from huggingface_hub import login
import os
import gc

warnings.filterwarnings('ignore')

# ============================================================================
# GPU MEMORY MANAGEMENT AND DEVICE SETUP
# ============================================================================

def setup_device():
    """Setup device with proper error handling and memory management."""
    if torch.cuda.is_available():
        # Clear GPU cache
        torch.cuda.empty_cache()
        gc.collect()

        # Check GPU memory
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"🔧 GPU detected: {torch.cuda.get_device_name(0)} ({gpu_memory:.1f}GB)")

        # Set memory fraction to prevent OOM
        torch.cuda.set_per_process_memory_fraction(0.8)

        device = "cuda"
        print("✅ Using CUDA device")
    else:
        device = "cpu"
        print("⚠️ CUDA not available, using CPU")

    return device

def clear_gpu_memory():
    """Clear GPU memory."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

# ============================================================================
# DATA LOADING AND PREPROCESSING
# ============================================================================

def load_and_preprocess_data() -> pd.DataFrame:
    """Loads and preprocesses data from the Clalit URL."""
    url = "https://raw.githubusercontent.com/abkimc/Scraping-Clalit-Webpages/refs/heads/main/Q%26A_Clalit.csv"

    try:
        print(f"🔄 Attempting to load data from URL: {url}")
        df = pd.read_csv(url, encoding='utf-8')
        print(f"✅ Successfully loaded data. Original columns: {df.columns.tolist()}")
        df = df.rename(columns={'question': 'question_he', 'answer': 'answer_he'})
    except Exception as e:
        print(f"❌ Fatal error loading data from URL: {e}")
        raise

    if 'question_he' not in df.columns or 'answer_he' not in df.columns:
        raise KeyError("Could not find required columns after renaming. Check CSV headers.")

    df['question_he'] = df['question_he'].astype(str).str.strip()
    df['answer_he'] = df['answer_he'].astype(str).str.strip()
    df = df.dropna(subset=['question_he', 'answer_he'])
    df = df[df['question_he'] != '']
    if 'subject' not in df.columns:
        df['subject'] = 'כללי'
    else:
        df['subject'] = df['subject'].astype(str).str.strip().fillna('כללי')

    print(f"📊 Processed {len(df)} medical Q&A pairs.")
    return df

def preprocess_hebrew_text(text: str) -> str:
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\u0590-\u05FF\s\d\.,!?;:\-()]', '', text)
    return text.strip()

def remove_hebrew_stop_words(text: str, stop_words: set) -> str:
    """Removes Hebrew stop words from a text."""
    words = text.split()
    filtered_words = [word for word in words if word not in stop_words]
    return ' '.join(filtered_words)

# ============================================================================
# RAG COMPONENTS
# ============================================================================

class HebrewEmbeddingModel:
    def __init__(self, model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
        print(f"🔄 Loading embedding model: {model_name}")
        self.device = setup_device()

        try:
            # Load model with explicit device handling
            self.model = SentenceTransformer(model_name, device=self.device)
            self.dimension = self.model.get_sentence_embedding_dimension()
            print(f"✅ Embedding model loaded (dimension: {self.dimension}) on {self.device}")
        except Exception as e:
            print(f"⚠️ GPU loading failed, falling back to CPU: {e}")
            clear_gpu_memory()
            self.device = "cpu"
            self.model = SentenceTransformer(model_name, device="cpu")
            self.dimension = self.model.get_sentence_embedding_dimension()
            print(f"✅ Embedding model loaded (dimension: {self.dimension}) on CPU")

    def encode(self, texts: List[str], show_progress: bool = True) -> np.ndarray:
        try:
            return self.model.encode(texts, show_progress_bar=show_progress, convert_to_numpy=True)
        except Exception as e:
            print(f"⚠️ Encoding error on {self.device}, retrying on CPU: {e}")
            if self.device != "cpu":
                clear_gpu_memory()
                self.model = self.model.to("cpu")
                self.device = "cpu"
            return self.model.encode(texts, show_progress_bar=show_progress, convert_to_numpy=True)

class FAISSVectorStore:
    def __init__(self, dimension: int):
        self.dimension = dimension
        self.index = faiss.IndexFlatIP(self.dimension)

    def add_embeddings(self, embeddings: np.ndarray):
        faiss.normalize_L2(embeddings)
        self.index.add(embeddings.astype('float32'))
        print(f"✅ Added {embeddings.shape[0]} embeddings to FAISS index")

    def search(self, query_embedding: np.ndarray, top_k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
        faiss.normalize_L2(query_embedding)
        return self.index.search(query_embedding.astype('float32'), top_k)

class HebrewLLM:
    def __init__(self, model_name: str = "google/medgemma-4b-it", use_quantization: bool = True):
        print(f"🔄 Loading language model: {model_name}")
        self.model_name = model_name
        self.device = setup_device()

        # Determine if quantization should be used
        use_quant = use_quantization and torch.cuda.is_available()

        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True
        ) if use_quant else None

        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # Load model with error handling
            if use_quant:
                print("🔄 Loading model with 4-bit quantization...")
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    quantization_config=quant_config,
                    device_map="auto",
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16
                )
            else:
                print("🔄 Loading model without quantization...")
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    device_map="auto" if torch.cuda.is_available() else None,
                    trust_remote_code=True,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
                )

            self.model.eval()
            print(f"✅ Language model loaded on {self.device}")

        except Exception as e:
            print(f"❌ Error loading model with GPU/quantization: {e}")
            print("🔄 Falling back to CPU without quantization...")
            clear_gpu_memory()

            # Fallback to CPU
            self.device = "cpu"
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                device_map=None,
                torch_dtype=torch.float32,
                trust_remote_code=True
            )
            self.model.to("cpu")
            self.model.eval()
            print(f"✅ Language model loaded on CPU")

    def generate_response(self, prompt: str, max_length: int = 500) -> str:
        try:
            messages = [
                {"role": "system", "content": "You are a helpful medical assistant. You must answer in Hebrew."},
                {"role": "user", "content": prompt}
            ]

            # Handle chat template safely
            try:
                full_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            except:
                # Fallback if chat template not supported
                full_prompt = f"System: You are a helpful medical assistant. You must answer in Hebrew.\nUser: {prompt}\nAssistant:"

            # Tokenize with proper device handling
            inputs = self.tokenizer.encode(
                full_prompt,
                return_tensors="pt",
                truncation=True,
                max_length=2048
            )

            # Move to correct device
            if hasattr(self.model, 'device'):
                inputs = inputs.to(self.model.device)
            elif self.device == "cuda" and torch.cuda.is_available():
                inputs = inputs.to("cuda")
            else:
                inputs = inputs.to("cpu")

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_new_tokens=max_length,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.eos_token_id,
                    attention_mask=torch.ones_like(inputs)
                )

            response = self.tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
            return response.strip()

        except Exception as e:
            print(f"⚠️ Generation error: {e}")
            return "מצטער, אירעה שגיאה ביצירת התשובה. אנא נסה שוב."

class HebrewMedicalRAG:
    def __init__(self):
        print("🚀 Initializing Hebrew Medical RAG System...")

        # Clear memory before starting
        clear_gpu_memory()

        # Initialize components with error handling
        self.df = load_and_preprocess_data()

        print("🔄 Initializing embedding model...")
        self.embedding_model = HebrewEmbeddingModel()

        print("🔄 Creating embeddings for retrieval...")
        self.embeddings = self.embedding_model.encode(self.df['question_he'].tolist())

        print("🔄 Building vector store...")
        self.vector_store = FAISSVectorStore(self.embedding_model.dimension)
        self.vector_store.add_embeddings(self.embeddings)

        print("🔄 Initializing language model...")
        self.llm = HebrewLLM()

        self.stop_words = set(['של', 'עם', 'את', 'על', 'כן', 'לא', 'אם', 'אני', 'אתה', 'הוא', 'היא', 'אנחנו', 'אתם', 'הם', 'הן', 'מה', 'מי', 'איך', 'כמה', 'למה', 'מתי', 'איפה', 'אבל', 'או', 'גם', 'ו', 'ב', 'ל', 'מ', 'ש'])

        print("✅ RAG System initialized successfully!")

    def custom_distance_metric(self, query: str, retrieved_docs: List[Dict]) -> List[Dict]:
        query_no_stopwords = remove_hebrew_stop_words(query, self.stop_words)
        for doc in retrieved_docs:
            semantic_score = doc['score']
            doc_no_stopwords = remove_hebrew_stop_words(doc['question'], self.stop_words)
            query_words = set(query_no_stopwords.split())
            doc_words = set(doc_no_stopwords.split())
            intersection = len(query_words.intersection(doc_words))
            union = len(query_words.union(doc_words))
            jaccard_score = intersection / union if union > 0 else 0
            query_subject = "כללי"
            doc_subject = doc.get('subject', 'כללי')
            subject_score = 1.0 if query_subject.lower() == doc_subject.lower() else 0.5
            doc['final_score'] = (0.6 * semantic_score) + (0.3 * jaccard_score) + (0.1 * subject_score)
        return sorted(retrieved_docs, key=lambda x: x['final_score'], reverse=True)

    def answer_question(self, query: str) -> Tuple[str, List[Dict]]:
        try:
            query = preprocess_hebrew_text(query)
            if not query:
                return "אנא הקלד שאלה.", []

            query_embedding = self.embedding_model.encode([query], show_progress=False)
            scores, indices = self.vector_store.search(query_embedding, top_k=10)

            retrieved_contexts = [
                {
                    'question': self.df.iloc[idx]['question_he'],
                    'answer': self.df.iloc[idx]['answer_he'],
                    'subject': self.df.iloc[idx]['subject'],
                    'score': float(scores[0][i])
                }
                for i, idx in enumerate(indices[0]) if scores[0][i] > 0.25
            ]

            if retrieved_contexts:
                reranked_contexts = self.custom_distance_metric(query, retrieved_contexts)
                contexts = reranked_contexts[:3]
            else:
                contexts = []

            if not contexts:
                return "מצטער, לא מצאתי מידע רלוונטי במאגר לשאלתך. אנא נסה לנסח אותה מחדש או פנה לייעוץ רפואי.", []

            context_text = "\n\n".join([
                f"מידע רלוונטי {i+1}:\n- נושא: {ctx['subject']}\n- שאלה דומה: {ctx['question']}\n- תשובה: {ctx['answer']}"
                for i, ctx in enumerate(contexts)
            ])

            prompt = f"""אתה עוזר רפואי וירטואלי, בעל ידע רב וגישה אמפתית. תפקידך הוא לענות על שאלות רפואיות בעברית.
השתמש במידע הרלוונטי שסופק לך למטה כדי לבנות תשובה מקיפה, ברורה וקלה להבנה.
אל תעתיק את התשובות מהמידע שסופק, אלא סנתז אותן לתשובה חדשה וקוהרנטית.
פנה ישירות למשתמש והתייחס לשאלתו הספציפית.
---
[מידע רפואי זמין]
{context_text}
---
[הנחיות]
בהתבסס על המידע שלעיל, ענה על השאלה הבאה בצורה הטובה ביותר. היה ברור, אמפתי, והסבר את הדברים בפשטות.
השאלה: "{query}"
תשובה מנוסחת היטב:"""

            answer = self.llm.generate_response(prompt)
            return (answer if answer else contexts[0]['answer']), contexts

        except Exception as e:
            print(f"⚠️ Error in answer_question: {e}")
            return f"מצטער, אירעה שגיאה בעיבוד השאלה: {str(e)}", []

# ============================================================================
# GRADIO INTERFACE
# ============================================================================

def create_gradio_interface(rag_system: HebrewMedicalRAG):
    def chat_response(message: str, history: List = None) -> str:
        if not message.strip():
            return "אנא הקלד שאלה."

        try:
            answer, contexts = rag_system.answer_question(message)
            disclaimer = "\n\n---\n*⚠️ **הערה חשובה:** המידע כאן הוא למטרות לימוד בלבד ואינו מהווה תחליף לייעוץ רפואי מקצועי.*"

            if contexts:
                sources = "\n\n*מקורות מידע דומים שנמצאו (לפי רלוונטיות):*\n" + "\n".join([
                    f"- *{ctx['question']} (רלוונטיות: {ctx.get('final_score', ctx['score']):.0%})*"
                    for ctx in contexts
                ])
                answer += sources

            return answer + disclaimer

        except Exception as e:
            return f"מצטער, אירעה שגיאה: {str(e)}"

    return gr.Interface(
        fn=chat_response,
        inputs=gr.Textbox(
            label="שאל שאלה רפואית בעברית",
            placeholder="לדוגמה: מה הטיפול בפטרת ציפורניים?",
            lines=2
        ),
        outputs=gr.Markdown(label="תשובה"),
        title="🏥 עוזר רפואי דיגיטלי (מבוסס על נתוני כללית)",
        description="שאל אותי שאלה רפואית בעברית, ואנסה לענות בהתבסס על מאגר המידע. **אזהרה:** איני רופא. המידע הוא למטרות לימוד בלבד.",
        examples=[
            ["מה הטיפול בפטרת ציפורניים?"],
            ["מהי אנמיה?"],
            ["מתי יש צורך בבדיקת סיטי?"]
        ],
        theme='soft',
        allow_flagging="never"
    )

# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    """Main execution function"""
    try:
        # Authentication
        print("🔑 Logging into Hugging Face Hub...")
        token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
        if token:
            print("✅ Found Hugging Face token in environment variables.")
            login(token=token)
        else:
            print("⚠️ No token found in environment. You may need to log in interactively.")
            try:
                login()
            except Exception as e:
                print(f"⚠️ Login failed: {e}. Continuing without login (may affect model access).")

        print("✅ Authentication complete.")

        # Setup environment variables for CUDA debugging
        os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

        print(f"🔧 Gradio version: {gr.__version__}")
        print(f"🔧 PyTorch version: {torch.__version__}")
        print(f"🔧 CUDA available: {torch.cuda.is_available()}")

        if torch.cuda.is_available():
            print(f"🔧 CUDA version: {torch.version.cuda}")
            print(f"🔧 GPU: {torch.cuda.get_device_name(0)}")

        # Initialize RAG system
        rag_system = HebrewMedicalRAG()

        # Create and launch interface
        interface = create_gradio_interface(rag_system)
        print("\n🎉 Launching Hebrew Medical RAG Chatbot! A public URL will be generated shortly.")
        interface.launch(share=True, debug=False, show_error=True, quiet=False)

    except Exception as e:
        print(f"\n❌ A critical error occurred: {e}")
        import traceback
        traceback.print_exc()

        # Try to free GPU memory
        clear_gpu_memory()

if __name__ == "__main__":
    main()