In [None]:
# ==============================
# Core Imports
# ==============================
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch
import re
from typing import List, Dict, Tuple, Optional, Union
import spacy
from collections import defaultdict

# Try to load spacy model, fallback to basic processing if not available
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    print("Warning: spaCy English model not found. Using basic text processing.")
    nlp = None

# ==============================
# Question Splitter
# ==============================
class QuestionSplitter:
    def __init__(self):
        self.question_patterns = [
            r'\?[^?]*\?',  # Multiple question marks
            r'\?[^\w]*(?=[A-Z])',  # Question mark followed by capital letter
            r'\?[^\w]*(?:what|how|when|where|why|who|which|can|could|would|should|is|are|do|does|did)',  # Question mark followed by question word
        ]
        
        self.conjunction_patterns = [
            r'\band\b',
            r'\bor\b', 
            r'\bthen\b',
            r'\balso\b',
            r'\bplus\b',
            r';',
            r','
        ]
    
    def split_questions(self, text: str) -> List[str]:
        """Split text into multiple questions."""
        questions = []
        
        # First, split by question marks
        parts = re.split(r'\?', text)
        
        for i, part in enumerate(parts[:-1]):  # Exclude last empty part
            question = part.strip()
            if question:
                # Add back the question mark
                question += '?'
                
                # Further split by conjunctions if needed
                sub_questions = self._split_by_conjunctions(question)
                questions.extend(sub_questions)
        
        # Clean and filter questions
        questions = [q.strip() for q in questions if self._is_valid_question(q)]
        
        return questions if questions else [text.strip()]
    
    def _split_by_conjunctions(self, question: str) -> List[str]:
        """Split a question by conjunctions if it contains multiple parts."""
        # Simple heuristic: if question is very long and contains conjunctions
        if len(question.split()) > 15:
            for pattern in self.conjunction_patterns:
                if re.search(pattern, question, re.IGNORECASE):
                    # Try to split intelligently
                    parts = re.split(pattern, question, 1)
                    if len(parts) == 2:
                        first_part = parts[0].strip()
                        second_part = parts[1].strip()
                        
                        # Ensure both parts look like questions
                        if not second_part.strip().endswith('?'):
                            second_part += '?'
                            
                        if self._is_valid_question(first_part) and self._is_valid_question(second_part):
                            return [first_part, second_part]
        
        return [question]
    
    def _is_valid_question(self, text: str) -> bool:
        """Check if text looks like a valid question."""
        text = text.strip()
        if len(text) < 5:
            return False
        
        question_words = ['what', 'how', 'when', 'where', 'why', 'who', 'which', 'can', 'could', 'would', 'should', 'is', 'are', 'do', 'does', 'did']
        
        # Must end with question mark or contain question words
        ends_with_q = text.endswith('?')
        has_q_word = any(word in text.lower().split()[:3] for word in question_words)
        
        return ends_with_q or has_q_word

# ==============================
# Conversation Context Manager
# ==============================
class ConversationContext:
    def __init__(self):
        self.history = []
        self.entities = {}
        self.topics = []
        self.last_topics = []
    
    def add_interaction(self, question: str, answer: str, entities: Dict = None, topics: List = None):
        """Add interaction to conversation history."""
        self.history.append({
            'question': question,
            'answer': answer,
            'entities': entities or {},
            'topics': topics or [],
            'timestamp': len(self.history)
        })
        
        # Update global entities and topics
        if entities:
            self.entities.update(entities)
        if topics:
            self.last_topics = topics
            self.topics.extend(topics)
    
    def get_recent_context(self, n: int = 3) -> List[Dict]:
        """Get last n interactions."""
        return self.history[-n:] if self.history else []
    
    def get_relevant_entities(self) -> Dict:
        """Get entities from recent interactions."""
        return self.entities
    
    def clear(self):
        """Clear conversation history."""
        self.history = []
        self.entities = {}
        self.topics = []
        self.last_topics = []

