# Advanced Neural-Symbolic Reasoning

Advanced neuro-symbolic patterns with constraint satisfaction, knowledge graphs, and explainable AI.

## Contents

1. **Constraint Satisfaction** - Rule validation and consistency
2. **Knowledge Graph Integration** - RDF/OWL integration
3. **Explanation Generation** - Human-readable reasoning paths
4. **Active Learning** - Constraint-guided annotation


## 1. Constraint Satisfaction Module


In [None]:
import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Set, Callable
from enum import Enum
from abc import ABC, abstractmethod
from datetime import datetime
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ConstraintType(Enum):
    """Types of constraints."""
    HARD = "hard"  # Must be satisfied
    SOFT = "soft"  # Should be satisfied (penalty if violated)
    PREFERENCE = "preference"  # Bonus if satisfied


@dataclass
class Constraint:
    """A constraint on model predictions."""
    name: str
    constraint_type: ConstraintType
    description: str
    validate_fn: Callable[[Dict], Tuple[bool, str]]
    priority: int = 0  # Higher = more important
    
    def validate(self, predictions: Dict) -> Tuple[bool, str]:
        """Validate predictions against constraint."""
        return self.validate_fn(predictions)


@dataclass
class ConstraintViolation:
    """A constraint violation."""
    constraint: str
    violated: bool
    message: str
    severity: str  # error, warning, info


class ConstraintSatisfactionEngine:
    """Engine for constraint satisfaction."""
    
    def __init__(self):
        self.constraints: List[Constraint] = []
        self._constraint_index: Dict[str, Constraint] = {}
    
    def add_constraint(self, constraint: Constraint):
        """Add a constraint to the engine."""
        self.constraints.append(constraint)
        self._constraint_index[constraint.name] = constraint
        logger.info(f"Added constraint: {constraint.name}")
    
    def validate_all(self, predictions: Dict) -> List[ConstraintViolation]:
        """Validate predictions against all constraints."""
        violations = []
        
        for constraint in sorted(self.constraints, key=lambda c: -c.priority):
            valid, message = constraint.validate(predictions)
            
            if not valid:
                severity = "error" if constraint.constraint_type == ConstraintType.HARD else "warning"
                violations.append(ConstraintViolation(
                    constraint=constraint.name,
                    violated=True,
                    message=message,
                    severity=severity
                ))
            else:
                violations.append(ConstraintViolation(
                    constraint=constraint.name,
                    violated=False,
                    message="Constraint satisfied",
                    severity="info"
                ))
        
        return violations
    
    def get_hard_violations(self, violations: List[ConstraintViolation]) -> List[ConstraintViolation]:
        """Filter to only hard constraint violations."""
        return [v for v in violations if v.severity == "error"]
    
    def is_satisfiable(self, predictions: Dict) -> bool:
        """Check if predictions satisfy all hard constraints."""
        violations = self.validate_all(predictions)
        return len(self.get_hard_violations(violations)) == 0


class ConstraintTemplates:
    """Pre-defined constraint templates."""
    
    @staticmethod
    def mutual_exclusion(
        entity_a: str,
        entity_b: str,
        predicate: str
    ) -> Constraint:
        """Create mutual exclusion constraint."""
        def validate_fn(predictions: Dict) -> Tuple[bool, str]:
            relations = predictions.get("relations", [])
            has_a = any(r.get("subject") == entity_a and r.get("predicate") == predicate for r in relations)
            has_b = any(r.get("subject") == entity_b and r.get("predicate") == predicate for r in relations)
            
            if has_a and has_b:
                return False, f"Entities {entity_a} and {entity_b} cannot both have relation {predicate}"
            return True, "No mutual exclusion violation"
        
        return Constraint(
            name=f"mutual_exclusion_{entity_a}_{entity_b}",
            constraint_type=ConstraintType.HARD,
            description=f"{entity_a} and {entity_b} are mutually exclusive for {predicate}",
            validate_fn=validate_fn,
            priority=10
        )
    
    @staticmethod
    def at_most_one(values: List[str], predicate: str) -> Constraint:
        """Create at-most-one constraint."""
        def validate_fn(predictions: Dict) -> Tuple[bool, str]:
            relations = predictions.get("relations", [])
            count = sum(1 for r in relations if r.get("predicate") == predicate)
            
            if count > 1:
                return False, f"At most one entity can have {predicate}, found {count}"
            return True, "At-most-one constraint satisfied"
        
        return Constraint(
            name=f"at_most_one_{predicate}",
            constraint_type=ConstraintType.HARD,
            description=f"At most one entity can have {predicate}",
            validate_fn=validate_fn,
            priority=10
        )
    
    @staticmethod
    def range_constraint(
        field: str,
        min_val: float,
        max_val: float
    ) -> Constraint:
        """Create range constraint for numeric fields."""
        def validate_fn(predictions: Dict) -> Tuple[bool, str]:
            value = predictions.get(field, 0)
            
            if value < min_val:
                return False, f"{field}={value} is below minimum {min_val}"
            if value > max_val:
                False, f"{field}={value} is above maximum {max_val}"
            return True, f"{field}={value} is within range [{min_val}, {max_val}]"
        
        return Constraint(
            name=f"range_{field}",
            constraint_type=ConstraintType.HARD,
            description=f"{field} must be in [{min_val}, {max_val}]",
            validate_fn=validate_fn
        )
    
    @staticmethod
    def type_constraint(
        entity: str,
        required_types: List[str]
    ) -> Constraint:
        """Create type constraint."""
        def validate_fn(predictions: Dict) -> Tuple[bool, str]:
            entities = predictions.get("entities", [])
            target = next((e for e in entities if e.get("id") == entity), None)
            
            if not target:
                return False, f"Entity {entity} not found"
            
            entity_types = target.get("types", [])
            if not any(t in required_types for t in entity_types):
                return False, f"Entity {entity} must be one of {required_types}"
            return True, f"Entity {entity} has valid type"
        
        return Constraint(
            name=f"type_{entity}",
            constraint_type=ConstraintType.HARD,
            description=f"Entity {entity} must have type in {required_types}",
            validate_fn=validate_fn
        )


