In [None]:
# notebooks/01_text_prototype_improved.ipynb

# ==============================
# Phase 2 - Enhanced Text-only Prototype
# ==============================

# --- Imports ---
import os
import json
import joblib
import faiss
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.metrics import classification_report
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import re
from typing import Dict, List, Any

# =================================
# 1. Load and preprocess dataset
# =================================
def load_and_split_data(data_path: str) -> tuple:
    """Load Pima dataset and split into train/test"""
    try:
        df = pd.read_csv(data_path)
        print(f"Dataset loaded: {df.shape[0]} rows, {df.shape[1]} columns")
        
        X = df.drop("Outcome", axis=1)
        y = df["Outcome"]
        
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )
        
        print(f"Train set: {X_train.shape[0]} samples")
        print(f"Test set: {X_test.shape[0]} samples")
        print(f"Positive class ratio - Train: {y_train.mean():.2f}, Test: {y_test.mean():.2f}")
        
        return X_train, X_test, y_train, y_test
        
    except Exception as e:
        print(f"Error loading data: {e}")
        raise

# Load data
data_path = "../data/raw/diabetes.csv"
X_train, X_test, y_train, y_test = load_and_split_data(data_path)

# =================================
# 2. Train and evaluate baseline classifiers
# =================================
def train_classifiers(X_train, y_train, X_test, y_test) -> tuple:
    """Train and evaluate baseline classifiers"""
    
    # Logistic Regression
    print("\n=== Training Logistic Regression ===")
    logreg = LogisticRegression(max_iter=1000, random_state=42)
    logreg.fit(X_train, y_train)
    
    logreg_pred = logreg.predict(X_test)
    print("Logistic Regression Report:")
    print(classification_report(y_test, logreg_pred))
    
    # XGBoost
    print("\n=== Training XGBoost ===")
    xgb = XGBClassifier(
        use_label_encoder=False, 
        eval_metric="logloss",
        random_state=42,
        n_estimators=100
    )
    xgb.fit(X_train, y_train)
    
    xgb_pred = xgb.predict(X_test)
    print("XGBoost Report:")
    print(classification_report(y_test, xgb_pred))
    
    # Save models
    os.makedirs("../models", exist_ok=True)
    joblib.dump(logreg, "../models/logreg.pkl")
    joblib.dump(xgb, "../models/xgb.pkl")
    print("\nModels saved to ../models/")
    
    return logreg, xgb

# Train models
logreg, xgb = train_classifiers(X_train, y_train, X_test, y_test)

# =================================
# 3. Enhanced FAISS RAG index with metadata
# =================================
def build_knowledge_base() -> tuple:
    """Build enhanced knowledge base with more comprehensive diabetes information"""
    
    # Enhanced document collection with titles for better sourcing
    knowledge_docs = [
        {
            "title": "WHO Diabetes Definition",
            "content": "Diabetes is a chronic disease characterized by elevated blood glucose levels. When the body cannot produce enough insulin or cannot effectively use the insulin it produces, glucose builds up in the bloodstream.",
            "source": "WHO Guidelines"
        },
        {
            "title": "Lifestyle Interventions",
            "content": "Lifestyle interventions such as diet and exercise are highly effective in managing diabetes. A combination of regular physical activity and healthy eating can significantly reduce diabetes risk and improve glycemic control.",
            "source": "IDF Clinical Practice Guidelines"
        },
        {
            "title": "WHO Screening Recommendations",
            "content": "The WHO recommends screening for diabetes in adults with risk factors such as obesity, family history, age over 45, and sedentary lifestyle. Early detection enables timely intervention and better outcomes.",
            "source": "WHO Diabetes Screening Guidelines"
        },
        {
            "title": "Global Diabetes Statistics",
            "content": "The IDF estimates that over 537 million adults worldwide are living with diabetes, with projections reaching 783 million by 2045. The majority have type 2 diabetes, which is largely preventable.",
            "source": "IDF Diabetes Atlas 10th Edition"
        },
        {
            "title": "Diabetes Complications",
            "content": "Uncontrolled diabetes can lead to serious complications including cardiovascular disease, kidney disease, diabetic retinopathy leading to blindness, and diabetic neuropathy. These complications are largely preventable with proper management.",
            "source": "WHO Diabetes Complications Report"
        },
        {
            "title": "Risk Factors Assessment",
            "content": "Key diabetes risk factors include high BMI, glucose intolerance, insulin resistance, family history, age, ethnicity, and gestational diabetes history. Multiple risk factors compound the overall diabetes risk.",
            "source": "ADA Risk Assessment Guidelines"
        },
        {
            "title": "Diagnostic Criteria",
            "content": "Diabetes diagnosis is confirmed by fasting plasma glucose ≥126 mg/dL, random plasma glucose ≥200 mg/dL with symptoms, or HbA1c ≥6.5%. Oral glucose tolerance test showing 2-hour glucose ≥200 mg/dL also confirms diagnosis.",
            "source": "WHO Diagnostic Criteria"
        }
    ]
    
    # Extract content and metadata
    docs_content = [doc["content"] for doc in knowledge_docs]
    docs_metadata = [{k: v for k, v in doc.items() if k != "content"} for doc in knowledge_docs]
    
    # Load embedding model
    print("Loading embedding model...")
    embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    
    # Create embeddings
    print("Creating document embeddings...")
    doc_embeddings = embedder.encode(docs_content, convert_to_numpy=True)
    
    # Build FAISS index
    dim = doc_embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(doc_embeddings)
    
    print(f"FAISS index built with {len(docs_content)} documents, dimension: {dim}")
    
    return embedder, index, docs_content, docs_metadata

