<a href="https://colab.research.google.com/github/MuhammadAttaUrRehman/RAG-Medical-Q-A-System/blob/main/RAG_Medical_Q%26A.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# 🔬 RAG Medical Q&A System with Transfer Learning - FIXED VERSION
# Implementation using Real Kaggle Medical Dataset
# Dataset: Comprehensive Medical Q&A Dataset by TheDevastator

# ===== STEP 1: INSTALLATION AND SETUP =====

!pip install transformers datasets torch sentence-transformers faiss-cpu pandas numpy scikit-learn evaluate rouge-score nltk kaggle wandb

import os
import pandas as pd
import numpy as np
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration,
    TrainingArguments, Trainer
)
from sentence_transformers import SentenceTransformer
import faiss
import torch
from sklearn.model_selection import train_test_split
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
import json
import re
import warnings
import pickle
import glob
import wandb
import zipfile
import shutil
warnings.filterwarnings('ignore')

import nltk
try:
    nltk.download('punkt', quiet=True)
    nltk.download('stopwords', quiet=True)
except:
    print("NLTK data already available or download failed")

# Set W&B API key (replace with your key or set as environment variable)
os.environ["WANDB_API_KEY"] = "5af3b7634147c84c75bbd6a60f69cff64911bd07"  # Replace or set via: export WANDB_API_KEY=your_key
!wandb login --relogin

print("🚀 All dependencies installed successfully!")

# ===== STEP 2: KAGGLE DATASET DOWNLOAD =====
def setup_kaggle_and_download():
    """Setup Kaggle API and download the medical dataset"""
    print("📥 Setting up Kaggle API...")

    # Check if running in Colab for file upload
    try:
        from google.colab import files
        print("""
        📋 KAGGLE SETUP INSTRUCTIONS:
        1. Go to Kaggle.com → Account → API → Create New API Token
        2. Download kaggle.json file
        3. Upload it using the file upload prompt
        """)
        uploaded = files.upload()

        # Setup Kaggle credentials
        os.makedirs(os.path.expanduser('~/.kaggle'), exist_ok=True)
        for filename in uploaded.keys():
            if filename.startswith('kaggle') and filename.endswith('.json'):
                os.system(f'cp "{filename}" ~/.kaggle/kaggle.json')
                os.system('chmod 600 ~/.kaggle/kaggle.json')
                print("✅ Kaggle API setup complete!")
                break
        else:
            raise FileNotFoundError("No kaggle.json file uploaded")

    except ImportError:
        print("⚠️ Not in Colab environment. Ensure kaggle.json is in ~/.kaggle/ or set KAGGLE_USERNAME and KAGGLE_KEY environment variables.")
        if not os.path.exists(os.path.expanduser('~/.kaggle/kaggle.json')):
            raise FileNotFoundError("kaggle.json not found and not in Colab environment")

    # Download the medical Q&A dataset
    print("📊 Downloading Comprehensive Medical Q&A Dataset...")
    os.system('kaggle datasets download -d thedevastator/comprehensive-medical-q-a-dataset')

    # Extract the dataset
    print("📂 Extracting dataset...")
    os.system('unzip -o comprehensive-medical-q-a-dataset.zip')

    print("✅ Dataset downloaded and extracted!")

def load_kaggle_medical_dataset():
    """Load and explore the Kaggle medical dataset"""
    csv_files = glob.glob("*.csv")

    if not csv_files:
        print("⚠️ CSV files not found. Available files:")
        os.system('ls -la')
        return None

    print(f"📂 Found CSV files: {csv_files}")
    main_file = max(csv_files, key=os.path.getsize) if len(csv_files) > 1 else csv_files[0]
    print(f"📖 Loading main dataset: {main_file}")

    try:
        df = pd.read_csv(main_file)
        print(f"✅ Dataset loaded successfully!")
        print(f"📊 Dataset shape: {df.shape}")
        print(f"📋 Columns: {list(df.columns)}")
        print("\n🔍 Sample data:")
        print(df.head())
        return df
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return None