## 2. Knowledge Graph Integration


In [None]:
from collections import defaultdict
import json

@dataclass
class KGEntity:
    """Knowledge graph entity."""
    id: str
    label: str
    types: List[str]
    properties: Dict = field(default_factory=dict)
    embeddings: Optional[np.ndarray] = None

@dataclass
class KGRelation:
    """Knowledge graph relation."""
    subject: str
    predicate: str
    object: str
    confidence: float = 1.0


class KnowledgeGraph:
    """In-memory knowledge graph for reasoning."""
    
    def __init__(self):
        self.entities: Dict[str, KGEntity] = {}
        self.relations: List[KGRelation] = []
        self.index: Dict[str, List[KGRelation]] = defaultdict(list)
    
    def add_entity(self, entity: KGEntity):
        """Add entity to graph."""
        self.entities[entity.id] = entity
        logger.debug(f"Added entity: {entity.id}")
    
    def add_relation(self, relation: KGRelation):
        """Add relation to graph."""
        self.relations.append(relation)
        self.index[relation.subject].append(relation)
        self.index[relation.object].append(relation)  # Also index by object
        logger.debug(f"Added relation: {relation.subject} -> {relation.predicate} -> {relation.object}")
    
    def get_entity(self, entity_id: str) -> Optional[KGEntity]:
        """Get entity by ID."""
        return self.entities.get(entity_id)
    
    def get_neighbors(self, entity_id: str, relation: str = None) -> List[Tuple[KGRelation, KGEntity]]:
        """Get neighboring entities."""
        neighbors = []
        
        for rel in self.index.get(entity_id, []):
            if relation and rel.predicate != relation:
                continue
            
            target_id = rel.object if rel.subject == entity_id else rel.subject
            target = self.get_entity(target_id)
            
            if target:
                neighbors.append((rel, target))
        
        return neighbors
    
    def find_path(
        self,
        source: str,
        target: str,
        max_depth: int = 3
    ) -> List[List[KGRelation]]:
        """Find paths between two entities (BFS)."""
        from collections import deque
        
        queue = deque([(source, [source])])
        visited = {source}
        paths = []
        
        while queue and len(paths) < 10:
            current, path = queue.popleft()
            
            if current == target:
                paths.append(path)
                continue
            
            if len(path) > max_depth:
                continue
            
            for rel in self.index.get(current, []):
                next_entity = rel.object if rel.subject == current else rel.subject
                
                if next_entity not in visited:
                    visited.add(next_entity)
                    queue.append((next_entity, path + [next_entity]))
        
        return paths
    
    def to_dict(self) -> Dict:
        """Export graph to dictionary."""
        return {
            "entities": {
                eid: {
                    "id": ent.id,
                    "label": ent.label,
                    "types": ent.types,
                    "properties": ent.properties
                }
                for eid, ent in self.entities.items()
            },
            "relations": [
                {
                    "subject": r.subject,
                    "predicate": r.predicate,
                    "object": r.object,
                    "confidence": r.confidence
                }
                for r in self.relations
            ]
        }


