In [5]:
# Improved Natural Language to Cypher Converter
# Enhanced rule-based converter with better pattern matching

import json
import logging
import re
import requests
import os
import hashlib
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime
from enum import Enum
from dataclasses import dataclass

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Domain(str, Enum):
    MEDICAL = "medical"
    LEGAL = "legal"
    TECHNICAL = "technical"
    GENERAL = "general"
    SCIENCE = "science"
    BUSINESS = "business"
    SOCIAL = "social"
    FINANCE = "finance"

class LLMProvider(str, Enum):
    OLLAMA = "ollama"
    RULE_BASED = "rule_based"

@dataclass
class LLMConfig:
    provider: LLMProvider
    model_name: str
    api_key: Optional[str] = None
    base_url: Optional[str] = None
    max_tokens: int = 512
    temperature: float = 0.1

@dataclass
class ConversionResult:
    cypher_query: str
    domain: str
    confidence: float
    processing_time: float
    metadata: Dict[str, Any]

class ImprovedRuleBasedConverter:
    """Enhanced rule-based converter with better pattern matching"""
    
    def __init__(self):
        self.setup_patterns()
        self.setup_domain_schemas()
    
    def setup_patterns(self):
        """Setup improved query patterns"""
        self.query_patterns = {
            # Medical patterns
            'medical_treatments': {
                'pattern': r'(?:treatments?|therapies?|cures?)\s+(?:for|of)\s+(.+?)(?:\?|$)',
                'template': "MATCH (d:Disease {{name: '{}'}})-[:TREATED_BY]->(t:Treatment) RETURN d, t",
                'confidence': 0.85
            },
            'medical_symptoms': {
                'pattern': r'(?:symptoms?|signs?)\s+(?:of|for)\s+(.+?)(?:\?|$)',
                'template': "MATCH (d:Disease {{name: '{}'}})-[:HAS_SYMPTOM]->(s:Symptom) RETURN d, s",
                'confidence': 0.85
            },
            'medical_patients': {
                'pattern': r'(?:patients?|people)\s+(?:with|having)\s+(.+?)(?:\?|$)',
                'template': "MATCH (p:Patient)-[:HAS_CONDITION]->(d:Disease {{name: '{}'}}) RETURN p, d",
                'confidence': 0.85
            },
            'medical_find_disease': {
                'pattern': r'(?:find|show|get|list)\s+(?:all\s+)?(?:diseases?|conditions?)\s+(.+?)(?:\?|$)',
                'template': "MATCH (d:Disease) WHERE d.name CONTAINS '{}' RETURN d",
                'confidence': 0.80
            },
            
            # Business patterns
            'business_companies': {
                'pattern': r'(?:companies?|businesses?)\s+(?:in|from)\s+(?:the\s+)?(.+?)(?:\s+industry|\s+sector|$|\?)',
                'template': "MATCH (c:Company {{industry: '{}'}}) RETURN c",
                'confidence': 0.85
            },
            'business_employees': {
                'pattern': r'(?:employees?|people|workers?)\s+(?:at|in|from)\s+(.+?)(?:\?|$)',
                'template': "MATCH (p:Person)-[:WORKS_FOR]->(c:Company {{name: '{}'}}) RETURN p, c",
                'confidence': 0.85
            },
            'business_products': {
                'pattern': r'(?:products?|items?)\s+(?:by|from|made by)\s+(.+?)(?:\?|$)',
                'template': "MATCH (c:Company {{name: '{}'}})-[:PRODUCES]->(p:Product) RETURN c, p",
                'confidence': 0.85
            },
            
            # General patterns
            'general_find': {
                'pattern': r'(?:find|show|get|list)\s+(?:all\s+)?(.+?)(?:\?|$)',
                'template': "MATCH (n:{}) RETURN n",
                'confidence': 0.70
            },
            'general_relationship': {
                'pattern': r'(.+?)\s+(?:related to|connected to|associated with)\s+(.+?)(?:\?|$)',
                'template': "MATCH (a)-[r]-(b) WHERE a.name CONTAINS '{}' AND b.name CONTAINS '{}' RETURN a, r, b",
                'confidence': 0.75
            }
        }
    
    def setup_domain_schemas(self):
        """Setup domain-specific node mappings"""
        self.domain_schemas = {
            Domain.MEDICAL: {
                'nodes': {
                    'disease': 'Disease',
                    'diseases': 'Disease',
                    'condition': 'Disease',
                    'conditions': 'Disease',
                    'disorder': 'Disease',
                    'disorders': 'Disease',
                    'illness': 'Disease',
                    'patient': 'Patient',
                    'patients': 'Patient',
                    'doctor': 'Doctor',
                    'doctors': 'Doctor',
                    'physician': 'Doctor',
                    'treatment': 'Treatment',
                    'treatments': 'Treatment',
                    'therapy': 'Treatment',
                    'therapies': 'Treatment',
                    'medication': 'Medication',
                    'medications': 'Medication',
                    'drug': 'Medication',
                    'drugs': 'Medication',
                    'symptom': 'Symptom',
                    'symptoms': 'Symptom'
                },
                'relationships': {
                    'has_symptom': 'HAS_SYMPTOM',
                    'treated_by': 'TREATED_BY',
                    'has_condition': 'HAS_CONDITION',
                    'treats': 'TREATS',
                    'prescribed': 'PRESCRIBED'
                }
            },
            Domain.BUSINESS: {
                'nodes': {
                    'company': 'Company',
                    'companies': 'Company',
                    'business': 'Company',
                    'businesses': 'Company',
                    'organization': 'Company',
                    'person': 'Person',
                    'people': 'Person',
                    'employee': 'Person',
                    'employees': 'Person',
                    'worker': 'Person',
                    'workers': 'Person',
                    'product': 'Product',
                    'products': 'Product',
                    'service': 'Product',
                    'services': 'Product'
                },
                'relationships': {
                    'works_for': 'WORKS_FOR',
                    'produces': 'PRODUCES',
                    'owns': 'OWNS',
                    'manages': 'MANAGES'
                }
            }
        }
    
    def convert(self, query: str, domain: Domain) -> str:
        """Convert using improved rule-based approach"""
        query_lower = query.lower().strip()
        
        # First try domain-specific patterns
        cypher_query = self.try_domain_patterns(query_lower, domain)
        if cypher_query:
            return cypher_query
        
        # Try general patterns
        cypher_query = self.try_general_patterns(query_lower, domain)
        if cypher_query:
            return cypher_query
        
        # Fallback to basic pattern matching
        return self.basic_pattern_matching(query_lower, domain)
    
    def try_domain_patterns(self, query: str, domain: Domain) -> Optional[str]:
        """Try domain-specific patterns"""
        domain_prefix = domain.value
        
        for pattern_name, pattern_info in self.query_patterns.items():
            if pattern_name.startswith(domain_prefix):
                match = re.search(pattern_info['pattern'], query, re.IGNORECASE)
                if match:
                    entity = match.group(1).strip()
                    entity = self.clean_entity_name(entity)
                    
                    # Format the template
                    cypher_query = pattern_info['template'].format(entity)
                    return cypher_query
        
        return None
    
    def try_general_patterns(self, query: str, domain: Domain) -> Optional[str]:
        """Try general patterns"""
        for pattern_name, pattern_info in self.query_patterns.items():
            if pattern_name.startswith('general_'):
                match = re.search(pattern_info['pattern'], query, re.IGNORECASE)
                if match:
                    if pattern_name == 'general_find':
                        entity = match.group(1).strip()
                        node_type = self.map_entity_to_node(entity, domain)
                        cypher_query = pattern_info['template'].format(node_type)
                        return cypher_query + " LIMIT 10"
                    elif pattern_name == 'general_relationship':
                        entity1 = match.group(1).strip()
                        entity2 = match.group(2).strip()
                        cypher_query = pattern_info['template'].format(entity1, entity2)
                        return cypher_query + " LIMIT 10"
        
        return None
    
    def basic_pattern_matching(self, query: str, domain: Domain) -> str:
        """Basic pattern matching as fallback"""
        # Extract entities
        entities = self.extract_entities(query, domain)
        
        # Determine operation
        if any(word in query for word in ["find", "show", "get", "list", "what", "who", "where"]):
            if len(entities) == 1:
                return f"MATCH (n:{entities[0]}) RETURN n LIMIT 10"
            elif len(entities) >= 2:
                return f"MATCH (a:{entities[0]})-[r]-(b:{entities[1]}) RETURN a, r, b LIMIT 10"
        
        # Default fallback
        return "MATCH (n) RETURN n LIMIT 10"
    
    def extract_entities(self, query: str, domain: Domain) -> List[str]:
        """Extract entities from query"""
        entities = []
        query_words = query.lower().split()
        
        schema = self.domain_schemas.get(domain)
        if schema:
            for word in query_words:
                # Remove punctuation
                clean_word = re.sub(r'[^\w\s]', '', word)
                if clean_word in schema['nodes']:
                    entities.append(schema['nodes'][clean_word])
        
        return list(set(entities)) if entities else ['Entity']
    
    def map_entity_to_node(self, entity: str, domain: Domain) -> str:
        """Map entity string to node type"""
        entity_lower = entity.lower()
        
        schema = self.domain_schemas.get(domain)
        if schema:
            # Direct mapping
            if entity_lower in schema['nodes']:
                return schema['nodes'][entity_lower]
            
            # Partial matching
            for key, value in schema['nodes'].items():
                if key in entity_lower or entity_lower in key:
                    return value
        
        # Default mapping
        return entity.replace(' ', '').title()
    
    def clean_entity_name(self, entity: str) -> str:
        """Clean entity name for Cypher"""
        # Remove common stop words
        stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
        words = entity.split()
        cleaned_words = [word for word in words if word.lower() not in stop_words]
        return ' '.join(cleaned_words).strip()