def load_backup_medical_dataset():
    """Load backup medical dataset from alternative source"""
    print("🔄 Using backup comprehensive medical dataset...")
    medical_qa_data = [
        {"question": "What is hypertension?", "answer": "Hypertension, also known as high blood pressure, is a condition in which the blood vessels have persistently raised pressure. Blood pressure is created by the force of blood pushing against the walls of blood vessels as it is pumped by the heart. The higher the pressure, the harder the heart has to pump. Normal blood pressure is 120/80 mmHg. Hypertension is defined as blood pressure above 140/90 mmHg.", "category": "Cardiovascular"},
        {"question": "What are the symptoms of diabetes?", "answer": "The main symptoms of diabetes include increased thirst (polydipsia), frequent urination (polyuria), unexplained weight loss, extreme fatigue, blurred vision, slow-healing cuts and wounds, and tingling or numbness in hands and feet. In type 1 diabetes, symptoms often develop quickly over weeks or months. In type 2 diabetes, symptoms develop more gradually and may be mild initially.", "category": "Endocrine"},
        {"question": "How is pneumonia treated?", "answer": "Pneumonia treatment depends on the type, severity, and patient's overall health. Bacterial pneumonia is typically treated with antibiotics such as amoxicillin, azithromycin, or fluoroquinolones. Viral pneumonia may be treated with antiviral medications. Supportive care includes rest, adequate fluid intake, fever reducers like acetaminophen or ibuprofen, and oxygen therapy if needed. Severe cases may require hospitalization and IV antibiotics.", "category": "Respiratory"},
        {"question": "What causes migraine headaches?", "answer": "Migraine headaches are caused by changes in the brain and surrounding blood vessels. Common triggers include hormonal changes (especially in women), certain foods (aged cheeses, processed meats, alcohol), stress, lack of sleep, bright lights, strong smells, weather changes, and certain medications. Genetics also play a role, as migraines often run in families. The exact mechanism involves abnormal brain activity affecting nerve signals, chemicals, and blood vessels.", "category": "Neurological"},
        {"question": "What is the treatment for depression?", "answer": "Depression treatment typically involves a combination of psychotherapy and medication. Common antidepressants include SSRIs (like sertraline, fluoxetine), SNRIs (like venlafaxine), and tricyclics. Psychotherapy options include cognitive-behavioral therapy (CBT), interpersonal therapy, and psychodynamic therapy. Lifestyle changes such as regular exercise, healthy diet, adequate sleep, stress management, and social support are also important. Severe cases may require hospitalization or electroconvulsive therapy (ECT).", "category": "Mental Health"},
        {"question": "What are the risk factors for heart disease?", "answer": "Major risk factors for heart disease include high blood pressure, high cholesterol, smoking, diabetes, obesity, physical inactivity, unhealthy diet, excessive alcohol consumption, and chronic stress. Non-modifiable risk factors include age (risk increases with age), gender (men have higher risk at younger ages), family history, and ethnicity. Other factors include sleep apnea, chronic kidney disease, and certain autoimmune conditions.", "category": "Cardiovascular"},
        {"question": "How is asthma managed?", "answer": "Asthma management involves both long-term control and quick-relief treatments. Long-term control medications include inhaled corticosteroids (like fluticasone), long-acting beta-agonists, leukotriene modifiers, and combination inhalers. Quick-relief medications include short-acting beta-agonists (like albuterol). Non-medication strategies include identifying and avoiding triggers, regular exercise, maintaining healthy weight, and having an asthma action plan. Regular monitoring with peak flow meters and routine medical check-ups are essential.", "category": "Respiratory"},
        {"question": "What are the symptoms of COVID-19?", "answer": "COVID-19 symptoms can range from mild to severe and may appear 2-14 days after exposure. Common symptoms include fever, cough, shortness of breath, fatigue, muscle aches, headache, loss of taste or smell, sore throat, congestion, nausea, vomiting, and diarrhea. Severe symptoms include difficulty breathing, persistent chest pain, confusion, inability to wake up, and bluish lips or face. Symptoms can vary greatly between individuals, and some people may be asymptomatic.", "category": "Infectious Disease"},
        {"question": "What causes kidney stones?", "answer": "Kidney stones form when certain substances in urine become highly concentrated and crystallize. Common causes include dehydration, high sodium diet, excessive animal protein consumption, high oxalate foods (spinach, nuts, chocolate), low calcium intake, certain medications, family history, obesity, and certain medical conditions like hyperparathyroidism or inflammatory bowel disease. Different types of stones (calcium oxalate, uric acid, struvite, cystine) have different risk factors.", "category": "Urology"},
        {"question": "How is arthritis treated?", "answer": "Arthritis treatment varies by type but generally includes medications, physical therapy, and lifestyle modifications. For osteoarthritis: pain relievers (acetaminophen, NSAIDs), topical creams, joint injections, and physical therapy. For rheumatoid arthritis: disease-modifying antirheumatic drugs (DMARDs), biologics, corticosteroids, and NSAIDs. General measures include regular exercise, weight management, hot/cold therapy, occupational therapy, and in severe cases, joint replacement surgery.", "category": "Rheumatology"}
    ]

    expanded_data = []
    for item in medical_qa_data:
        expanded_data.append(item)
        if "symptoms" in item["question"].lower():
            expanded_data.append({
                "question": f"What are the signs of {item['question'].split('symptoms of ')[-1].rstrip('?')}?",
                "answer": item["answer"],
                "category": item["category"]
            })
        if "treatment" in item["question"].lower() or "treated" in item["question"].lower():
            expanded_data.append({
                "question": f"How do you manage {item['question'].split('is ')[-1].split(' treated')[0]}?",
                "answer": item["answer"],
                "category": item["category"]
            })

    final_data = expanded_data * 200  # Increased multiplier
    return pd.DataFrame(final_data)