# ==============================
# Follow-up Question Processor
# ==============================
class FollowUpProcessor:
    def __init__(self):
        self.pronouns = ['it', 'this', 'that', 'these', 'those', 'they', 'them']
        self.demonstratives = ['this', 'that', 'these', 'those']
        
    def extract_entities(self, text: str) -> Dict:
        """Extract entities from text using spaCy if available."""
        entities = {}
        
        if nlp:
            doc = nlp(text)
            for ent in doc.ents:
                entities[ent.label_] = entities.get(ent.label_, []) + [ent.text]
        else:
            # Basic entity extraction using patterns
            # Medical terms
            medical_patterns = r'\b(covid-19|coronavirus|fever|cough|symptoms|treatment|medicine|disease|illness|pain|headache|infection)\b'
            medical_matches = re.findall(medical_patterns, text.lower())
            if medical_matches:
                entities['MEDICAL'] = list(set(medical_matches))
            
            # Legal terms
            legal_patterns = r'\b(divorce|will|court|lawyer|legal|law|contract|agreement|property|custody|petition)\b'
            legal_matches = re.findall(legal_patterns, text.lower())
            if legal_matches:
                entities['LEGAL'] = list(set(legal_matches))
        
        return entities
    
    def resolve_references(self, question: str, context: ConversationContext) -> str:
        """Resolve pronouns and references in follow-up questions."""
        if not context.history:
            return question
        
        resolved_question = question.lower()
        recent_context = context.get_recent_context(2)
        
        # Get the most relevant entities from recent context
        relevant_entities = []
        for interaction in reversed(recent_context):
            if interaction.get('entities'):
                for entity_type, entity_list in interaction['entities'].items():
                    relevant_entities.extend(entity_list)
        
        # Try to resolve common pronouns
        for pronoun in self.pronouns:
            if f' {pronoun} ' in resolved_question or resolved_question.startswith(f'{pronoun} '):
                if relevant_entities:
                    # Use the most recent relevant entity
                    replacement = relevant_entities[0]
                    resolved_question = re.sub(f'\\b{pronoun}\\b', replacement, resolved_question, flags=re.IGNORECASE)
                    break
        
        # Handle "how can I treat it?" -> "how can I treat covid?"
        if 'treat it' in resolved_question or 'cure it' in resolved_question:
            medical_entities = []
            for interaction in reversed(recent_context):
                if interaction.get('entities', {}).get('MEDICAL'):
                    medical_entities.extend(interaction['entities']['MEDICAL'])
            
            if medical_entities:
                resolved_question = resolved_question.replace('it', medical_entities[0])
        
        return resolved_question

# ==============================
# Domain Classifier (Enhanced)
# ==============================
class DomainClassifier:
    def __init__(self, model: str = "facebook/bart-large-mnli"):
        self.classifier = pipeline(
            "zero-shot-classification",
            model=model,
            device=0 if torch.cuda.is_available() else -1
        )

    def preprocess_label(self, label: str) -> str:
        """Make labels more descriptive for the NLI model."""
        return f"This text is about {label.lower()} topics"

    def is_relevant(
        self,
        text: str,
        domain: str,
        contrast_labels: Optional[List[str]] = None,
        threshold: float = 0.7,
        rephrase_labels: bool = True,
    ) -> Tuple[bool, Dict[str, Union[float, str, List[str]]]]:
        """Check if a text is relevant to the given domain."""

        if contrast_labels is None:
            contrast_labels = [
                "sports", "technology", "weather", "health",
                "education", "entertainment", "unrelated", "legal",
            ]

        candidate_labels = [domain] + contrast_labels
        if rephrase_labels:
            candidate_labels = [self.preprocess_label(label) for label in candidate_labels]
            domain_label = self.preprocess_label(domain)
        else:
            domain_label = domain

        result = self.classifier(text, candidate_labels, multi_label=True)

        # Extract domain score
        domain_score = result["scores"][result["labels"].index(domain_label)]

        return domain_score > threshold, {
            "domain_score": domain_score,
            "top_label": result["labels"][0],
            "top_score": result["scores"][0],
            "all_labels": result["labels"],
            "all_scores": result["scores"]
        }