class DomainDetector:
    """Automatically detect domain from query"""
    
    def __init__(self):
        self.domain_keywords = {
            Domain.MEDICAL: {
                'disease', 'symptom', 'treatment', 'medication', 'diagnosis',
                'patient', 'hospital', 'doctor', 'therapy', 'surgery', 'clinic',
                'prescription', 'medicine', 'health', 'medical', 'cure', 'illness',
                'fever', 'pain', 'infection', 'virus', 'bacteria', 'cancer',
                'fabry', 'diabetes', 'condition', 'disorder'
            },
            Domain.BUSINESS: {
                'company', 'business', 'market', 'product', 'revenue',
                'profit', 'customer', 'client', 'sales', 'marketing',
                'strategy', 'competition', 'industry', 'corporate', 'financial',
                'startup', 'investment', 'entrepreneur', 'management', 'finance',
                'tech', 'technology', 'sector'
            }
        }
    
    def detect_domain(self, query: str) -> Domain:
        """Detect domain from query text"""
        query_lower = query.lower()
        query_words = set(query_lower.split())
        
        domain_scores = {}
        for domain, keywords in self.domain_keywords.items():
            score = len(query_words.intersection(keywords))
            domain_scores[domain] = score
        
        max_domain = max(domain_scores, key=domain_scores.get)
        return max_domain if domain_scores[max_domain] > 0 else Domain.GENERAL