# Load dataset
try:
    setup_kaggle_and_download()
    df = load_kaggle_medical_dataset()
    if df is None:
        raise Exception("Failed to load Kaggle dataset")
except Exception as e:
    print(f"⚠️ Kaggle download failed: {e}")
    print("🔄 Using backup comprehensive medical dataset...")
    df = load_backup_medical_dataset()

print(f"📊 Final dataset shape: {df.shape}")
print("📋 Category distribution before preprocessing:")
print(df.get('qtype', df.get('category', pd.Series(['Unknown']*len(df)))).value_counts())

# ===== STEP 3: DATA PREPROCESSING =====
class MedicalDataPreprocessor:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('t5-small')
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def clean_text(self, text):
        if pd.isna(text):
            return ""
        text = str(text)
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'[^\w\s\-\.,;:()\[\]/]', '', text)
        return text.strip()

    def standardize_columns(self, df):
        question_cols = ['question', 'Question', 'query', 'Query', 'input', 'Input']
        answer_cols = ['answer', 'Answer', 'response', 'Response', 'output', 'Output', 'text', 'Text']
        category_cols = ['category', 'Category', 'type', 'Type', 'class', 'Class', 'label', 'Label', 'qtype']

        question_col = None
        answer_col = None
        category_col = None

        for col in df.columns:
            if col in question_cols:
                question_col = col
            elif col in answer_cols:
                answer_col = col
            elif col in category_cols:
                category_col = col

        if question_col is None:
            question_col = df.columns[0]
        if answer_col is None:
            answer_col = df.columns[1] if len(df.columns) > 1 else df.columns[0]
        if category_col is None and len(df.columns) > 2:
            category_col = df.columns[2]

        standardized_df = pd.DataFrame()
        standardized_df['question'] = df[question_col].apply(self.clean_text)
        standardized_df['answer'] = df[answer_col].apply(self.clean_text)

        if category_col and category_col in df.columns:
            standardized_df['category'] = df[category_col].fillna('General')
        else:
            standardized_df['category'] = 'Medical'

        standardized_df = standardized_df[
            (standardized_df['question'].str.len() > 10) &
            (standardized_df['answer'].str.len() > 20)
        ]

        return standardized_df

    def prepare_dataset(self, df):
        print("🔧 Preprocessing dataset...")

        df = self.standardize_columns(df)
        print(f"📊 Shape after standardization: {df.shape}")

        # Deduplicate based on questions only
        df = df.drop_duplicates(subset=['question'])
        print(f"📊 Shape after deduplication: {df.shape}")

        df = df[
            (df['question'].str.len().between(10, 500)) &
            (df['answer'].str.len().between(20, 1000))
        ]
        print(f"📊 Shape after length filtering: {df.shape}")

        category_counts = df['category'].value_counts()
        valid_categories = category_counts[category_counts >= 2].index
        df = df[df['category'].isin(valid_categories)]
        print(f"📊 Shape after category filtering: {df.shape}")
        print(f"📋 Category counts after filtering:\n{category_counts[valid_categories]}")

        if df.empty:
            raise ValueError("No valid categories with sufficient samples after filtering.")

        print(f"✅ Preprocessed dataset shape: {df.shape}")
        print(f"📊 Valid categories: {list(valid_categories)}")

        try:
            train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42, stratify=df['category'])
            val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['category'])
        except ValueError as e:
            print(f"⚠️ Stratified split failed: {e}. Falling back to non-stratified split...")
            train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
            val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

        print(f"📊 Training samples: {len(train_df)}")
        print(f"📊 Validation samples: {len(val_df)}")
        print(f"📊 Test samples: {len(test_df)}")

        print("\n📈 Category distribution in training set:")
        print(train_df['category'].value_counts())

        return train_df, val_df, test_df

# Preprocess the data
preprocessor = MedicalDataPreprocessor()
train_df, val_df, test_df = preprocessor.prepare_dataset(df)