# ==============================
# Enhanced FAQ Matching
# ==============================
def match_question_with_suggestions(user_question, model, faqs, question_embeddings, threshold=0.65, suggestion_threshold=0.4, max_suggestions=3):
    """Find the best match and additional suggestions."""
    user_embedding = model.encode(user_question, convert_to_tensor=True)
    scores = cos_sim(user_embedding, question_embeddings)[0]
    
    # Sort scores to get best matches
    sorted_indices = torch.argsort(scores, descending=True)
    sorted_scores = scores[sorted_indices]
    
    result = {
        "matched": False,
        "main_answer": None,
        "suggestions": [],
        "all_scores": []
    }
    
    # Check if best match exceeds threshold
    if sorted_scores[0].item() >= threshold:
        best_idx = sorted_indices[0].item()
        best_faq = faqs[best_idx]
        
        result["matched"] = True
        result["main_answer"] = {
            "score": round(sorted_scores[0].item(), 4),
            "question": best_faq["question"],
            "answer": best_faq["answer"],
            "domain": best_faq["domain"],
            "consultant_id": best_faq["consultant_id"]
        }
    
    # Find suggestions (other relevant matches)
    suggestions = []
    start_idx = 1 if result["matched"] else 0
    
    for i in range(start_idx, min(len(sorted_scores), max_suggestions + start_idx)):
        score = sorted_scores[i].item()
        if score >= suggestion_threshold:
            idx = sorted_indices[i].item()
            faq = faqs[idx]
            suggestions.append({
                "score": round(score, 4),
                "question": faq["question"],
                "answer": faq["answer"],
                "domain": faq["domain"],
                "consultant_id": faq["consultant_id"]
            })
    
    result["suggestions"] = suggestions
    result["all_scores"] = [round(s.item(), 4) for s in sorted_scores[:10]]
    
    return result

# ==============================
# Multi-Question Handler
# ==============================
class MultiQuestionHandler:
    def __init__(self, model, faqs, question_embeddings, domain_classifier):
        self.model = model
        self.faqs = faqs
        self.question_embeddings = question_embeddings
        self.domain_classifier = domain_classifier
        self.splitter = QuestionSplitter()
        self.followup_processor = FollowUpProcessor()
        self.context = ConversationContext()
    
    def process_input(self, user_input: str, expert_domain: str = "general", 
                     similarity_threshold: float = 0.75, 
                     suggestion_threshold: float = 0.4,
                     domain_threshold: float = 0.7) -> Dict:
        """Process user input (single or multiple questions)."""
        
        # Split input into questions
        questions = self.splitter.split_questions(user_input)
        
        results = {
            "input": user_input,
            "detected_questions": questions,
            "question_count": len(questions),
            "results": []
        }
        
        for i, question in enumerate(questions):
            # Resolve follow-up references
            resolved_question = self.followup_processor.resolve_references(question, self.context)
            
            # Get matches and suggestions
            match_result = match_question_with_suggestions(
                resolved_question, self.model, self.faqs, self.question_embeddings,
                threshold=similarity_threshold, suggestion_threshold=suggestion_threshold
            )
            
            # Extract entities for context
            entities = self.followup_processor.extract_entities(resolved_question)
            
            question_result = {
                "question_number": i + 1,
                "original_question": question,
                "resolved_question": resolved_question if resolved_question != question.lower() else question,
                "entities": entities,
                "match_result": match_result
            }
            
            # Domain verification if main answer found
            if match_result["matched"] and expert_domain != "general":
                main_answer = match_result["main_answer"]
                q_rel, q_info = self.domain_classifier.is_relevant(
                    main_answer["question"], expert_domain, threshold=domain_threshold
                )
                a_rel, a_info = self.domain_classifier.is_relevant(
                    main_answer["answer"], expert_domain, threshold=domain_threshold
                )
                
                question_result["domain_verification"] = {
                    "domain_relevant": q_rel or a_rel,
                    "question_domain_score": q_info["domain_score"],
                    "answer_domain_score": a_info["domain_score"],
                    "max_domain_score": max(q_info["domain_score"], a_info["domain_score"])
                }
            
            results["results"].append(question_result)
            
            # Add to context for follow-up processing
            if match_result["matched"]:
                self.context.add_interaction(
                    question=resolved_question,
                    answer=match_result["main_answer"]["answer"],
                    entities=entities,
                    topics=[match_result["main_answer"]["domain"]]
                )
        
        return results
    
    def get_conversation_summary(self) -> Dict:
        """Get summary of conversation history."""
        return {
            "total_interactions": len(self.context.history),
            "entities": self.context.entities,
            "topics": list(set(self.context.topics)),
            "recent_context": self.context.get_recent_context(3)
        }
    
    def clear_context(self):
        """Clear conversation context."""
        self.context.clear()