class KGReasoner:
    """Reason over knowledge graph."""
    
    def __init__(self, knowledge_graph: KnowledgeGraph):
        self.kg = knowledge_graph
    
    def find_related_entities(
        self,
        entity_id: str,
        relation_types: List[str] = None,
        depth: int = 2
    ) -> List[Tuple[KGEntity, str, float]]:
        """Find entities related to a given entity."""
        results = []
        
        def dfs(current: str, current_depth: int, visited: Set[str], path: List[str]):
            if current_depth > depth or current in visited:
                return
            
            visited.add(current)
            
            for rel, neighbor in self.kg.get_neighbors(current):
                if relation_types and rel.predicate not in relation_types:
                    continue
                
                results.append((
                    neighbor,
                    " -> ".join(path + [rel.predicate]),
                    rel.confidence
                ))
                
                dfs(neighbor.id, current_depth + 1, visited, path + [rel.predicate])
        
        entity = self.kg.get_entity(entity_id)
        if entity:
            dfs(entity_id, 0, set(), [entity.label])
        
        return results
    
    def validate_predictions(self, predictions: Dict) -> Tuple[bool, List[str]]:
        """Validate predictions against knowledge graph."""
        issues = []
        
        # Check entity existence
        for entity_id in predictions.get("entity_ids", []):
            if not self.kg.get_entity(entity_id):
                issues.append(f"Unknown entity: {entity_id}")
        
        # Check relation validity
        for rel in predictions.get("relations", []):
            subj = self.kg.get_entity(rel.get("subject"))
            obj = self.kg.get_entity(rel.get("object"))
            
            if not subj or not obj:
                issues.append(f"Invalid relation: {rel}")
        
        return len(issues) == 0, issues


## 3. Explanation Generation


In [None]:
@dataclass
class ExplanationStep:
    """A single step in an explanation."""
    step_id: int
    description: str
    confidence: float
    rule_applied: str
    evidence: List[str]

@dataclass
class Explanation:
    """A complete explanation for a prediction."""
    prediction: str
    confidence: float
    steps: List[ExplanationStep]
    rule_overrides: List[str]
    uncertainty_factors: List[str]
    
    def to_text(self) -> str:
        """Convert explanation to human-readable text."""
        lines = []
        lines.append(f"PREDICTION: {self.prediction}")
        lines.append(f"CONFIDENCE: {self.confidence:.2%}")
        lines.append("\nREASONING CHAIN:")
        
        for i, step in enumerate(self.steps, 1):
            lines.append(f"  {i}. [{step.rule_applied}] {step.description}")
            if step.evidence:
                for ev in step.evidence[:3]:  # Limit evidence
                    lines.append(f"     → {ev}")
        
        if self.rule_overrides:
            lines.append(f"\nRULE OVERRIDES: {', '.join(self.rule_overrides)}")
        
        if self.uncertainty_factors:
            lines.append(f"\nUNCERTAINTY FACTORS:")
            for uf in self.uncertainty_factors:
                lines.append(f"  • {uf}")
        
        return "\n".join(lines)


class ExplanationGenerator:
    """Generate human-readable explanations."""
    
    def __init__(self, knowledge_graph: KnowledgeGraph = None):
        self.kg = knowledge_graph
        self.rule_registry: Dict[str, str] = {}
    
    def register_rule(self, rule_id: str, description: str):
        """Register a rule for explanation."""
        self.rule_registry[rule_id] = description
    
    def explain_prediction(
        self,
        prediction: str,
        model_outputs: Dict,
        applied_rules: List[str],
        constraints_violated: List[str] = None,
        uncertainty: float = 0.0
    ) -> Explanation:
        """Generate explanation for a prediction."""
        
        # Build reasoning steps
        steps = []
        
        # Step 1: Neural model contribution
        steps.append(ExplanationStep(
            step_id=1,
            description="Neural model analyzed input features",
            confidence=model_outputs.get("model_confidence", 0.8),
            rule_applied="NEURAL_MODEL",
            evidence=self._extract_evidence(model_outputs)
        ))
        
        # Step 2: Rule-based reasoning
        for i, rule in enumerate(applied_rules, 2):
            rule_desc = self.rule_registry.get(rule, f"Rule {rule}")
            steps.append(ExplanationStep(
                step_id=i,
                description=rule_desc,
                confidence=1.0 if rule in applied_rules else 0.0,
                rule_applied=rule,
                evidence=[]
            ))
        
        # Calculate overall confidence
        confidences = [s.confidence for s in steps]
        overall_confidence = np.mean(confidences) * (1.0 - uncertainty)
        
        # Identify uncertainty factors
        uncertainty_factors = []
        if uncertainty > 0.2:
            uncertainty_factors.append(f"High input uncertainty ({uncertainty:.1%})")
        if constraints_violated:
            uncertainty_factors.append(f"Constraint violations: {len(constraints_violated)}")
        if model_outputs.get("low_confidence_features"):
            uncertainty_factors.append("Some input features had low model confidence")
        
        return Explanation(
            prediction=prediction,
            confidence=overall_confidence,
            steps=steps,
            rule_overrides=applied_rules,
            uncertainty_factors=uncertainty_factors
        )
    
    def _extract_evidence(self, model_outputs: Dict) -> List[str]:
        """Extract evidence from model outputs."""
        evidence = []
        
        # Top features
        top_features = model_outputs.get("top_features", [])[:3]
        for feat in top_features:
            evidence.append(f"Feature '{feat.get('name')}' score: {feat.get('score', 0):.3f}")
        
        return evidence