# ===== STEP 4: RAG SYSTEM IMPLEMENTATION =====
class MedicalRAGSystem:
    def __init__(self):
        self.embedding_model_name = 'all-MiniLM-L6-v2'
        self.generator_model_name = 't5-small'

        print(f"🔧 Initializing RAG system...")
        print(f"📥 Loading embedding model: {self.embedding_model_name}")
        self.embedding_model = SentenceTransformer(self.embedding_model_name)

        print(f"📥 Loading generator model: {self.generator_model_name}")
        self.generator_tokenizer = AutoTokenizer.from_pretrained(self.generator_model_name)
        if self.generator_tokenizer.pad_token is None:
            self.generator_tokenizer.pad_token = self.generator_tokenizer.eos_token
        self.generator_model = T5ForConditionalGeneration.from_pretrained(self.generator_model_name)

        self.knowledge_base = None
        self.faiss_index = None

    def build_knowledge_base(self, df):
        print("🔍 Building knowledge base with FAISS indexing...")

        self.knowledge_base = df[['question', 'answer', 'category']].to_dict('records')

        combined_texts = []
        for item in self.knowledge_base:
            combined_text = f"Question: {item['question']} Answer: {item['answer']}"
            combined_texts.append(combined_text)

        print("🧮 Generating embeddings...")
        embeddings = self.embedding_model.encode(combined_texts, show_progress_bar=True)

        dimension = embeddings.shape[1]
        self.faiss_index = faiss.IndexFlatIP(dimension)
        faiss.normalize_L2(embeddings)
        self.faiss_index.add(embeddings.astype('float32'))

        print(f"✅ Knowledge base built with {len(self.knowledge_base)} entries")
        print(f"📐 Embedding dimension: {dimension}")

    def retrieve_relevant_context(self, query, top_k=5):
        query = self.embedding_model.encode([re.sub(r'\s+', ' ', query.strip().lower())])
        faiss.normalize_L2(query)

        scores, indices = self.faiss_index.search(query.astype('float32'), top_k)

        retrieved_contexts = []
        for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
            if idx < len(self.knowledge_base) and score > 0.3:
                context = self.knowledge_base[idx]
                retrieved_contexts.append({
                    'answer': context['answer'],
                    'question': context['question'],
                    'category': context['category'],
                    'score': float(score)
                })

        return retrieved_contexts

    def generate_answer(self, question, max_length=300):
        contexts = self.retrieve_relevant_context(question)

        if not contexts:
            return "Unable to generate answer. Please consult a medical professional.", []

        context_parts = []
        for ctx in contexts[:3]:
            context_parts.append(ctx['answer'])
        context_text = " ".join(context_parts)

        input_text = f"question: {question} context: {context_text}"

        inputs = self.generator_tokenizer(
            input_text,
            max_length=512,
            truncation=True,
            padding=True,
            return_tensors='pt'
        )

        try:
            with torch.no_grad():
                outputs = self.generator_model.generate(
                    inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    max_length=max_length,
                    num_beams=4,
                    do_sample=True,
                    top_p=0.9,
                    early_stopping=True,
                    pad_token_id=self.generator_tokenizer.pad_token_id
                )

            answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
            return answer if answer.strip() else "Unable to generate answer. Please consult a medical professional.", contexts
        except:
            return "Unable to generate answer. Please consult a medical professional.", contexts

# Initialize and test RAG system
rag_system = MedicalRAGSystem()
rag_system.build_knowledge_base(train_df)
test_question = "What are the symptoms of diabetes?"
answer, contexts = rag_system.generate_answer(test_question)
print(f"\n🧪 SYSTEM TEST:")
print(f"Question: {test_question}")
print(f"Generated Answer: {answer}")
print(f"Retrieved {len(contexts)} relevant contexts")

# ===== STEP 5: FINE-TUNING PROCESS =====
class MedicalRAGTrainer:
    def __init__(self, rag_system):
        self.rag_system = rag_system
        self.model = rag_system.generator_model
        self.tokenizer = rag_system.generator_tokenizer

    def create_training_examples(self, df, num_samples=None):
        if num_samples:
            df = df.sample(n=min(num_samples, len(df)), random_state=42)

        training_examples = []
        print(f"🔄 Creating training examples from {len(df)} samples...")

        for idx, row in df.iterrows():
            question = row['question']
            true_answer = row['answer']

            contexts = self.rag_system.retrieve_relevant_context(question, top_k=2)
            filtered_contexts = [
                ctx for ctx in contexts
                if ctx['question'].lower().strip() != question.lower().strip()
            ]

            context_text = " ".join([ctx['answer'] for ctx in filtered_contexts[:1]]) if filtered_contexts else ""

            input_text = f"question: {question} context: {context_text}"
            training_examples.append({
                'input': input_text,
                'target': true_answer
            })

        return training_examples

    def tokenize_examples(self, examples):
        inputs = [ex['input'] for ex in examples]
        targets = [ex['target'] for ex in examples]

        input_encodings = self.tokenizer(
            inputs,
            max_length=512,
            truncation=True,
            padding=True,
            return_tensors='pt'
        )

        target_encodings = self.tokenizer(
            targets,
            max_length=200,
            truncation=True,
            padding=True,
            return_tensors='pt'
        )

        return {
            'input_ids': input_encodings['input_ids'],
            'attention_mask': input_encodings['attention_mask'],
            'labels': target_encodings['input_ids']
        }

    def fine_tune(self, train_df, val_df, num_epochs=5, batch_size=2):
        print("🎯 Starting fine-tuning process...")

        # Clear GPU memory if available
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        train_examples = self.create_training_examples(train_df, num_samples=min(2000, len(train_df)))
        val_examples = self.create_training_examples(val_df, num_samples=min(400, len(val_df)))

        train_dataset = Dataset.from_dict(self.tokenize_examples(train_examples))
        val_dataset = Dataset.from_dict(self.tokenize_examples(val_examples))

        training_args = TrainingArguments(
            output_dir='./medical-rag-model',
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            gradient_accumulation_steps=4,
            warmup_steps=200,
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=100,
            eval_strategy="steps",
            eval_steps=200,
            save_steps=200,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            remove_unused_columns=False,
            dataloader_pin_memory=False,
            fp16=torch.cuda.is_available(),  # Enable mixed precision if GPU available
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=self.tokenizer,
        )

        print("🏃 Training started...")
        trainer.train()

        trainer.save_model('./medical-rag-finetuned')
        self.tokenizer.save_pretrained('./medical-rag-finetuned')

        print("✅ Fine-tuning completed!")
        return trainer