# ==============================
# Pretty Printer
# ==============================
def print_results(results: Dict):
    """Pretty print the results."""
    print(f"\n{'='*60}")
    print(f"INPUT: {results['input']}")
    print(f"DETECTED {results['question_count']} QUESTION(S)")
    print(f"{'='*60}")
    
    for result in results["results"]:
        print(f"\n📌 QUESTION {result['question_number']}: {result['original_question']}")
        
        if result["resolved_question"] != result["original_question"]:
            print(f"🔄 RESOLVED TO: {result['resolved_question']}")
        
        if result["entities"]:
            print(f"🏷️  ENTITIES: {result['entities']}")
        
        match_result = result["match_result"]
        
        if match_result["matched"]:
            main = match_result["main_answer"]
            print(f"\n✅ BEST MATCH (Score: {main['score']:.3f})")
            print(f"   Q: {main['question']}")
            print(f"   A: {main['answer']}")
            print(f"   Domain: {main['domain']} | Consultant: {main['consultant_id']}")
            
            # Domain verification
            if "domain_verification" in result:
                dv = result["domain_verification"]
                status = "✅ RELEVANT" if dv["domain_relevant"] else "⚠️  NOT RELEVANT"
                print(f"   🎯 DOMAIN: {status} (Score: {dv['max_domain_score']:.3f})")
            
            # Suggestions
            if match_result["suggestions"]:
                print(f"\n💡 OTHER RELEVANT ANSWERS:")
                for j, sugg in enumerate(match_result["suggestions"], 1):
                    print(f"   {j}. (Score: {sugg['score']:.3f}) {sugg['question']}")
                    print(f"      {sugg['answer'][:100]}{'...' if len(sugg['answer']) > 100 else ''}")
        else:
            print(f"\n❌ NO MATCH FOUND")
            if match_result["suggestions"]:
                print(f"💡 SIMILAR QUESTIONS:")
                for j, sugg in enumerate(match_result["suggestions"], 1):
                    print(f"   {j}. (Score: {sugg['score']:.3f}) {sugg['question']}")
        
        print("-" * 50)