# Build knowledge base
embedder, index, docs_content, docs_metadata = build_knowledge_base()

# =================================
# 4. Enhanced RAG + LLM pipeline
# =================================
def retrieve_relevant_docs(query: str, k: int = 3) -> List[Dict]:
    """Retrieve relevant documents with metadata"""
    try:
        query_vec = embedder.encode([query], convert_to_numpy=True)
        distances, indices = index.search(query_vec, k)
        
        retrieved_docs = []
        for i, (idx, dist) in enumerate(zip(indices[0], distances[0])):
            retrieved_docs.append({
                "rank": i + 1,
                "content": docs_content[idx],
                "title": docs_metadata[idx]["title"],
                "source": docs_metadata[idx]["source"],
                "similarity_score": float(1 / (1 + dist))  # Convert distance to similarity
            })
        
        return retrieved_docs
        
    except Exception as e:
        print(f"Error in retrieval: {e}")
        return []

def format_risk_level(probability: float) -> str:
    """Convert probability to risk level description"""
    if probability >= 0.7:
        return "High Risk"
    elif probability >= 0.4:
        return "Moderate Risk"
    else:
        return "Low Risk"

def create_enhanced_prompt(patient_features: List[float], probability: float, retrieved_docs: List[Dict]) -> str:
    """Create enhanced prompt for LLM with better structure"""
    
    # Feature names for context
    feature_names = ["Pregnancies", "Glucose", "BloodPressure", "SkinThickness", 
                    "Insulin", "BMI", "DiabetesPedigreeFunction", "Age"]
    
    # Format patient data
    patient_data = ", ".join([f"{name}: {val}" for name, val in zip(feature_names, patient_features)])
    risk_level = format_risk_level(probability)
    
    # Format retrieved documents
    context_docs = "\n".join([
        f"- {doc['title']}: {doc['content']} (Source: {doc['source']})"
        for doc in retrieved_docs
    ])
    
    prompt = f"""You are a medical AI assistant providing diabetes risk assessment explanations.

PATIENT DATA: {patient_data}
PREDICTED DIABETES RISK: {probability:.1%} ({risk_level})

RELEVANT MEDICAL GUIDELINES:
{context_docs}

Please provide a JSON response with exactly these fields:
- "conclusion": A clear summary of the diabetes risk assessment
- "reasoning": Detailed explanation of why this risk level was determined, referencing specific patient factors and guidelines
- "sources": List of guideline sources that support this assessment

Ensure the response is valid JSON format."""

    return prompt

def extract_json_from_response(response_text: str) -> Dict:
    """Extract and validate JSON from LLM response"""
    try:
        # Look for JSON-like content
        json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
        if json_match:
            json_str = json_match.group()
            return json.loads(json_str)
    except:
        pass
    
    # Fallback: create structured response
    return {
        "conclusion": f"Unable to parse structured response from LLM",
        "reasoning": response_text.strip(),
        "sources": ["Response parsing error"]
    }

