In [None]:
# Natural Language to Cypher Converter - Backend Component
# Jupyter Notebook Version - Rule-based + Ollama Only (No Groq)

import json
import logging
import re
import requests
import os
import hashlib
from typing import Dict, Any, Optional, List
from datetime import datetime
from enum import Enum
from dataclasses import dataclass


# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Domain(str, Enum):
    MEDICAL = "medical"
    LEGAL = "legal"
    TECHNICAL = "technical"
    GENERAL = "general"
    SCIENCE = "science"
    BUSINESS = "business"
    SOCIAL = "social"
    FINANCE = "finance"

class LLMProvider(str, Enum):
    OLLAMA = "ollama"
    RULE_BASED = "rule_based"

@dataclass
class LLMConfig:
    provider: LLMProvider
    model_name: str
    api_key: Optional[str] = None
    base_url: Optional[str] = None
    max_tokens: int = 512
    temperature: float = 0.1

@dataclass
class ConversionResult:
    cypher_query: str
    domain: str
    confidence: float
    processing_time: float
    metadata: Dict[str, Any]

class DomainDetector:
    """Automatically detect domain from query"""
    
    def __init__(self):
        self.domain_keywords = {
            Domain.MEDICAL: {
                'disease', 'symptom', 'treatment', 'medication', 'diagnosis',
                'patient', 'hospital', 'doctor', 'therapy', 'surgery', 'clinic',
                'prescription', 'medicine', 'health', 'medical', 'cure', 'illness',
                'fever', 'pain', 'infection', 'virus', 'bacteria', 'cancer',
                'fabry', 'diabetes', 'condition', 'disorder'
            },
            Domain.LEGAL: {
                'law', 'court', 'case', 'judge', 'attorney', 'contract',
                'lawsuit', 'legal', 'jurisdiction', 'precedent', 'ruling',
                'litigation', 'lawyer', 'defendant', 'plaintiff', 'statute',
                'crime', 'trial', 'evidence', 'witness', 'verdict'
            },
            Domain.TECHNICAL: {
                'software', 'hardware', 'algorithm', 'system', 'network',
                'database', 'api', 'framework', 'technology', 'programming',
                'code', 'application', 'server', 'cloud', 'platform',
                'computer', 'development', 'bug', 'feature', 'deployment'
            },
            Domain.SCIENCE: {
                'research', 'experiment', 'theory', 'hypothesis', 'study',
                'scientist', 'publication', 'journal', 'laboratory', 'analysis',
                'data', 'discovery', 'innovation', 'academic', 'scientific',
                'physics', 'chemistry', 'biology', 'mathematics', 'astronomy'
            },
            Domain.BUSINESS: {
                'company', 'business', 'market', 'product', 'revenue',
                'profit', 'customer', 'client', 'sales', 'marketing',
                'strategy', 'competition', 'industry', 'corporate', 'financial',
                'startup', 'investment', 'entrepreneur', 'management', 'finance'
            },
            Domain.SOCIAL: {
                'social', 'network', 'user', 'friend', 'follow', 'post',
                'comment', 'like', 'share', 'community', 'group', 'profile',
                'facebook', 'twitter', 'instagram', 'linkedin', 'message'
            },
            Domain.FINANCE: {
                'stock', 'investment', 'portfolio', 'trading', 'market',
                'financial', 'money', 'capital', 'asset', 'fund', 'investor',
                'bank', 'loan', 'credit', 'debt', 'currency', 'exchange'
            }
        }
    
    def detect_domain(self, query: str) -> Domain:
        """Detect domain from query text"""
        query_lower = query.lower()
        query_words = set(query_lower.split())
        
        domain_scores = {}
        for domain, keywords in self.domain_keywords.items():
            score = len(query_words.intersection(keywords))
            domain_scores[domain] = score
        
        max_domain = max(domain_scores, key=domain_scores.get)
        return max_domain if domain_scores[max_domain] > 0 else Domain.GENERAL

class QueryCache:
    """Simple in-memory cache for query results"""
    
    def __init__(self, max_size: int = 1000):
        self.cache = {}
        self.max_size = max_size
        self.hits = 0
        self.misses = 0
    
    def get(self, query_hash: str) -> Optional[ConversionResult]:
        if query_hash in self.cache:
            self.hits += 1
            return self.cache[query_hash]
        else:
            self.misses += 1
            return None
    
    def set(self, query_hash: str, result: ConversionResult):
        if len(self.cache) >= self.max_size:
            # Remove oldest entry (simple FIFO)
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        self.cache[query_hash] = result
    
    def get_stats(self):
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0
        return {
            "hits": self.hits,
            "misses": self.misses,
            "hit_rate": hit_rate,
            "cache_size": len(self.cache)
        }