# ==============================
# Enhanced Sample Data
# ==============================
sample_data = [
    {
        "consultant_id": "med_001",
        "domain": "medical",
        "faqs": [
            # COVID-19
            {"question": "What are the symptoms of COVID-19?",
             "answer": "The symptoms include fever, cough, shortness of breath, fatigue, body aches, and loss of taste or smell."},
            {"question": "How can I treat a mild fever at home?",
             "answer": "Rest, stay hydrated, take acetaminophen or ibuprofen, and monitor your temperature regularly."},
            {"question": "How can I treat COVID-19 at home?",
             "answer": "For mild COVID-19, rest, drink plenty of fluids, take fever reducers, isolate yourself, and monitor symptoms closely."},
            {"question": "What is the difference between COVID-19 and flu?",
             "answer": "COVID-19 may cause loss of taste/smell and has longer incubation period, while flu typically has more sudden onset."},
            {"question": "How long should I isolate with COVID-19?",
             "answer": "Isolate for at least 5 days from symptom onset and until fever-free for 24 hours without medication."},

            # Diabetes
            {"question": "What are common symptoms of diabetes?",
             "answer": "Increased thirst, frequent urination, fatigue, blurred vision, and unexplained weight loss."},
            {"question": "How can I manage type 2 diabetes?",
             "answer": "Maintain a healthy diet, exercise regularly, monitor blood sugar, and take prescribed medication."},
            {"question": "What are the complications of uncontrolled diabetes?",
             "answer": "Can include neuropathy, kidney disease, vision problems, and increased risk of heart disease."},
            {"question": "What is the difference between type 1 and type 2 diabetes?",
             "answer": "Type 1 is autoimmune and insulin-dependent, often diagnosed in children; type 2 usually develops in adults and can be managed with lifestyle changes and medication."},

            # Hypertension
            {"question": "What are the symptoms of high blood pressure?",
             "answer": "High blood pressure often has no symptoms, but severe cases may cause headaches, shortness of breath, or nosebleeds."},
            {"question": "How can I lower my blood pressure naturally?",
             "answer": "Reduce salt intake, exercise regularly, manage stress, avoid smoking, and maintain a healthy weight."},
            {"question": "What is normal blood pressure?",
             "answer": "Normal blood pressure is generally around 120/80 mmHg."},

            # Asthma
            {"question": "What triggers asthma attacks?",
             "answer": "Common triggers include allergens, exercise, cold air, respiratory infections, and air pollution."},
            {"question": "How can asthma be treated?",
             "answer": "Use prescribed inhalers, avoid triggers, and follow an asthma action plan provided by your doctor."},
            {"question": "What is the difference between rescue and maintenance inhalers?",
             "answer": "Rescue inhalers provide immediate relief during attacks; maintenance inhalers prevent symptoms over time."}
        ]
    },
    {
        "consultant_id": "legal_101",
        "domain": "legal",
        "faqs": [
            # Divorce
            {"question": "How can I file for divorce?",
             "answer": "You need to submit a divorce petition to your local court with required documents and fees."},
            {"question": "What documents do I need to file for divorce?",
             "answer": "Marriage certificate, financial statements, property agreements, and child custody arrangements if applicable."},
            {"question": "How long does the divorce process take?",
             "answer": "Divorce can take 6 months to 2 years depending on complexity, state requirements, and whether it's contested."},
            {"question": "What is the cost of filing for divorce?",
             "answer": "Filing fees range from $200-$500, plus attorney fees which can be $1,000-$10,000+ depending on complexity."},

            # Wills & Estates
            {"question": "What is the process of creating a will?",
             "answer": "Draft your will, sign it in front of witnesses, and store it safely with copies for executors."},
            {"question": "How can I contest a will?",
             "answer": "File a petition in probate court stating the reasons for contesting within the legal timeframe."},
            {"question": "What is a power of attorney?",
             "answer": "It is a legal document that gives someone the authority to act on your behalf in financial or healthcare matters."},
            {"question": "How can I revoke a power of attorney?",
             "answer": "You can revoke it in writing and notify the agent and relevant institutions."},

            # Real Estate
            {"question": "How do I buy a house?",
             "answer": "Hire a real estate agent, get mortgage pre-approval, view properties, and close the deal with a lawyer or notary."},
            {"question": "What is property tax?",
             "answer": "Property tax is paid by the owner of a property to the local government, usually annually."},
            {"question": "How can I rent out my property legally?",
             "answer": "Ensure compliance with local tenancy laws, draft a lease agreement, and register the rental if required."},
            {"question": "What are the steps to sell a property?",
             "answer": "Prepare the property, hire an agent, market the property, negotiate offers, and complete legal paperwork for transfer."}
        ]
    }
]


# Flatten FAQs
faqs = []
for consultant in sample_data:
    for faq in consultant["faqs"]:
        faqs.append({
            "question": faq["question"],
            "answer": faq["answer"],
            "domain": consultant["domain"],
            "consultant_id": consultant["consultant_id"]
        })