# Fine-tune the model
trainer = MedicalRAGTrainer(rag_system)
trained_model = trainer.fine_tune(train_df, val_df)

# ===== STEP 6: EVALUATION =====
class MedicalRAGEvaluator:
    def __init__(self, rag_system):
        self.rag_system = rag_system
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    def evaluate_model(self, test_df, num_samples=10):
        print(f"📊 Evaluating model on {min(num_samples, len(test_df))} samples...")

        test_sample = test_df.sample(n=min(num_samples, len(test_df)), random_state=42)

        predictions = []
        references = []
        retrieval_scores = []
        bleu_scores = []

        for idx, row in test_sample.iterrows():
            question = row['question']
            true_answer = row['answer']

            try:
                predicted_answer, contexts = self.rag_system.generate_answer(question)
                if predicted_answer.strip() and predicted_answer != "Unable to generate answer. Please consult a medical professional.":
                    predictions.append(predicted_answer)
                    references.append(true_answer)
                    retrieval_score = min(len(contexts) / 5, 1.0)
                    retrieval_scores.append(retrieval_score)
                    bleu_score = sentence_bleu([true_answer.split()], predicted_answer.split(), weights=(0.25, 0.25, 0.25, 0.25))
                    bleu_scores.append(bleu_score)
                else:
                    print(f"Skipped empty/invalid prediction for question: {question[:50]}...")
            except Exception as e:
                print(f"Error processing question: {question[:50]}... Error: {e}")
                predictions.append("Error generating answer")
                references.append(true_answer)
                retrieval_scores.append(0)
                bleu_scores.append(0)

        rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}

        for pred, ref in zip(predictions, references):
            try:
                if pred.strip() and pred != "Error generating answer":
                    scores = self.rouge_scorer.score(ref, pred)
                    rouge_scores['rouge1'].append(scores['rouge1'].fmeasure)
                    rouge_scores['rouge2'].append(scores['rouge2'].fmeasure)
                    rouge_scores['rougeL'].append(scores['rougeL'].fmeasure)
                else:
                    rouge_scores['rouge1'].append(0)
                    rouge_scores['rouge2'].append(0)
                    rouge_scores['rougeL'].append(0)
            except:
                rouge_scores['rouge1'].append(0)
                rouge_scores['rouge2'].append(0)
                rouge_scores['rougeL'].append(0)

        evaluation_results = {
            'ROUGE-1': np.mean(rouge_scores['rouge1']),
            'ROUGE-2': np.mean(rouge_scores['rouge2']),
            'ROUGE-L': np.mean(rouge_scores['rougeL']),
            'BLEU': np.mean(bleu_scores),
            'Retrieval Score': np.mean(retrieval_scores),
            'Total Samples': len(predictions)
        }

        return evaluation_results, predictions, references, test_sample

    def display_evaluation_results(self, results, predictions, references, test_sample):
        print("\n" + "="*60)
        print("🎯 MEDICAL RAG EVALUATION RESULTS")
        print("="*60)

        print("\n📊 OVERALL PERFORMANCE METRICS:")
        for metric, score in results.items():
            if isinstance(score, float):
                print(f"{metric}: {score:.4f}")
            else:
                print(f"{metric}: {score}")

        print("\n📝 SAMPLE PREDICTIONS:")
        print("-" * 60)

        for i in range(min(3, len(predictions))):
            row = test_sample.iloc[i]
            print(f"\n🔍 Example {i+1}:")
            print(f"Category: {row['category']}")
            print(f"Question: {row['question']}")
            print(f"True Answer: {references[i][:150]}...")
            print(f"Predicted Answer: {predictions[i][:150]}...")
            print("-" * 40)

# Evaluate the model
evaluator = MedicalRAGEvaluator(rag_system)
results, predictions, references, test_sample = evaluator.evaluate_model(test_df)
evaluator.display_evaluation_results(results, predictions, references, test_sample)