# Test the improved converter
def test_improved_converter():
    print("🧪 Testing Improved Rule-based Converter")
    print("=" * 50)
    
    converter = ImprovedRuleBasedConverter()
    detector = DomainDetector()
    
    test_queries = [
        "What are the treatments for fabry disease?",
        "Find all patients with diabetes",
        "Show me companies in tech industry",
        "What are the symptoms of cancer?",
        "List employees at Google",
        "Find products by Apple"
    ]
    
    for query in test_queries:
        print(f"\n Query: {query}")
        
        # Detect domain
        domain = detector.detect_domain(query)
        print(f" Domain: {domain.value}")
        
        # Convert to Cypher
        cypher = converter.convert(query, domain)
        print(f" Cypher: {cypher}")
        
        print("-" * 40)

if __name__ == "__main__":
    test_improved_converter()

 Testing Improved Rule-based Converter

 Query: What are the treatments for fabry disease?
 Domain: medical
 Cypher: MATCH (d:Disease {name: 'fabry disease'})-[:TREATED_BY]->(t:Treatment) RETURN d, t
----------------------------------------

 Query: Find all patients with diabetes
 Domain: medical
 Cypher: MATCH (p:Patient)-[:HAS_CONDITION]->(d:Disease {name: 'diabetes'}) RETURN p, d
----------------------------------------

 Query: Show me companies in tech industry
 Domain: business
 Cypher: MATCH (c:Company {industry: 'tech'}) RETURN c
----------------------------------------

 Query: What are the symptoms of cancer?
 Domain: general
 Cypher: MATCH (n:Entity) RETURN n LIMIT 10
----------------------------------------

 Query: List employees at Google
 Domain: general
 Cypher: MATCH (n:employees google) RETURN n
----------------------------------------

 Query: Find products by Apple
 Domain: general
 Cypher: MATCH (n:products apple) RETURN n
----------------------------------------


In [6]:
# Interactive test in Jupyter

converter = ImprovedRuleBasedConverter()
detector = DomainDetector()

# Input from user
user_query = input(" Enter your query: ")

# Detect domain
domain = detector.detect_domain(user_query)
print(f" Detected Domain: {domain.value}")

# Convert query to Cypher
cypher = converter.convert(user_query, domain)
print(f" Generated Cypher: {cypher}")


 Detected Domain: medical
 Generated Cypher: MATCH (d:Disease {name: 'piles disease'})-[:TREATED_BY]->(t:Treatment) RETURN d, t