class RuleBasedConverter:
    """Rule-based converter as fallback"""
    
    def __init__(self):
        self.domain_mappings = {
            Domain.MEDICAL: {
                "patient": "Patient",
                "doctor": "Doctor",
                "disease": "Disease",
                "symptom": "Symptom",
                "treatment": "Treatment",
                "medication": "Medication",
                "condition": "Disease",
                "disorder": "Disease",
                "fabry": "Disease"
            },
            Domain.BUSINESS: {
                "company": "Company",
                "person": "Person",
                "product": "Product",
                "customer": "Customer",
                "employee": "Employee"
            },
            Domain.SOCIAL: {
                "user": "User",
                "post": "Post",
                "friend": "User",
                "group": "Group",
                "comment": "Comment"
            },
            Domain.TECHNICAL: {
                "system": "System",
                "component": "Component",
                "technology": "Technology",
                "software": "Software",
                "hardware": "Hardware"
            },
            Domain.GENERAL: {
                "person": "Person",
                "organization": "Organization",
                "entity": "Entity",
                "concept": "Concept",
                "event": "Event"
            }
        }
    
    def convert(self, query: str, domain: Domain) -> str:
        """Convert using rule-based approach"""
        query_lower = query.lower()
        
        # Determine main operation
        if any(word in query_lower for word in ["find", "show", "get", "list", "search", "who", "what", "where"]):
            operation = "MATCH"
        elif any(word in query_lower for word in ["create", "add", "insert"]):
            operation = "CREATE"
        elif any(word in query_lower for word in ["delete", "remove"]):
            operation = "DELETE"
        else:
            operation = "MATCH"
        
        # Extract entities based on domain
        entities = self.extract_entities(query, domain)
        
        # Build basic query
        if operation == "MATCH":
            if len(entities) == 1:
                cypher = f"MATCH (n:{entities[0]}) RETURN n"
            elif len(entities) == 2:
                cypher = f"MATCH (a:{entities[0]})-[r]-(b:{entities[1]}) RETURN a, r, b"
            else:
                cypher = f"MATCH (n) RETURN n"
        else:
            cypher = f"MATCH (n) RETURN n"
        
        # Add conditions if present
        if "where" in query_lower or "with" in query_lower:
            words = query_lower.split()
            for i, word in enumerate(words):
                if word in ["name", "called", "named"] and i + 1 < len(words):
                    value = words[i + 1].strip('"\'')
                    cypher = cypher.replace("RETURN", f"WHERE n.name CONTAINS '{value}' RETURN")
                    break
        
        # Add limit
        numbers = re.findall(r'\d+', query)
        if numbers and not "LIMIT" in cypher:
            cypher += f" LIMIT {numbers[0]}"
        elif not "LIMIT" in cypher:
            cypher += " LIMIT 10"
        
        return cypher
    
    def extract_entities(self, query: str, domain: Domain) -> List[str]:
        """Extract entity types from query based on domain"""
        entities = []
        query_lower = query.lower()
        
        domain_map = self.domain_mappings.get(domain, self.domain_mappings[Domain.GENERAL])
        
        for keyword, entity_type in domain_map.items():
            if keyword in query_lower:
                entities.append(entity_type)
        
        return list(set(entities))

class OllamaClient:
    """Client for Ollama local LLM"""
    
    def __init__(self, config: LLMConfig):
        self.config = config
        self.base_url = config.base_url or "http://localhost:11434"
        
        # Test connection
        self.test_connection()
    
    def test_connection(self):
        """Test if Ollama is running"""
        try:
            response = requests.get(f"{self.base_url}/api/tags", timeout=5)
            if response.status_code == 200:
                print(f"✅ Ollama is running at {self.base_url}")
            else:
                raise Exception(f"Ollama server responded with status {response.status_code}")
        except requests.exceptions.RequestException as e:
            raise Exception(f"Cannot connect to Ollama. Make sure it's running: {e}")
    
    def generate(self, system_prompt: str, user_query: str) -> str:
        """Generate response using Ollama"""
        try:
            url = f"{self.base_url}/api/generate"
            
            prompt = f"System: {system_prompt}\n\nUser: {user_query}\n\nAssistant:"
            
            payload = {
                "model": self.config.model_name,
                "prompt": prompt,
                "stream": False,
                "options": {
                    "temperature": self.config.temperature,
                    "num_predict": self.config.max_tokens
                }
            }
            
            response = requests.post(url, json=payload, timeout=30)
            response.raise_for_status()
            
            result = response.json()
            return result.get("response", "").strip()
            
        except Exception as e:
            logger.error(f"Ollama generation error: {e}")
            raise