# ===== STEP 7: REAL-WORLD APPLICATION DEMO =====
class MedicalRAGDemo:
    def __init__(self, rag_system):
        self.rag_system = rag_system

    def run_demo_questions(self):
        print("\n" + "="*60)
        print("🏥 MEDICAL RAG SYSTEM - REAL-WORLD DEMO")
        print("="*60)

        demo_questions = [
            "What are the early signs of heart disease?",
            "How can diabetes be prevented?",
            "What should I do if I have chest pain?",
            "What are the side effects of blood pressure medication?",
            "How is COVID-19 different from the flu?",
            "What lifestyle changes help with high cholesterol?",
            "When should I see a doctor for headaches?",
            "What are the symptoms of a heart attack?",
        ]

        print("🤖 AI Medical Assistant ready! Here are some example consultations:\n")

        for i, question in enumerate(demo_questions, 1):
            print(f"👤 Patient Question {i}: {question}")
            try:
                answer, contexts = self.rag_system.generate_answer(question)
                print(f"🏥 AI Assistant: {answer}")
                if contexts:
                    print(f"📚 Based on {len(contexts)} relevant medical sources")
                    print(f"🏷️ Primary category: {contexts[0]['category'] if contexts else 'General'}")
            except Exception as e:
                print(f"❌ Error: {e}")
            print("-" * 50)

        print("\n⚠️ MEDICAL DISCLAIMER:")
        print("This AI system is for educational purposes only.")
        print("Always consult with qualified healthcare professionals for medical advice.")
        print("Do not use this system for emergency medical situations.")

# Run the demo
demo = MedicalRAGDemo(rag_system)
demo.run_demo_questions()

# ===== STEP 8: MODEL PERFORMANCE ANALYSIS =====
def analyze_model_performance():
    print("\n" + "="*60)
    print("📈 DETAILED PERFORMANCE ANALYSIS")
    print("="*60)

    category_performance = {}

    for category in test_df['category'].unique():
        category_data = test_df[test_df['category'] == category]
        if len(category_data) >= 2:
            sample_size = min(5, len(category_data))
            category_sample = category_data.sample(n=sample_size, random_state=42)

            rouge_scores = []
            for _, row in category_sample.iterrows():
                try:
                    answer, _ = rag_system.generate_answer(row['question'])
                    if answer.strip() and answer != "Unable to generate answer. Please consult a medical professional.":
                        rouge_score = evaluator.rouge_scorer.score(row['answer'], answer)
                        rouge_scores.append(rouge_score['rouge1'].fmeasure)
                    else:
                        rouge_scores.append(0)
                except Exception as e:
                    print(f"Error analyzing category {category}: {e}")
                    rouge_scores.append(0)

            category_performance[category] = {
                'avg_rouge1': np.mean(rouge_scores),
                'sample_count': sample_size,
                'total_questions': len(category_data)
            }

    print("\n🏷️ PERFORMANCE BY MEDICAL CATEGORY:")
    print("-" * 50)
    for category, metrics in sorted(category_performance.items(),
                                   key=lambda x: x[1]['avg_rouge1'],
                                   reverse=True):
        print(f"{category:20} | ROUGE-1: {metrics['avg_rouge1']:.3f} | "
              f"Samples: {metrics['sample_count']}/{metrics['total_questions']}")

    print(f"\n🔧 SYSTEM CAPABILITIES SUMMARY:")
    print("-" * 50)
    print(f"📊 Knowledge Base Size: {len(rag_system.knowledge_base):,} medical Q&A pairs")
    print(f"🧮 Embedding Dimension: {rag_system.faiss_index.d}")
    print(f"🎯 Retrieval Method: FAISS with cosine similarity")
    print(f"🤖 Generator Model: {rag_system.generator_model_name}")
    print(f"📝 Max Answer Length: 300 tokens")
    print(f"🔍 Context Window: 512 tokens")

analyze_model_performance()

# ===== STEP 9: SAVE AND EXPORT MODEL =====
def save_complete_system():
    print("\n💾 SAVING COMPLETE RAG SYSTEM")
    print("="*50)

    os.makedirs('./saved_medical_rag', exist_ok=True)

    try:
        rag_system.generator_model.save_pretrained('./saved_medical_rag/generator')
        rag_system.generator_tokenizer.save_pretrained('./saved_medical_rag/generator')
        faiss.write_index(rag_system.faiss_index, './saved_medical_rag/faiss_index.bin')

        with open('./saved_medical_rag/knowledge_base.pkl', 'wb') as f:
            pickle.dump(rag_system.knowledge_base, f)

        train_df.to_csv('./saved_medical_rag/train_data.csv', index=False)
        val_df.to_csv('./saved_medical_rag/val_data.csv', index=False)
        test_df.to_csv('./saved_medical_rag/test_data.csv', index=False)

        with open('./saved_medical_rag/evaluation_results.json', 'w') as f:
            json.dump(results, f, indent=2)

        print("✅ Complete RAG system saved successfully!")
        print("📁 Files saved in './saved_medical_rag/' directory")

        print("\n📋 SAVED FILES:")
        for root, dirs, files in os.walk('./saved_medical_rag'):
            for file in files:
                file_path = os.path.join(root, file)
                file_size = os.path.getsize(file_path) / (1024*1024)
                print(f"  📄 {file}: {file_size:.2f} MB")

        # Zip and download results
        print("\n📦 Zipping and downloading results...")
        output_zip = 'medical_rag_output.zip'
        with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
            # Add saved_medical_rag directory
            for root, dirs, files in os.walk('./saved_medical_rag'):
                for file in files:
                    zipf.write(os.path.join(root, file),
                              os.path.relpath(os.path.join(root, file), './saved_medical_rag'))
            # Add W&B logs if available
            if os.path.exists('./wandb'):
                for root, dirs, files in os.walk('./wandb'):
                    for file in files:
                        zipf.write(os.path.join(root, file),
                                  os.path.relpath(os.path.join(root, file), './wandb'))

        print(f"✅ Zipped output to {output_zip}")

        try:
            from google.colab import files
            files.download(output_zip)
            print("✅ Download initiated for medical_rag_output.zip")
        except ImportError:
            print(f"⚠️ Not in Colab. Please manually download {output_zip} from your working directory.")

    except Exception as e:
        print(f"❌ Error saving system: {e}")