def enhanced_rag_pipeline(patient_features: List[float], model_name: str = "xgb") -> Dict[str, Any]:
    """Enhanced RAG pipeline with better error handling and output format"""
    
    try:
        # 1. Get classifier prediction
        model = xgb if model_name == "xgb" else logreg
        input_data = np.array(patient_features).reshape(1, -1)
        probability = float(model.predict_proba(input_data)[0, 1])
        prediction = int(probability > 0.5)
        
        # 2. Retrieve relevant documents
        query = f"diabetes risk factors prediction probability {probability:.1%}"
        retrieved_docs = retrieve_relevant_docs(query, k=3)
        
        # 3. Generate LLM response
        prompt = create_enhanced_prompt(patient_features, probability, retrieved_docs)
        
        # Initialize LLM (you might want to replace with OpenAI API for better results)
        llm = pipeline("text-generation", 
                      model="microsoft/DialoGPT-medium",
                      max_length=512,
                      do_sample=True,
                      temperature=0.7)
        
        response = llm(prompt, max_new_tokens=200, num_return_sequences=1)[0]["generated_text"]
        
        # Extract the response part (remove the prompt)
        llm_response = response[len(prompt):].strip()
        
        # 4. Parse LLM response
        structured_response = extract_json_from_response(llm_response)
        
        # 5. Compile final result
        result = {
            "model_used": model_name,
            "prediction": prediction,
            "probability": probability,
            "risk_level": format_risk_level(probability),
            "retrieved_documents": retrieved_docs,
            "llm_explanation": structured_response,
            "raw_llm_response": llm_response
        }
        
        return result
        
    except Exception as e:
        return {
            "error": f"Pipeline error: {str(e)}",
            "model_used": model_name,
            "prediction": None,
            "probability": None
        }

# =================================
# 5. Test enhanced pipeline
# =================================
def test_pipeline():
    """Test the enhanced pipeline with multiple samples"""
    
    print("\n" + "="*50)
    print("TESTING ENHANCED RAG PIPELINE")
    print("="*50)
    
    # Test with different patient profiles
    test_cases = [
        {
            "name": "High Risk Patient",
            "features": X_test.iloc[0].tolist(),
            "model": "xgb"
        },
        {
            "name": "Low Risk Patient", 
            "features": [1, 85, 66, 29, 0, 26.6, 0.351, 31],  # Manually created low-risk profile
            "model": "logreg"
        }
    ]
    
    for i, test_case in enumerate(test_cases, 1):
        print(f"\n--- Test Case {i}: {test_case['name']} ---")
        
        result = enhanced_rag_pipeline(test_case['features'], test_case['model'])
        
        if 'error' in result:
            print(f"ERROR: {result['error']}")
            continue
            
        print(f"Model: {result['model_used']}")
        print(f"Prediction: {result['prediction']} (Risk: {result['risk_level']})")
        print(f"Probability: {result['probability']:.1%}")
        
        print(f"\nRetrieved Documents:")
        for doc in result['retrieved_documents']:
            print(f"  {doc['rank']}. {doc['title']} (Score: {doc['similarity_score']:.3f})")
        
        print(f"\nLLM Explanation:")
        if isinstance(result['llm_explanation'], dict):
            for key, value in result['llm_explanation'].items():
                print(f"  {key.title()}: {value}")
        else:
            print(f"  {result['llm_explanation']}")
        
        print("\n" + "-"*30)

# Run tests
test_pipeline()

# =================================
# 6. Success criteria validation
# =================================
def validate_success_criteria():
    """Validate that all success criteria are met"""
    
    print("\n" + "="*50)
    print("SUCCESS CRITERIA VALIDATION")
    print("="*50)
    
    criteria = {
        "✓ Classifier trained and saved": os.path.exists("../models/logreg.pkl") and os.path.exists("../models/xgb.pkl"),
        "✓ FAISS index built": index.ntotal > 0,
        "✓ RAG retrieval working": len(retrieve_relevant_docs("diabetes test", k=3)) == 3,
        "✓ Pipeline runs end-to-end": True,  # Will be tested below
        "✓ JSON output with required fields": True  # Will be validated below
    }
    
    # Test pipeline execution
    try:
        sample_patient = X_test.iloc[0].tolist()
        result = enhanced_rag_pipeline(sample_patient)
        
        # Check required output structure
        required_fields = ['prediction', 'probability', 'retrieved_documents', 'llm_explanation']
        has_required_fields = all(field in result for field in required_fields)
        criteria["✓ Pipeline runs end-to-end"] = 'error' not in result
        criteria["✓ JSON output with required fields"] = has_required_fields
        
    except Exception as e:
        criteria["✓ Pipeline runs end-to-end"] = False
        criteria["✓ JSON output with required fields"] = False
        print(f"Pipeline test failed: {e}")
    
    # Print results
    for criterion, passed in criteria.items():
        status = "PASS" if passed else "FAIL"
        print(f"{criterion}: {status}")
    
    all_passed = all(criteria.values())
    print(f"\nOVERALL STATUS: {'SUCCESS' if all_passed else 'NEEDS IMPROVEMENT'}")
    
    return all_passed

# Validate success criteria
validate_success_criteria()

print("\n" + "="*50)
print("PHASE 2 PROTOTYPE COMPLETE")
print("="*50)