In [None]:
# MedQA Retrieval System - Complete Implementation

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel
import torch
import faiss
import streamlit as st
from typing import List, Dict, Tuple

# 1. Data Preparation

def load_medqa_data(file_path: str) -> pd.DataFrame:
    """Load MedQA dataset from JSON file."""
    with open(file_path, 'r') as f:
        data = json.load(f)
    return pd.DataFrame(data)

def process_textbooks(textbook_dir: str) -> List[Dict[str, str]]:
    """Process textbooks into sections."""
    sections = []
    for filename in os.listdir(textbook_dir):
        if filename.endswith('.txt'):
            with open(os.path.join(textbook_dir, filename), 'r') as f:
                content = f.read()
                # Split content into sections (this is a simple split, you might need a more sophisticated approach)
                chapter_sections = content.split('\n\n')
                for i, section in enumerate(chapter_sections):
                    sections.append({
                        'id': f"{filename}_{i}",
                        'content': section.strip()
                    })
    return sections

# 2. Embedding and Indexing

class Embedder:
    def __init__(self, model_name: str = 'sentence-transformers/all-MiniLM-L6-v2'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        
    def embed(self, texts: List[str]) -> np.ndarray:
        """Create embeddings for a list of texts."""
        encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        return model_output.last_hidden_state[:, 0, :].numpy()

def create_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2:
    """Create a FAISS index from embeddings."""
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)
    return index

# 3. Hybrid Search Implementation

def sparse_retrieval(query: str, documents: List[str], top_k: int = 5) -> List[Tuple[int, float]]:
    """Implement sparse retrieval using TF-IDF and cosine similarity."""
    vectorizer = TfidfVectorizer()
    doc_vectors = vectorizer.fit_transform(documents)
    query_vector = vectorizer.transform([query])
    similarities = cosine_similarity(query_vector, doc_vectors).flatten()
    top_indices = similarities.argsort()[-top_k:][::-1]
    return list(zip(top_indices, similarities[top_indices]))

def dense_retrieval(query_embedding: np.ndarray, index: faiss.IndexFlatL2, top_k: int = 5) -> List[Tuple[int, float]]:
    """Implement dense retrieval using FAISS."""
    distances, indices = index.search(query_embedding.reshape(1, -1), top_k)
    return list(zip(indices[0], 1 / (1 + distances[0])))  # Convert distance to similarity score

def hybrid_search(query: str, dense_index: faiss.IndexFlatL2, documents: List[str], embedder: Embedder, alpha: float = 0.5, top_k: int = 5) -> List[Tuple[int, float]]:
    """Implement hybrid search with variable alpha."""
    query_embedding = embedder.embed([query])
    dense_results = dense_retrieval(query_embedding, dense_index, top_k)
    sparse_results = sparse_retrieval(query, documents, top_k)
    
    combined_results = {}
    for idx, score in dense_results:
        combined_results[idx] = alpha * score
    
    for idx, score in sparse_results:
        if idx in combined_results:
            combined_results[idx] += (1 - alpha) * score
        else:
            combined_results[idx] = (1 - alpha) * score
    
    return sorted(combined_results.items(), key=lambda x: x[1], reverse=True)[:top_k]

# 4. Question Answering Model

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

class QuestionAnsweringModel:
    def __init__(self, model_name: str = 't5-base'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    
    def generate_answer(self, question: str, context: str) -> str:
        input_text = f"question: {question} context: {context}"
        input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
        outputs = self.model.generate(input_ids, max_length=50)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

# 5. Streamlit UI

def create_ui(qa_model: QuestionAnsweringModel, embedder: Embedder, dense_index: faiss.IndexFlatL2, documents: List[str]):
    st.title("MedQA Chatbot")
    
    query = st.text_input("Ask a medical question:")
    alpha = st.slider("Set alpha value for hybrid search:", 0.0, 1.0, 0.5)
    
    if query:
        results = hybrid_search(query, dense_index, documents, embedder, alpha)
        context = " ".join([documents[idx] for idx, _ in results])
        answer = qa_model.generate_answer(query, context)
        st.write(f"Answer: {answer}")
        
        st.subheader("Retrieved Contexts:")
        for idx, score in results:
            st.write(f"Score: {score:.4f}")
    
    st.markdown("---")
    st.subheader("About")
    st.write("This chatbot uses a hybrid search method to retrieve relevant information from medical textbooks and generate answers to your questions.")

# 6. Evaluation

def evaluate_system(qa_model: QuestionAnsweringModel, embedder: Embedder, dense_index: faiss.IndexFlatL2, documents: List[str], test_data: pd.DataFrame) -> Dict[str, float]:
    correct = 0
    total = len(test_data)
    
    for _, row in test_data.iterrows():
        question = row['question']
        correct_answer = row['answer']
        
        results = hybrid_search(question, dense_index, documents, embedder)
        context = " ".join([documents[idx] for idx, _ in results])
        generated_answer = qa_model.generate_answer(question, context)
        
        if generated_answer.lower() == correct_answer.lower():
            correct += 1
    
    accuracy = correct / total
    return {"accuracy": accuracy}

# Main execution

def main():
    # Load and process data
    medqa_data = load_medqa_data('path_to_medqa_us_data.json')
    textbook_sections = process_textbooks('path_to_textbooks_directory')
    documents = [section['content'] for section in textbook_sections]
    
    # Create embeddings and index
    embedder = Embedder()
    doc_embeddings = embedder.embed(documents)
    dense_index = create_faiss_index(doc_embeddings)
    
    # Initialize QA model
    qa_model = QuestionAnsweringModel()
    
    # Create UI
    create_ui(qa_model, embedder, dense_index, documents)
    
    # Evaluation
    test_data = medqa_data[medqa_data['split'] == 'test']
    evaluation_results = evaluate_system(qa_model, embedder, dense_index, documents, test_data)
    st.sidebar.subheader("System Evaluation")
    st.sidebar.write(f"Accuracy: {evaluation_results['accuracy']:.2f}")

if __name__ == "__main__":
    main()