save_complete_system()

# ===== STEP 10: USAGE INSTRUCTIONS AND DOCUMENTATION =====
def print_usage_instructions():
    print("\n" + "="*60)
    print("📚 USAGE INSTRUCTIONS & DOCUMENTATION")
    print("="*60)

    print("""
🚀 HOW TO USE THIS MEDICAL RAG SYSTEM:

1️⃣ SETUP:
   • Run all cells in order
   • Ensure kaggle.json is uploaded or configured
   • Wait for dataset preprocessing and model training

2️⃣ MAKING PREDICTIONS:
   question = "What are the symptoms of diabetes?"
   answer, contexts = rag_system.generate_answer(question)
   print(answer)

3️⃣ LOADING SAVED MODEL:
   from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
   import faiss
   import pickle

   tokenizer = AutoTokenizer.from_pretrained('./saved_medical_rag/generator')
   model = AutoModelForSeq2SeqLM.from_pretrained('./saved_medical_rag/generator')
   index = faiss.read_index('./saved_medical_rag/faiss_index.bin')
   with open('./saved_medical_rag/knowledge_base.pkl', 'rb') as f:
       knowledge_base = pickle.load(f)

4️⃣ SYSTEM ARCHITECTURE:
   📥 Input Question
   ↓
   🔍 Retrieval (FAISS + Sentence Transformers)
   ↓
   🤖 Generation (Fine-tuned T5)
   ↓
   📤 Medical Answer + Context

5️⃣ KEY FEATURES:
   ✅ Transfer Learning with T5-small
   ✅ FAISS-based efficient retrieval
   ✅ Medical domain fine-tuning
   ✅ Comprehensive evaluation metrics
   ✅ Real-world application demo

6️⃣ EVALUATION METRICS:
   • ROUGE-1, ROUGE-2, ROUGE-L scores
   • BLEU score
   • Retrieval quality assessment
   • Category-wise performance analysis
   • Sample predictions review

7️⃣ APPLICATIONS:
   • Medical Q&A systems
   • Healthcare chatbots
   • Clinical decision support
   • Medical education tools
   • Symptom checking applications

⚠️ IMPORTANT NOTES:
   • This is for educational/research purposes only
   • Always consult healthcare professionals for medical advice
   • Model performance depends on training data quality
   • Fine-tune on domain-specific data for best results
""")

    print("\n🎓 PROJECT COMPLETION CHECKLIST:")
    print("-" * 50)
    checklist = [
        "✅ Dataset loaded (Kaggle or backup)",
        "✅ Data preprocessing completed",
        "✅ RAG system implemented",
        "✅ Transfer learning applied",
        "✅ Model fine-tuning completed",
        "✅ Comprehensive evaluation performed",
        "✅ Real-world demo created",
        "✅ Model saved for future use",
        "✅ Documentation provided"
    ]

    for item in checklist:
        print(f"  {item}")

    print(f"\n🏆 PROJECT STATUS: COMPLETE!")
    print(f"📊 Final Model Performance: ROUGE-1 = {results['ROUGE-1']:.3f}, BLEU = {results['BLEU']:.3f}")
    print(f"🎯 Knowledge Base: {len(rag_system.knowledge_base):,} medical Q&A pairs")

print_usage_instructions()

print("\n🎮 OPTIONAL: Run Interactive Demo")
print("Uncomment the line below to start interactive chat:")
print("# demo.interactive_chat()")

# Uncomment to run interactive demo:
# demo.interactive_chat()

print("\n🎉 Medical RAG System Implementation Complete!")
print("="*60)

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvid

Saving kaggle.json to kaggle.json
✅ Kaggle API setup complete!
📊 Downloading Comprehensive Medical Q&A Dataset...
📂 Extracting dataset...
✅ Dataset downloaded and extracted!
📂 Found CSV files: ['train.csv']
📖 Loading main dataset: train.csv
✅ Dataset loaded successfully!
📊 Dataset shape: (16407, 3)
📋 Columns: ['qtype', 'Question', 'Answer']