def demonstrate_neurosymbolic_advanced():
    """Demo advanced neural-symbolic reasoning."""
    
    print("=" * 70)
    print("ADVANCED NEURAL-SYMBOLIC REASONING DEMO")
    print("=" * 70)
    
    # 1. Constraint Satisfaction
    print("\n[1] Constraint Satisfaction Engine")
    engine = ConstraintSatisfactionEngine()
    
    # Add mutual exclusion constraint
    engine.add_constraint(ConstraintTemplates.mutual_exclusion(
        "Entity_A", "Entity_B", "located_in"
    ))
    
    # Add range constraint
    engine.add_constraint(ConstraintTemplates.range_constraint(
        "confidence", 0.0, 1.0
    ))
    
    # Test valid predictions
    valid_preds = {
        "relations": [
            {"subject": "Entity_A", "predicate": "located_in", "object": "France"}
        ],
        "confidence": 0.85
    }
    
    violations = engine.validate_all(valid_preds)
    print(f"  Valid predictions violations: {len([v for v in violations if v.violated])}")
    
    # Test invalid predictions
    invalid_preds = {
        "relations": [
            {"subject": "Entity_A", "predicate": "located_in", "object": "France"},
            {"subject": "Entity_B", "predicate": "located_in", "object": "Germany"}
        ],
        "confidence": 1.5  # Out of range
    }
    
    violations = engine.validate_all(invalid_preds)
    hard_violations = engine.get_hard_violations(violations)
    print(f"  Invalid predictions hard violations: {len(hard_violations)}")
    
    # 2. Knowledge Graph
    print("\n[2] Knowledge Graph Reasoning")
    kg = KnowledgeGraph()
    
    # Add entities
    kg.add_entity(KGEntity(
        id="paris",
        label="Paris",
        types=["City", "Capital"],
        properties={"population": 2148000}
    ))
    kg.add_entity(KGEntity(
        id="france",
        label="France",
        types=["Country"],
        properties={"population": 67390000}
    ))
    kg.add_entity(KGEntity(
        id="europe",
        label="Europe",
        types=["Continent"]
    ))
    
    # Add relations
    kg.add_relation(KGRelation("paris", "capital_of", "france"))
    kg.add_relation(KGRelation("france", "located_in", "europe"))
    
    # Find path
    paths = kg.find_path("paris", "europe", max_depth=3)
    print(f"  Paris -> Europe paths found: {len(paths)}")
    
    # 3. Explanation Generation
    print("\n[3] Explanation Generation")
    explainer = ExplanationGenerator(kg)
    explainer.register_rule("MUTUAL_EXCLUSION", "Entities cannot both have this relation")
    
    explanation = explainer.explain_prediction(
        prediction="Paris is the capital of France",
        model_outputs={"model_confidence": 0.92, "top_features": [
            {"name": "location_context", "score": 0.88}
        ]},
        applied_rules=["CAPITAL_VERIFICATION"],
        constraints_violated=[],
        uncertainty=0.08
    )
    
    print("\n" + explanation.to_text())
    
    return {
        "constraint_engine": engine,
        "knowledge_graph": kg,
        "explainer": explainer
    }


if __name__ == "__main__":
    demonstrate_neurosymbolic_advanced()

## 4. Summary

This notebook demonstrates:

- **Constraint Satisfaction**: Hard/soft constraints with validation
- **Knowledge Graphs**: In-memory graph with path finding
- **Explanation Generation**: Human-readable reasoning chains
- **Integration**: Combined neuro-symbolic reasoning pipeline

The module provides enterprise-grade explainable AI capabilities.