# ==============================
# Interactive Main Execution
# ==============================
def interactive_chat():
    """Interactive chat interface for the FAQ system."""
    print("🚀 Loading models...")
    
    # Load models
    model = SentenceTransformer('all-MiniLM-L6-v2')
    classifier = DomainClassifier()
    
    # Encode FAQ questions
    print("📊 Encoding FAQ questions...")
    questions = [faq["question"] for faq in faqs]
    question_embeddings = model.encode(questions, convert_to_tensor=True)
    
    # Initialize handler
    handler = MultiQuestionHandler(model, faqs, question_embeddings, classifier)
    
    print("✅ System ready!")
    print("\n" + "="*80)
    print("🤖 INTERACTIVE FAQ SYSTEM")
    print("="*80)
    print("Available domains: medical, legal, general")
    print("Commands:")
    print("  - 'quit' or 'exit': Exit the system")
    print("  - 'clear': Clear conversation history") 
    print("  - 'context': Show conversation context")
    print("  - 'domain medical' or 'domain legal': Switch domain")
    print("  - 'help': Show this help message")
    print("="*80)
    
    current_domain = "general"
    
    while True:
        try:
            print(f"\n[Domain: {current_domain.upper()}]")
            user_input = input("🧑 You: ").strip()
            
            if not user_input:
                continue
                
            # Handle commands
            if user_input.lower() in ['quit', 'exit', 'q']:
                print("👋 Goodbye! Thank you for using the FAQ system.")
                break
                
            elif user_input.lower() == 'clear':
                handler.clear_context()
                print("🗑️  Conversation history cleared!")
                continue
                
            elif user_input.lower() == 'context':
                summary = handler.get_conversation_summary()
                print(f"\n📚 CONVERSATION SUMMARY:")
                print(f"   Total interactions: {summary['total_interactions']}")
                print(f"   Entities: {summary['entities']}")
                print(f"   Topics: {summary['topics']}")
                if summary['recent_context']:
                    print(f"   Recent questions:")
                    for i, ctx in enumerate(summary['recent_context'][-3:], 1):
                        print(f"     {i}. {ctx['question']}")
                continue
                
            elif user_input.lower().startswith('domain '):
                new_domain = user_input.lower().replace('domain ', '').strip()
                if new_domain in ['medical', 'legal', 'general']:
                    current_domain = new_domain
                    print(f"🎯 Domain switched to: {current_domain.upper()}")
                else:
                    print("⚠️  Invalid domain. Use: medical, legal, or general")
                continue
                
            elif user_input.lower() == 'help':
                print("\n📖 HELP:")
                print("  - Ask any question (single or multiple)")
                print("  - Use follow-up questions like 'How can I treat it?'")
                print("  - Examples:")
                print("    • 'What are COVID symptoms? How to treat it?'")
                print("    • 'How to file for divorce?'")
                print("    • 'What about the cost?' (follow-up)")
                print("  - Commands: quit, clear, context, domain [name], help")
                continue
            
            # Process the question
            print("🤔 Processing your question(s)...")
            
            results = handler.process_input(
                user_input,
                expert_domain=current_domain,
                similarity_threshold=0.65,
                suggestion_threshold=0.4
            )
            
            # Print results
            print_results(results)
            
            # Show quick context if relevant
            if results["question_count"] > 1:
                print(f"\n💡 Processed {results['question_count']} questions in your input.")
            
            # Check if any follow-up references were resolved
            for result in results["results"]:
                if result["resolved_question"] != result["original_question"]:
                    print(f"🔄 Resolved follow-up reference based on conversation context.")
                    break
                    
        except KeyboardInterrupt:
            print("\n\n👋 Session interrupted. Goodbye!")
            break
        except Exception as e:
            print(f"\n❌ Error occurred: {str(e)}")
            print("Please try again or type 'help' for assistance.")

if __name__ == "__main__":
    interactive_chat()