🔍 Sample data:
             qtype                                           Question  \
0   susceptibility  Who is at risk for Lymphocytic Choriomeningiti...   
1         symptoms  What are the symptoms of Lymphocytic Choriomen...   
2   susceptibility  Who is at risk for Lymphocytic Choriomeningiti...   
3  exams and tests  How to diagnose Lymphocytic Choriomeningitis (...   
4        treatment  What are the treatments for Lymphocytic Chorio...   

                                              Answer  
0  LCMV infections can occur after exposure to fr...  
1  LCMV is most commonly recognized as causing ne...  
2  Individuals of all ages who come i

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

🔧 Preprocessing dataset...
📊 Shape after standardization: (16406, 3)
📊 Shape after deduplication: (14977, 3)
📊 Shape after length filtering: (8321, 3)
📊 Shape after category filtering: (8320, 3)
📋 Category counts after filtering:
category
information        2057
treatment          1743
inheritance        1188
frequency          1118
causes              427
outlook             338
genetic changes     320
symptoms            283
exams and tests     275
research            246
susceptibility      147
prevention           84
considerations       69
complications        18
stages                7
Name: count, dtype: int64
✅ Preprocessed dataset shape: (8320, 3)
📊 Valid categories: ['information', 'treatment', 'inheritance', 'frequency', 'causes', 'outlook', 'genetic changes', 'symptoms', 'exams and tests', 'research', 'susceptibility', 'prevention', 'considerations', 'complications', 'stages']
📊 Training samples: 5824
📊 Validation samples: 1248
📊 Test samples: 1248

📈 Category distribution 

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

📥 Loading generator model: t5-small


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

🔍 Building knowledge base with FAISS indexing...
🧮 Generating embeddings...


Batches:   0%|          | 0/182 [00:00<?, ?it/s]

✅ Knowledge base built with 5824 entries
📐 Embedding dimension: 384

🧪 SYSTEM TEST:
Question: What are the symptoms of diabetes?
Generated Answer: thirsty - urinating
Retrieved 5 relevant contexts
🎯 Starting fine-tuning process...
🔄 Creating training examples from 2000 samples...
🔄 Creating training examples from 400 samples...




🏃 Training started...


[34m[1mwandb[0m: Currently logged in as: [33mrehanreigns123[0m ([33mrehanreigns123-ist-iislamabad[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
200,2.3776,1.804257
400,1.7914,1.607861
600,1.6957,1.544476
800,1.7137,1.514376
1000,1.6153,1.499445
1200,1.6994,1.492892


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


✅ Fine-tuning completed!
📊 Evaluating model on 10 samples...

🎯 MEDICAL RAG EVALUATION RESULTS

📊 OVERALL PERFORMANCE METRICS:
ROUGE-1: 0.1995
ROUGE-2: 0.0549
ROUGE-L: 0.1393
BLEU: 0.0178
Retrieval Score: 1.0000
Total Samples: 10

📝 SAMPLE PREDICTIONS:
------------------------------------------------------------

🔍 Example 1:
Category: information
Question: What is (are) What I need to know about Erectile Dysfunction
True Answer: Erectile dysfunction is when you cannot get or keep an erection firm enough to have sex. You may have ED if you - can get an erection sometimes, thoug...
Predicted Answer: Erectile dysfunction (ED) is a common type of male sexual dysfunction. It can be a sign of health problems. It can be a sign of health problems. It ca...
----------------------------------------

🔍 Example 2:
Category: frequency
Question: How many people are affected by Osteoarthritis
True Answer: The chance of developing osteoarthritis increases with age. It is estimated that 33.6 (12.4 mil

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Download initiated for medical_rag_output.zip

📚 USAGE INSTRUCTIONS & DOCUMENTATION

🚀 HOW TO USE THIS MEDICAL RAG SYSTEM:

1️⃣ SETUP:
   • Run all cells in order
   • Ensure kaggle.json is uploaded or configured
   • Wait for dataset preprocessing and model training

2️⃣ MAKING PREDICTIONS:
   question = "What are the symptoms of diabetes?"
   answer, contexts = rag_system.generate_answer(question)
   print(answer)

3️⃣ LOADING SAVED MODEL:
   from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
   import faiss
   import pickle

   tokenizer = AutoTokenizer.from_pretrained('./saved_medical_rag/generator')
   model = AutoModelForSeq2SeqLM.from_pretrained('./saved_medical_rag/generator')
   index = faiss.read_index('./saved_medical_rag/faiss_index.bin')
   with open('./saved_medical_rag/knowledge_base.pkl', 'rb') as f:
       knowledge_base = pickle.load(f)

4️⃣ SYSTEM ARCHITECTURE:
   📥 Input Question
   ↓
   🔍 Retrieval (FAISS + Sentence Transformers)
   ↓
   🤖 Generation (