class NaturalLanguageToCypherConverter:
    """Main converter class - Rule-based + Ollama only"""
    
    def __init__(self, llm_config: LLMConfig = None):
        # Default to rule-based if no config provided
        if llm_config is None:
            llm_config = LLMConfig(
                provider=LLMProvider.RULE_BASED,
                model_name="pattern-matching"
            )
        
        self.llm_config = llm_config
        self.llm_client = None
        self.domain_detector = DomainDetector()
        self.rule_based_converter = RuleBasedConverter()
        self.query_cache = QueryCache()
        
        # Initialize LLM client
        self.init_llm_client()
        
        # Setup domain-specific prompts
        self.setup_domain_prompts()
    
    def init_llm_client(self):
        """Initialize LLM client based on provider"""
        try:
            if self.llm_config.provider == LLMProvider.OLLAMA:
                self.llm_client = OllamaClient(self.llm_config)
            else:
                self.llm_client = None  # Use rule-based
            
            logger.info(f"Initialized {self.llm_config.provider} client")
        except Exception as e:
            logger.error(f"Failed to initialize LLM client: {e}")
            logger.info("Falling back to rule-based converter")
            self.llm_client = None
    
    def setup_domain_prompts(self):
        """Setup domain-specific system prompts"""
        self.domain_prompts = {
            Domain.MEDICAL: """
            You are an expert in medical knowledge graphs and Cypher queries. 
            Convert natural language medical queries into Cypher for Neo4j.
            
            Common Medical Schema:
            - Diseases: (:Disease {name, icd_code, severity})
            - Symptoms: (:Symptom {name, type, severity})
            - Treatments: (:Treatment {name, type, duration})
            - Medications: (:Medication {name, dosage, drug_class})
            - Patients: (:Patient {name, age, gender})
            - Doctors: (:Doctor {name, specialty})
            
            Relationships:
            - (Disease)-[:HAS_SYMPTOM]->(Symptom)
            - (Disease)-[:TREATED_BY]->(Treatment)
            - (Patient)-[:HAS_CONDITION]->(Disease)
            - (Doctor)-[:TREATS]->(Patient)
            - (Treatment)-[:INCLUDES]->(Medication)
            
            Return only a valid Cypher query without explanation.
            """,
            
            Domain.BUSINESS: """
            You are an expert in business knowledge graphs and Cypher queries.
            Convert natural language business queries into Cypher for Neo4j.
            
            Common Business Schema:
            - Companies: (:Company {name, industry, revenue})
            - People: (:Person {name, position, department})
            - Products: (:Product {name, category, price})
            - Markets: (:Market {name, region, size})
            
            Relationships:
            - (Person)-[:WORKS_FOR]->(Company)
            - (Company)-[:PRODUCES]->(Product)
            - (Company)-[:OPERATES_IN]->(Market)
            
            Return only a valid Cypher query without explanation.
            """,
            
            Domain.TECHNICAL: """
            You are an expert in technical knowledge graphs and Cypher queries.
            Convert natural language technical queries into Cypher for Neo4j.
            
            Common Technical Schema:
            - Technologies: (:Technology {name, version, category})
            - Components: (:Component {name, function, type})
            - Systems: (:System {name, architecture, purpose})
            - Companies: (:Company {name, industry})
            
            Relationships:
            - (Technology)-[:USES]->(Component)
            - (System)-[:IMPLEMENTS]->(Technology)
            - (Company)-[:DEVELOPS]->(Technology)
            
            Return only a valid Cypher query without explanation.
            """,
            
            Domain.GENERAL: """
            You are an expert Cypher query generator for general knowledge graphs.
            Convert natural language queries into Cypher for Neo4j.
            
            General Schema:
            - Entities: (:Entity {name, type})
            - People: (:Person {name, role})
            - Organizations: (:Organization {name, type})
            - Events: (:Event {name, date})
            
            Relationships:
            - (Entity)-[:RELATED_TO]->(Entity)
            - (Person)-[:WORKS_FOR]->(Organization)
            - (Person)-[:PARTICIPATED_IN]->(Event)
            
            Return only a valid Cypher query without explanation.
            """
        }
    
    def convert(self, query: str, domain: Domain = None, context: Dict = None) -> ConversionResult:
        """Convert natural language query to Cypher"""
        start_time = datetime.now()
        
        try:
            # Auto-detect domain if not provided
            if domain is None:
                domain = self.domain_detector.detect_domain(query)
            
            # Check cache first
            cache_key = self.generate_cache_key(query, domain, context)
            cached_result = self.query_cache.get(cache_key)
            
            if cached_result:
                logger.info(f"Cache hit for query: {query[:50]}...")
                return cached_result
            
            # Generate Cypher query
            cypher_query = self.generate_cypher(query, domain, context)
            
            # Calculate confidence
            confidence = self.calculate_confidence(query, cypher_query)
            
            processing_time = (datetime.now() - start_time).total_seconds()
            
            result = ConversionResult(
                cypher_query=cypher_query,
                domain=domain.value,
                confidence=confidence,
                processing_time=processing_time,
                metadata={
                    "query_length": len(query),
                    "generated_at": datetime.now().isoformat(),
                    "provider": self.llm_config.provider.value,
                    "model": self.llm_config.model_name,
                    "cache_key": cache_key
                }
            )
            
            # Cache the result
            self.query_cache.set(cache_key, result)
            
            return result
            
        except Exception as e:
            logger.error(f"Error in conversion: {e}")
            # Fallback to rule-based
            cypher_query = self.rule_based_converter.convert(query, domain or Domain.GENERAL)
            processing_time = (datetime.now() - start_time).total_seconds()
            
            return ConversionResult(
                cypher_query=cypher_query,
                domain=(domain or Domain.GENERAL).value,
                confidence=0.5,
                processing_time=processing_time,
                metadata={
                    "fallback_used": True,
                    "error": str(e),
                    "provider": "rule_based"
                }
            )
    
    def generate_cypher(self, query: str, domain: Domain, context: Dict = None) -> str:
        """Generate Cypher query"""
        try:
            if self.llm_client:
                return self.generate_cypher_with_llm(query, domain, context)
            else:
                return self.rule_based_converter.convert(query, domain)
        except Exception as e:
            logger.error(f"Error generating Cypher: {e}")
            return self.rule_based_converter.convert(query, domain)
    
    def generate_cypher_with_llm(self, query: str, domain: Domain, context: Dict = None) -> str:
        """Generate Cypher using LLM"""
        system_prompt = self.domain_prompts.get(domain, self.domain_prompts[Domain.GENERAL])
        
        enhanced_query = query
        if context:
            context_str = json.dumps(context, indent=2)
            enhanced_query = f"Context: {context_str}\n\nQuery: {query}"
        
        cypher_query = self.llm_client.generate(system_prompt, enhanced_query)
        return self.clean_cypher_query(cypher_query)
    
    def clean_cypher_query(self, cypher_query: str) -> str:
        """Clean up generated Cypher query"""
        cypher_query = cypher_query.replace('```cypher', '').replace('```', '')
        cypher_query = ' '.join(cypher_query.split())
        
        valid_starters = ['MATCH', 'CREATE', 'MERGE', 'DELETE', 'SET', 'WITH', 'UNWIND', 'RETURN']
        if not any(cypher_query.upper().startswith(starter) for starter in valid_starters):
            raise ValueError("Generated query doesn't start with a valid Cypher command")
        
        return cypher_query
    
    def calculate_confidence(self, nl_query: str, cypher_query: str) -> float:
        """Calculate confidence score"""
        try:
            confidence = 0.5
            
            nl_words = set(nl_query.lower().split())
            cypher_words = set(cypher_query.lower().split())
            
            common_words = {'the', 'and', 'or', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
            nl_words -= common_words
            cypher_words -= common_words
            
            if nl_words:
                overlap = len(nl_words.intersection(cypher_words))
                confidence += min(overlap / len(nl_words), 0.3)
            
            if 'MATCH' in cypher_query and 'RETURN' in cypher_query:
                confidence += 0.1
            
            if '-[' in cypher_query and ']->' in cypher_query:
                confidence += 0.1
            
            return min(confidence, 1.0)
            
        except Exception:
            return 0.5
    
    def generate_cache_key(self, query: str, domain: Domain, context: Dict = None) -> str:
        """Generate cache key for query"""
        key_data = {
            "query": query.lower().strip(),
            "domain": domain.value,
            "context": context or {}
        }
        key_string = json.dumps(key_data, sort_keys=True)
        return hashlib.md5(key_string.encode()).hexdigest()
    
    def get_cache_stats(self):
        """Get cache statistics"""
        return self.query_cache.get_stats()

# Convenience functions for easy usage
def create_converter(provider: str = "rule_based", model: str = None) -> NaturalLanguageToCypherConverter:
    """Create converter with specified provider (rule_based or ollama only)"""
    
    if provider.lower() == "ollama":
        config = LLMConfig(
            provider=LLMProvider.OLLAMA,
            model_name=model or "llama3.1:8b",
            base_url="http://localhost:11434"
        )
    else:  # rule_based
        config = LLMConfig(
            provider=LLMProvider.RULE_BASED,
            model_name="pattern-matching"
        )
    
    return NaturalLanguageToCypherConverter(config)

def convert_query(query: str, domain: str = None, provider: str = "rule_based", **kwargs) -> ConversionResult:
    """Quick conversion function"""
    converter = create_converter(provider, **kwargs)
    domain_enum = Domain(domain) if domain else None
    return converter.convert(query, domain_enum)

# Interactive Interface
def simple_query_interface():
    """Simple interface for querying"""
    
    print(" Natural Language to Cypher Converter")
    print(" Available Methods: Rule-based + Ollama")
    print("=" * 50)
    
    # Choose method
    print("\nChoose conversion method:")
    print("1. Rule-based (Free, instant)")
    print("2. Ollama (Free, better results - requires setup)")
    
    method_choice = input("Enter choice (1-2): ").strip()
    
    # Setup converter
    if method_choice == "2":
        print("\n🦙 Setting up Ollama...")
        model = input("Enter Ollama model (default=llama3.1:8b): ").strip() or "llama3.1:8b"
        try:
            converter = create_converter("ollama", model=model)
            print(f"Using Ollama with model: {model}")
        except Exception as e:
            print(f" Ollama setup failed: {e}")
            print(" To fix this:")
            print("   1. Install Ollama: https://ollama.ai/")
            print("   2. Pull model: ollama pull llama3.1:8b")
            print("   3. Start server: ollama serve")
            print("\n Falling back to rule-based...")
            converter = create_converter("rule_based")
    else:
        converter = create_converter("rule_based")
        print(" Using Rule-based converter")
    
    print("\n" + "="*50)
    print(" Try queries like:")
    print("   - What are the treatments for fabry disease?")
    print("   - Find all patients with diabetes")
    print("   - Show me companies in tech industry")
    print("="*50)
    
    # Main query loop
    while True:
        try:
            user_query = input("\nEnter a natural language query: ").strip()
            
            if not user_query:
                print("Please enter a query!")
                continue
                
            if user_query.lower() in ['quit', 'exit', 'q']:
                print(" Goodbye!")
                break
            
            # Convert query
            result = converter.convert(user_query)
            
            # Display result
            print("\n```")
            print("Generated Cypher:")
            print(result.cypher_query)
            print("```")
            
            print(f"\n Confidence: {result.confidence:.2f}")
            print(f" Domain: {result.domain}")
            print(f" Method: {result.metadata.get('provider', 'unknown')}")
            print(f" Time: {result.processing_time:.3f}s")
            
            print("\n" + "-"*50)
            
        except KeyboardInterrupt:
            print("\n Goodbye!")
            break
        except Exception as e:
            print(f" Error: {e}")
            continue

# Test the converter
if __name__ == "__main__":
    print(" Natural Language to Cypher Converter - Rule-based + Ollama")
    print("=" * 60)
    
    # Test with rule-based first
    print("\n🧪 Testing Rule-based converter:")
    converter = create_converter("rule_based")
    
    test_queries = [
        ("What are the treatments for fabry disease?", Domain.MEDICAL),
        ("Find all patients with diabetes", Domain.MEDICAL),
        ("Show me companies in the tech industry", Domain.BUSINESS),
    ]
    
    for query, domain in test_queries:
        result = converter.convert(query, domain)
        print(f"\n🔍 Query: {query}")
        print(f"⚡ Cypher: {result.cypher_query}")
        print(f" Confidence: {result.confidence:.2f}")
    
    print(f"\n Cache Stats: {converter.get_cache_stats()}")
    
    # Run interactive interface
    print("\n" + "="*60)
    print(" Starting Interactive Interface...")
    print("="*60)
    
    simple_query_interface()