# Setup and Configuration

In [12]:
import torch
import os
from datetime import datetime, timedelta
from transformers import BertTokenizer, BertModel
from dateutil import parser
import re
from dataclasses import dataclass, field
import time
from typing import List, Tuple, Any, Set, Dict, Optional, Union, Callable
import itertools
import random
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from collections import Counter
import json
import pprint

class Config:
    BERT_MODEL_NAME: str = 'bert-base-uncased'
    DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    MAX_LENGTH: int = 128
    MAX_REASONING_DEPTH: int = 5
    INITIAL_RULE_CONFIDENCE: float = 0.9
    PLAUSIBILITY_THRESHOLD: float = 0.5
    ADAPTATION_LEARNING_RATE_NEURAL: float = 1e-5
    ADAPTATION_LEARNING_RATE_SYMBOLIC: float = 0.1
    RULE_GEN_MIN_SUPPORT: int = 2
    RULE_GEN_CLUSTER_THRESHOLD: float = 0.8
    DATA_DIR: str = "data"
    SAMPLE_QA_FILENAME: str = "sample_qa.json"
    DATA_PATH: str = os.path.join(DATA_DIR, SAMPLE_QA_FILENAME)

config = Config()

print(f"Using device: {config.DEVICE}")

Using device: cuda


# Utility Functions

In [13]:
class TimeUtils:
    @staticmethod
    def parse_time(time_str: str) -> Optional[datetime]:
        if not isinstance(time_str, str):
             print(f"Warning: parse_time expected string, got {type(time_str)}")
             return None
        try:
            return parser.parse(time_str)
        except (ValueError, OverflowError, TypeError) as e:
            print(f"Warning: Could not parse time string: '{time_str}'. Error: {e}")
            return None

    @staticmethod
    def parse_duration(duration_str: str) -> Optional[timedelta]:
        if not isinstance(duration_str, str):
             print(f"Warning: parse_duration expected string, got {type(duration_str)}")
             return None

        duration_str_lower = duration_str.lower()
        minutes = 0
        hours = 0

        hour_match = re.search(r'(\d+)\s*(?:hr|hour|hours)', duration_str_lower)
        min_match = re.search(r'(\d+)\s*(?:min|minute|minutes)', duration_str_lower)

        if hour_match:
            hours += int(hour_match.group(1))
        if min_match:
            minutes += int(min_match.group(1))

        if hours == 0 and minutes == 0:
            try:
                num_match = re.search(r'^(\d+)$', duration_str.strip())
                if num_match:
                     minutes += int(num_match.group(1))
                     print(f"Warning: Interpreting standalone number '{duration_str}' as minutes.")
            except (AttributeError, ValueError):
                 pass

        if minutes > 0 or hours > 0:
            return timedelta(hours=hours, minutes=minutes)
        else:
            print(f"Warning: Could not parse duration string: '{duration_str}'")
            return None

    @staticmethod
    def format_time(dt_obj: datetime) -> str:
        if not isinstance(dt_obj, datetime):
            return "Invalid Time"
        return dt_obj.strftime("%I:%M %p")

    @staticmethod
    def format_timedelta(td_obj: timedelta) -> str:
        if not isinstance(td_obj, timedelta):
            return "Invalid Duration"

        total_minutes = int(td_obj.total_seconds() / 60)
        if total_minutes < 0:
             sign = "-"
             total_minutes = abs(total_minutes)
        else:
             sign = ""

        hours = total_minutes // 60
        minutes = total_minutes % 60
        res = []
        if hours > 0:
            res.append(f"{hours} hour{'s' if hours > 1 else ''}")
        if minutes > 0:
            res.append(f"{minutes} minute{'s' if minutes > 1 else ''}")

        if not res:
            return "0 minutes"
        else:
            return sign + " and ".join(res)

print("--- Time Parsing ---")
t1 = TimeUtils.parse_time("3:00 PM")
t2 = TimeUtils.parse_time("14:15")
t3 = TimeUtils.parse_time("noon")
print(f"'3:00 PM' -> {t1}")
print(f"'14:15' -> {t2}")
print(f"'noon' -> {t3}")
print(f"'invalid time' -> {TimeUtils.parse_time('invalid time')}")

print("\n--- Duration Parsing ---")
d1 = TimeUtils.parse_duration("30 minutes")
d2 = TimeUtils.parse_duration("1 hr and 15 min")
d3 = TimeUtils.parse_duration("45min")
d4 = TimeUtils.parse_duration("2 hours")
d5 = TimeUtils.parse_duration("invalid duration")
d6 = TimeUtils.parse_duration("120")
print(f"'30 minutes' -> {d1}")
print(f"'1 hr and 15 min' -> {d2}")
print(f"'45min' -> {d3}")
print(f"'2 hours' -> {d4}")
print(f"'invalid duration' -> {d5}")
print(f"'120' -> {d6}")

print("\n--- Formatting ---")
if t1: print(f"t1 formatted: {TimeUtils.format_time(t1)}")
if d2: print(f"d2 formatted: {TimeUtils.format_timedelta(d2)}")
if t1 and d1:
    departure = t1 - d1
    print(f"Meeting at {TimeUtils.format_time(t1)}, travel {TimeUtils.format_timedelta(d1)}, leave by {TimeUtils.format_time(departure)}")

print(f"Formatting None: {TimeUtils.format_time(None)}, {TimeUtils.format_timedelta(None)}")

--- Time Parsing ---
'3:00 PM' -> 2025-03-31 15:00:00
'14:15' -> 2025-03-31 14:15:00
'noon' -> None
'invalid time' -> None

--- Duration Parsing ---
'30 minutes' -> 0:30:00
'1 hr and 15 min' -> 1:15:00
'45min' -> 0:45:00
'2 hours' -> 2:00:00
'invalid duration' -> None
'120' -> 2:00:00

--- Formatting ---
t1 formatted: 03:00 PM
d2 formatted: 1 hour and 15 minutes
Meeting at 03:00 PM, travel 30 minutes, leave by 02:30 PM
Formatting None: Invalid Time, Invalid Duration


# Neural Component

In [14]:
class BertExtractor:
    def __init__(self, model_name: str = config.BERT_MODEL_NAME, device: torch.device = config.DEVICE):
        print(f"Loading BERT model: {model_name}...")
        try:
            self.tokenizer = BertTokenizer.from_pretrained(model_name)
            self.model = BertModel.from_pretrained(model_name)
            self.device = device
            self.model.to(self.device)
            self.model.eval()
            print("BERT model loaded.")
        except Exception as e:
            print(f"ERROR: Failed to load BERT model '{model_name}'. Check model name and internet connection. Error: {e}")
            raise e

    def get_embeddings(self, text: str) -> Optional[torch.Tensor]:
        try:
            inputs = self.tokenizer(
                text,
                return_tensors='pt',
                max_length=config.MAX_LENGTH,
                truncation=True,
                padding='max_length'
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self.model(**inputs)
            cls_embedding = outputs.last_hidden_state[:, 0, :]
            return cls_embedding.cpu()
        except Exception as e:
            print(f"Error generating embedding for text: '{text}'. Error: {e}")
            return None

    def extract_entities_relations(self, question: str) -> Dict[str, Any]:
        entities = {'times': [], 'durations': []}
        relations = []
        event_id_counter = 1
        event_name = f"event{event_id_counter}"

        time_patterns = r'\b(\d{1,2}:\d{2}\s*(?:AM|PM)?|\d{1,2}\s*(?:AM|PM))\b'
        try:
            found_times = re.findall(time_patterns, question, re.IGNORECASE)
            for t_str in found_times:
                parsed_time = TimeUtils.parse_time(t_str)
                if parsed_time and not any(t['text'] == t_str for t in entities['times']):
                    entities['times'].append({'text': t_str, 'value': parsed_time})
        except Exception as e:
            print(f"Error during time regex extraction: {e}")

        duration_patterns = r'(\d+\s*(?:minute|minutes|min|hour|hours|hr)s?(?:\s*(?:and|,)?\s*\d+\s*(?:minute|minutes|min|hour|hours|hr)s?)?)'
        simple_duration_patterns = r'(?:takes?|travel|drive|walk|lasts?|runs? for)\s*(\d+)\b'
        try:
            found_durations = re.findall(duration_patterns, question, re.IGNORECASE)
            found_simple_durations = re.findall(simple_duration_patterns, question, re.IGNORECASE)

            all_duration_texts = found_durations + [f"{m} minutes" for m in found_simple_durations]

            processed_duration_texts = set()
            for d_str in all_duration_texts:
                 d_str_cleaned = d_str.strip()
                 if d_str_cleaned in processed_duration_texts: continue

                 is_part_of_time = any(d_str_cleaned in t['text'] for t in entities['times'])
                 if not is_part_of_time:
                     parsed_duration = TimeUtils.parse_duration(d_str_cleaned)
                     if parsed_duration:
                         entities['durations'].append({'text': d_str_cleaned, 'value': parsed_duration})
                         processed_duration_texts.add(d_str_cleaned)

        except Exception as e:
            print(f"Error during duration regex extraction: {e}")

        question_lower = question.lower()
        if entities['times']:
            first_time_val = entities['times'][0]['value']
            if any(kw in question_lower for kw in ["meeting", "appointment", "deadline", "flight is at", "event at"]):
                 relations.append(('event_time', event_name, first_time_val))
            elif "start" in question_lower or "begins" in question_lower:
                 relations.append(('start_time', event_name, first_time_val))
            elif "depart" in question_lower or "leaves at" in question_lower:
                 relations.append(('departure_time', event_name, first_time_val))

        if entities['durations']:
            first_duration_val = entities['durations'][0]['value']
            if any(w in question_lower for w in ["get there", "travel", "drive", "walk", "commute", "journey"]):
                 relations.append(('travel_time', event_name, first_duration_val))
            elif any(kw in question_lower for kw in ["runs for", "lasts", "duration of", "need", "takes"]):
                 if any(r[0] in ['start_time', 'event_time'] for r in relations):
                      relations.append(('duration', event_name, first_duration_val))
                 else:
                      relations.append(('duration', event_name, first_duration_val))

        query_variable = None
        if "when should i leave" in question_lower or "departure time" in question_lower or "latest departure" in question_lower:
            query_variable = ('departure_time', event_name, '?')
        elif "when does it end" in question_lower or "end time" in question_lower or "what time does it finish" in question_lower or "when will i arrive" in question_lower:
             if any(r[0] == 'travel_time' for r in relations) or any(kw in question_lower for kw in ["arrive", "get there"]):
                  query_variable = ('arrival_time', event_name, '?')
             else:
                  query_variable = ('end_time', event_name, '?')
        elif "how long" in question_lower or "duration" in question_lower:
             query_variable = ('duration', event_name, '?')
        elif "when should i start" in question_lower or "start work" in question_lower:
             query_variable = ('start_work_time', event_name, '?')

        embedding = self.get_embeddings(question)

        return {
            "question": question,
            "entities": entities,
            "relations": relations,
            "query": query_variable,
            "embedding": embedding
        }

    def fine_tune_step(self, batch_questions: List[str], batch_labels: List[Any], optimizer, criterion):
        print("Warning: BertExtractor.fine_tune_step is a placeholder and not implemented.")
        pass


extractor = BertExtractor(model_name=config.BERT_MODEL_NAME, device=config.DEVICE)

questions = [
    "If I have a meeting at 3 PM and it takes 30 minutes to get there, when should I leave?",
    "My train departs at 10:00 AM. The journey lasts 2 hours and 15 min. When will I arrive?",
    "The workshop starts at 9 AM and runs for 3 hours. What time does it finish?",
    "How long is a meeting from 2pm to 4:30pm?"
]

for i, q in enumerate(questions):
    print(f"\n--- Analyzing Question {i+1} ---")
    extracted_info = extractor.extract_entities_relations(q)
    print(f"Question: {extracted_info['question']}")
    print("Entities:")
    for entity_type, entity_list in extracted_info['entities'].items():
        print(f"  {entity_type}:")
        for entity in entity_list:
            val_str = TimeUtils.format_time(entity['value']) if isinstance(entity['value'], datetime) else TimeUtils.format_timedelta(entity['value'])
            print(f"    - Text: '{entity['text']}', Value: {val_str}")
    print("Relations:")
    for rel in extracted_info['relations']:
        obj_str = TimeUtils.format_time(rel[2]) if isinstance(rel[2], datetime) else TimeUtils.format_timedelta(rel[2])
        print(f"  - ({rel[0]}, {rel[1]}, {obj_str})")
    print(f"Query: {extracted_info['query']}")
    if extracted_info['embedding'] is not None:
        print(f"Embedding Shape: {extracted_info['embedding'].shape}")
    else:
        print("Embedding: Failed to generate.")

Loading BERT model: bert-base-uncased...
BERT model loaded.

--- Analyzing Question 1 ---
Question: If I have a meeting at 3 PM and it takes 30 minutes to get there, when should I leave?
Entities:
  times:
    - Text: '3 PM', Value: 03:00 PM
  durations:
    - Text: '30 minutes', Value: 30 minutes
Relations:
  - (event_time, event1, 03:00 PM)
  - (travel_time, event1, 30 minutes)
Query: ('departure_time', 'event1', '?')
Embedding Shape: torch.Size([1, 768])

--- Analyzing Question 2 ---
Question: My train departs at 10:00 AM. The journey lasts 2 hours and 15 min. When will I arrive?
Entities:
  times:
    - Text: '10:00 AM', Value: 10:00 AM
  durations:
    - Text: '2 hours and 15 min', Value: 2 hours and 15 minutes
    - Text: '2 minutes', Value: 2 minutes
Relations:
  - (departure_time, event1, 10:00 AM)
  - (travel_time, event1, 2 hours and 15 minutes)
Query: ('arrival_time', 'event1', '?')
Embedding Shape: torch.Size([1, 768])

--- Analyzing Question 3 ---
Question: The workshop st

# Symbolic Component

In [15]:
@dataclass(frozen=True, eq=True)
class Fact:
    predicate: str
    subject: str
    object: Any

    def __str__(self) -> str:
        obj_str = self.object
        if isinstance(self.object, datetime):
            obj_str = TimeUtils.format_time(self.object)
        elif isinstance(self.object, timedelta):
            obj_str = TimeUtils.format_timedelta(self.object)
        elif isinstance(self.object, str) and self.object == '?':
             obj_str = '?'
        return f"{self.predicate}({self.subject}, {obj_str})"

@dataclass
class Rule:
    conditions: List[Tuple[str, str, str]]
    conclusion: Tuple[str, str, str]
    confidence: float = field(default=config.INITIAL_RULE_CONFIDENCE)
    source: str = field(default="manual")
    id: int = field(default_factory=itertools.count().__next__, init=False)

    def __str__(self) -> str:
        cond_str = " AND ".join([f"{p}({s}, {o})" for p, s, o in self.conditions])
        conc_str = f"{self.conclusion[0]}({self.conclusion[1]}, {self.conclusion[2]})"
        return f"Rule {self.id} ({self.confidence:.2f}, {self.source}): IF {cond_str} THEN {conc_str}"

class SymbolicReasoner:
    def __init__(self):
        self.facts: Set[Fact] = set()
        self.rules: List[Rule] = []
        self._add_core_temporal_rules()

    def _add_core_temporal_rules(self):
        core_rules_data = [
            {
                "conditions": [('event_time', '?e', '?t'), ('travel_time', '?e', '?d')],
                "conclusion": ('departure_time', '?e', 'calculate_departure(?t, ?d)'),
            },
            {
                "conditions": [('start_time', '?e', '?t'), ('duration', '?e', '?d')],
                "conclusion": ('end_time', '?e', 'calculate_end(?t, ?d)'),
            },
             {
                "conditions": [('event_time', '?e', '?t'), ('duration', '?e', '?d')],
                "conclusion": ('end_time', '?e', 'calculate_end(?t, ?d)'),
            },
            {
                "conditions": [('start_time', '?e', '?t1'), ('end_time', '?e', '?t2')],
                "conclusion": ('duration', '?e', 'calculate_duration(?t1, ?t2)'),
            },
             {
                "conditions": [('event_time', '?e', '?t'), ('duration', '?e', '?d')],
                "conclusion": ('start_work_time', '?e', 'calculate_departure(?t, ?d)'),
            },
             {
                "conditions": [('departure_time', '?e', '?t'), ('travel_time', '?e', '?d')],
                "conclusion": ('arrival_time', '?e', 'calculate_end(?t, ?d)'),
            },
        ]
        for rule_data in core_rules_data:
             self.add_rule(
                 conditions=rule_data["conditions"],
                 conclusion=rule_data["conclusion"],
                 source="core_temporal",
                 confidence=1.0
             )
        print(f"Added {len(self.rules)} core temporal rules.")


    def add_fact(self, fact: Fact) -> bool:
        if not isinstance(fact, Fact):
             print(f"Warning: Attempted to add non-Fact object: {fact}")
             return False
        if fact not in self.facts:
            self.facts.add(fact)
            return True
        return False

    def add_rule(self, conditions: List[Tuple[str, str, str]],
                 conclusion: Tuple[str, str, str],
                 confidence: float = config.INITIAL_RULE_CONFIDENCE,
                 source: str = "manual") -> Rule:
        rule = Rule(conditions=conditions, conclusion=conclusion, confidence=confidence, source=source)
        self.rules.append(rule)
        return rule

    def reason(self, max_depth: int = config.MAX_REASONING_DEPTH) -> Tuple[Set[Fact], List[Tuple[Rule, Dict[str, Any], Fact]]]:
        all_newly_derived_facts_overall = set()
        reasoning_trace = []
        current_facts = self.facts.copy()
        depth = 0

        while depth < max_depth:
            made_change_this_iteration = False
            newly_derived_this_iteration = set()

            active_rules = [r for r in self.rules if r.confidence > 0.1]

            for rule in active_rules:
                bindings_list = self._find_bindings(rule.conditions, current_facts)

                for bindings in bindings_list:
                    derived_fact = self._derive_conclusion(rule.conclusion, bindings)

                    if derived_fact and derived_fact not in current_facts and derived_fact not in newly_derived_this_iteration:
                        newly_derived_this_iteration.add(derived_fact)
                        reasoning_trace.append((rule, bindings, derived_fact))
                        made_change_this_iteration = True

            if not made_change_this_iteration:
                break

            current_facts.update(newly_derived_this_iteration)
            all_newly_derived_facts_overall.update(newly_derived_this_iteration)
            depth += 1

        if depth == max_depth:
             print(f"Warning: Reasoning reached maximum depth ({max_depth}).")

        self.facts = current_facts
        return all_newly_derived_facts_overall, reasoning_trace

    def _find_bindings(self, conditions: List[Tuple[str, str, str]], current_facts: Set[Fact]) -> List[Dict[str, Any]]:
        possible_bindings_list = []
        current_bindings_list = [{}]

        for predicate, subj_var, obj_var in conditions:
            next_bindings_list = []
            is_subj_variable = subj_var.startswith('?')
            is_obj_variable = obj_var.startswith('?')

            for existing_binding in current_bindings_list:
                for fact in current_facts:
                    if fact.predicate == predicate:
                        new_binding = existing_binding.copy()
                        consistent = True

                        if is_subj_variable:
                            if subj_var in new_binding:
                                if new_binding[subj_var] != fact.subject:
                                    consistent = False
                            else:
                                new_binding[subj_var] = fact.subject
                        elif subj_var != fact.subject:
                            consistent = False

                        if consistent:
                            if is_obj_variable:
                                if obj_var in new_binding:
                                    if type(new_binding[obj_var]) != type(fact.object) or new_binding[obj_var] != fact.object:
                                        consistent = False
                                else:
                                    new_binding[obj_var] = fact.object
                            elif type(obj_var) != type(fact.object) or obj_var != fact.object:
                                consistent = False

                        if consistent:
                            next_bindings_list.append(new_binding)

            current_bindings_list = next_bindings_list
            if not current_bindings_list:
                return []

        unique_bindings_tuples = {tuple(sorted(b.items())) for b in current_bindings_list}
        final_bindings_list = [dict(t) for t in unique_bindings_tuples]

        return final_bindings_list


    def _derive_conclusion(self, conclusion_pattern: Tuple[str, str, str], bindings: Dict[str, Any]) -> Optional[Fact]:
        conc_pred, conc_subj_var, conc_obj_expr = conclusion_pattern

        try:
            subj = bindings.get(conc_subj_var)
            if subj is None:
                 if not conc_subj_var.startswith('?'):
                      subj = conc_subj_var
                 else:
                      return None

            obj = None
            if isinstance(conc_obj_expr, str) and conc_obj_expr.startswith('calculate_'):
                match = re.match(r'calculate_(\w+)\((.+)\)', conc_obj_expr)
                if not match:
                     print(f"Warning: Invalid calculation format: {conc_obj_expr}")
                     return None
                func_name = "calculate_" + match.group(1)
                arg_vars = [v.strip() for v in match.group(2).split(',')]

                args = []
                for var in arg_vars:
                    if var not in bindings:
                        return None
                    args.append(bindings[var])

                if func_name == 'calculate_departure' and len(args) == 2 and isinstance(args[0], datetime) and isinstance(args[1], timedelta):
                    obj = args[0] - args[1]
                elif func_name == 'calculate_end' and len(args) == 2 and isinstance(args[0], datetime) and isinstance(args[1], timedelta):
                    obj = args[0] + args[1]
                elif func_name == 'calculate_duration' and len(args) == 2 and isinstance(args[0], datetime) and isinstance(args[1], datetime):
                    obj = args[1] - args[0]
                    if obj < timedelta(0):
                         print(f"Warning: Calculated negative duration for {conc_pred} with args {args}. Start time might be after end time.")
                else:
                    print(f"Warning: Unknown calculation function '{func_name}' or incorrect argument types/count for args: {args}")
                    return None
            else:
                obj = bindings.get(conc_obj_expr)
                if obj is None:
                    if not isinstance(conc_obj_expr, str) or not conc_obj_expr.startswith('?'):
                         obj = conc_obj_expr
                    else:
                         return None

            if isinstance(subj, str) and subj.startswith('?'): return None
            if isinstance(obj, str) and obj.startswith('?'): return None

            if obj is None:
                 return None

            return Fact(predicate=conc_pred, subject=str(subj), object=obj)

        except KeyError as e:
            print(f"Error: KeyError during conclusion derivation for {conclusion_pattern}. Missing binding for {e}. Bindings: {bindings}")
            return None
        except TypeError as e:
            print(f"Error: TypeError during calculation for {conclusion_pattern}. Bindings: {bindings}. Error: {e}")
            return None
        except Exception as e:
            print(f"Error: Unexpected error during conclusion derivation for {conclusion_pattern}. Bindings: {bindings}. Error: {e}")
            return None


    def query(self, query_pattern: Tuple[str, str, str]) -> List[Fact]:
        if not isinstance(query_pattern, tuple) or len(query_pattern) != 3:
             print(f"Error: Invalid query pattern format: {query_pattern}")
             return []

        q_pred, q_subj, q_obj = query_pattern
        results = []
        for fact in self.facts:
            match = True
            if fact.predicate != q_pred:
                match = False
            if q_subj != '?' and fact.subject != q_subj:
                match = False
            if q_obj != '?':
                if type(fact.object) != type(q_obj) or fact.object != q_obj:
                    match = False


            if match:
                results.append(fact)
        return results

    def get_all_facts(self) -> Set[Fact]:
        return self.facts

    def get_all_rules(self) -> List[Rule]:
        return self.rules

    def clear_facts(self):
        self.facts.clear()

    def clear_rules(self):
        self.rules.clear()


reasoner = SymbolicReasoner()

print("--- Example 1: Departure Time ---")
reasoner.clear_facts()
t1 = TimeUtils.parse_time("3 PM")
d1 = TimeUtils.parse_duration("30 minutes")
if t1: reasoner.add_fact(Fact('event_time', 'event1', t1))
if d1: reasoner.add_fact(Fact('travel_time', 'event1', d1))

print("Initial Facts:")
for f in reasoner.get_all_facts(): print(f"  {f}")
print("\nApplying Rules...")
new_facts, trace = reasoner.reason()

print("\nNew Facts Derived:")
for f in new_facts: print(f"  {f}")

print("\nQuerying for departure_time(event1, ?):")
results = reasoner.query(('departure_time', 'event1', '?'))
if results:
    print(f"  Found Answer: {results[0]}")
else:
    print("  Answer not found.")

print("\n--- Example 2: End Time ---")
reasoner.clear_facts()
t2 = TimeUtils.parse_time("9:00 AM")
d2 = TimeUtils.parse_duration("3 hours")
if t2: reasoner.add_fact(Fact('start_time', 'event2', t2))
if d2: reasoner.add_fact(Fact('duration', 'event2', d2))

print("Initial Facts:")
for f in reasoner.get_all_facts(): print(f"  {f}")
print("\nApplying Rules...")
new_facts, trace = reasoner.reason()

print("\nNew Facts Derived:")
for f in new_facts: print(f"  {f}")

print("\nQuerying for end_time(event2, ?):")
results = reasoner.query(('end_time', 'event2', '?'))
if results:
    print(f"  Found Answer: {results[0]}")
else:
    print("  Answer not found.")

print("\n--- Example 3: Duration ---")
reasoner.clear_facts()
t3_start = TimeUtils.parse_time("2 PM")
t3_end = TimeUtils.parse_time("4:30 PM")
if t3_start: reasoner.add_fact(Fact('start_time', 'event3', t3_start))
if t3_end: reasoner.add_fact(Fact('end_time', 'event3', t3_end))

print("Initial Facts:")
for f in reasoner.get_all_facts(): print(f"  {f}")
print("\nApplying Rules...")
new_facts, trace = reasoner.reason()

print("\nNew Facts Derived:")
for f in new_facts: print(f"  {f}")

print("\nQuerying for duration(event3, ?):")
results = reasoner.query(('duration', 'event3', '?'))
if results:
    print(f"  Found Answer: {results[0]}")
else:
    print("  Answer not found.")

print("\n--- All Rules in Reasoner ---")
for r in reasoner.get_all_rules(): print(r)

Added 6 core temporal rules.
--- Example 1: Departure Time ---
Initial Facts:
  event_time(event1, 03:00 PM)
  travel_time(event1, 30 minutes)

Applying Rules...

New Facts Derived:
  arrival_time(event1, 03:00 PM)
  departure_time(event1, 02:30 PM)

Querying for departure_time(event1, ?):
  Found Answer: departure_time(event1, 02:30 PM)

--- Example 2: End Time ---
Initial Facts:
  duration(event2, 3 hours)
  start_time(event2, 09:00 AM)

Applying Rules...

New Facts Derived:
  end_time(event2, 12:00 PM)

Querying for end_time(event2, ?):
  Found Answer: end_time(event2, 12:00 PM)

--- Example 3: Duration ---
Initial Facts:
  start_time(event3, 02:00 PM)
  end_time(event3, 04:30 PM)

Applying Rules...

New Facts Derived:
  duration(event3, 2 hours and 30 minutes)

Querying for duration(event3, ?):
  Found Answer: duration(event3, 2 hours and 30 minutes)

--- All Rules in Reasoner ---
Rule 0 (1.00, core_temporal): IF event_time(?e, ?t) AND travel_time(?e, ?d) THEN departure_time(?e, ca

# Integration Layer

In [16]:
class NeuralToSymbolicTranslator:
    def translate(self, neural_output: Dict[str, Any], reasoner: SymbolicReasoner) -> int:
        added_facts_count = 0
        if not isinstance(neural_output, dict) or 'relations' not in neural_output:
             print("Error: Invalid neural_output format for translation. Expected dict with 'relations'.")
             return -1
        if not isinstance(reasoner, SymbolicReasoner):
             print("Error: Invalid reasoner object provided for translation.")
             return -1

        relations = neural_output.get('relations', [])
        if not isinstance(relations, list):
             print("Warning: 'relations' in neural_output is not a list.")
             relations = []

        for relation in relations:
            if isinstance(relation, tuple) and len(relation) == 3:
                predicate, subject, obj = relation
                if isinstance(predicate, str) and isinstance(subject, str):
                    try:
                        fact = Fact(predicate=predicate, subject=subject, object=obj)
                        if reasoner.add_fact(fact):
                            added_facts_count += 1
                    except Exception as e:
                         print(f"Error creating Fact from relation {relation}: {e}")
                else:
                    print(f"Warning: Skipping invalid relation format: {relation}")
            else:
                 print(f"Warning: Skipping invalid relation format: {relation}")


        return added_facts_count

extractor = BertExtractor()
reasoner = SymbolicReasoner()
translator = NeuralToSymbolicTranslator()

q1 = "If I have a meeting at 3 PM and it takes 30 minutes to get there, when should I leave?"
neural_out1 = extractor.extract_entities_relations(q1)

print("--- Neural Output ---")

print(f"Relations: {neural_out1.get('relations')}")
print(f"Query: {neural_out1.get('query')}")


print("\n--- Translating to Symbolic Facts ---")
num_added = translator.translate(neural_out1, reasoner)
print(f"Added {num_added} facts.")

print("\n--- Facts in Reasoner (Before Reasoning) ---")
if not reasoner.get_all_facts():
    print("  No facts in reasoner.")
else:
    for f in reasoner.get_all_facts(): print(f"  {f}")

print("\n--- Performing Symbolic Reasoning ---")
new_facts, trace = reasoner.reason()
print(f"Derived {len(new_facts)} new facts.")

print("\n--- Querying for Result ---")
query_pattern = neural_out1.get('query')
if query_pattern:
    results = reasoner.query(query_pattern)
    if results:
        print(f"Answer Found for {query_pattern}: {results[0]}")
    else:
        print(f"Answer not found for query: {query_pattern}")
else:
    print("No query identified in neural output.")

print("\n--- All Facts After Reasoning ---")
if not reasoner.get_all_facts():
    print("  No facts in reasoner.")
else:
    for f in reasoner.get_all_facts(): print(f"  {f}")

Loading BERT model: bert-base-uncased...
BERT model loaded.
Added 6 core temporal rules.
--- Neural Output ---
Relations: [('event_time', 'event1', datetime.datetime(2025, 3, 31, 15, 0)), ('travel_time', 'event1', datetime.timedelta(seconds=1800))]
Query: ('departure_time', 'event1', '?')

--- Translating to Symbolic Facts ---
Added 2 facts.

--- Facts in Reasoner (Before Reasoning) ---
  event_time(event1, 03:00 PM)
  travel_time(event1, 30 minutes)

--- Performing Symbolic Reasoning ---
Derived 2 new facts.

--- Querying for Result ---
Answer Found for ('departure_time', 'event1', '?'): departure_time(event1, 02:30 PM)

--- All Facts After Reasoning ---
  arrival_time(event1, 03:00 PM)
  departure_time(event1, 02:30 PM)
  event_time(event1, 03:00 PM)
  travel_time(event1, 30 minutes)


# Meta-Cognitive Component

In [17]:
class MetaCognitiveEvaluator:
    def evaluate_plausibility(self,
                              answer_fact: Fact,
                              context_facts: Set[Fact]
                             ) -> Tuple[bool, str, float]:
        if not isinstance(answer_fact, Fact):
             return False, "Invalid answer fact provided for evaluation.", 0.0
        if not isinstance(context_facts, set):
             context_facts = set() if context_facts is None else set(context_facts)

        predicate = answer_fact.predicate
        subject = answer_fact.subject
        value = answer_fact.object

        if predicate == 'departure_time' and isinstance(value, datetime):
            event_time_fact = next((f for f in context_facts if f.predicate == 'event_time' and f.subject == subject and isinstance(f.object, datetime)), None)
            if event_time_fact:
                if value >= event_time_fact.object:
                    explanation = (f"Evaluation Failed: Departure time ({TimeUtils.format_time(value)}) "
                                   f"is not before the event time ({TimeUtils.format_time(event_time_fact.object)}).")
                    return False, explanation, 0.0

        elif predicate == 'end_time' and isinstance(value, datetime):
            start_time_fact = next((f for f in context_facts if f.predicate == 'start_time' and f.subject == subject and isinstance(f.object, datetime)), None)
            if start_time_fact:
                if value <= start_time_fact.object:
                    explanation = (f"Evaluation Failed: End time ({TimeUtils.format_time(value)}) "
                                   f"is not after the start time ({TimeUtils.format_time(start_time_fact.object)}).")
                    return False, explanation, 0.0

                duration_fact = next((f for f in context_facts if f.predicate == 'duration' and f.subject == subject and isinstance(f.object, timedelta)), None)
                if duration_fact:
                    expected_end_time = start_time_fact.object + duration_fact.object

                    if abs((value - expected_end_time).total_seconds()) > 60:
                        explanation = (f"Evaluation Warning: Calculated end time ({TimeUtils.format_time(value)}) "
                                       f"differs significantly from start time + duration "
                                       f"({TimeUtils.format_time(expected_end_time)}).")

                        return True, explanation, 0.7

        elif predicate == 'arrival_time' and isinstance(value, datetime):
             departure_time_fact = next((f for f in context_facts if f.predicate == 'departure_time' and f.subject == subject and isinstance(f.object, datetime)), None)
             if departure_time_fact:
                  if value <= departure_time_fact.object:
                       explanation = (f"Evaluation Failed: Arrival time ({TimeUtils.format_time(value)}) "
                                      f"is not after the departure time ({TimeUtils.format_time(departure_time_fact.object)}).")
                       return False, explanation, 0.0

                  travel_time_fact = next((f for f in context_facts if f.predicate == 'travel_time' and f.subject == subject and isinstance(f.object, timedelta)), None)
                  if travel_time_fact:
                      expected_arrival_time = departure_time_fact.object + travel_time_fact.object
                      if abs((value - expected_arrival_time).total_seconds()) > 60:
                           explanation = (f"Evaluation Warning: Calculated arrival time ({TimeUtils.format_time(value)}) "
                                          f"differs significantly from departure time + travel time "
                                          f"({TimeUtils.format_time(expected_arrival_time)}).")
                           return True, explanation, 0.7

        elif predicate == 'duration' and isinstance(value, timedelta):
            if value.total_seconds() < 0:
                explanation = f"Evaluation Failed: Duration ({TimeUtils.format_timedelta(value)}) cannot be negative."
                return False, explanation, 0.0

        explanation = "Evaluation Passed: Answer appears plausible based on implemented checks."
        return True, explanation, 1.0

    def generate_explanation(self,
                             query: Optional[Tuple[str, str, str]],
                             answer_fact: Optional[Fact],
                             reasoning_trace: List[Tuple[Rule, Dict[str, Any], Fact]],
                             plausibility_result: Optional[Tuple[bool, str, float]]
                            ) -> str:
        explanation_lines = []
        query_str = f"{query[0]}({query[1]}, {query[2]})" if query else "No query specified"
        explanation_lines.append(f"Query: {query_str}")

        explanation_lines.append("\nReasoning Steps:")
        if not reasoning_trace:
            if answer_fact:
                 explanation_lines.append(f" - Answer '{answer_fact}' was likely provided as an initial fact (no rules applied).")
            else:
                 explanation_lines.append(" - No rules could be applied, or no relevant facts were available.")
        else:
            relevant_trace = []
            processed_rules = set()
            if answer_fact:
                 final_steps = [step for step in reasoning_trace if step[2] == answer_fact]
                 if final_steps:
                      for step in final_steps:
                           rule_id = step[0].id
                           if (rule_id, step[2]) not in processed_rules:
                                relevant_trace.append(step)
                                processed_rules.add((rule_id, step[2]))
                 else:
                      explanation_lines.append(f" - Trace does not directly show derivation of answer {answer_fact}. Showing all steps:")
                      relevant_trace = reasoning_trace
            else:
                 relevant_trace = reasoning_trace
            if not relevant_trace and reasoning_trace:
                 relevant_trace = reasoning_trace
            relevant_trace.sort(key=lambda step: (step[0].id, str(step[2])))
            for rule, bindings, derived_fact in relevant_trace:
                 if (rule.id, derived_fact) in processed_rules and len(relevant_trace) > 1: continue
                 processed_rules.add((rule.id, derived_fact))
                 bound_vars_str = {}
                 for k, v in bindings.items():
                      if isinstance(v, datetime): bound_vars_str[k] = TimeUtils.format_time(v)
                      elif isinstance(v, timedelta): bound_vars_str[k] = TimeUtils.format_timedelta(v)
                      else: bound_vars_str[k] = str(v)
                 explanation_lines.append(f" - Applied Rule {rule.id} ({rule.source}, conf: {rule.confidence:.2f}):")
                 explanation_lines.append(f"     IF {' AND '.join([f'{p}({s}, {o})' for p, s, o in rule.conditions])}")
                 explanation_lines.append(f"     THEN {rule.conclusion[0]}({rule.conclusion[1]}, {rule.conclusion[2]})")
                 explanation_lines.append(f"     With Bindings: {bound_vars_str}")
                 explanation_lines.append(f"     Derived Fact: {derived_fact}")

        explanation_lines.append(f"\nFinal Answer: {answer_fact if answer_fact else 'Not Found'}")

        explanation_lines.append("\nPlausibility Check:")
        if plausibility_result:
            is_plausible, plaus_explanation, plaus_confidence = plausibility_result
            status = 'Passed' if is_plausible else ('Warning' if plaus_confidence > config.PLAUSIBILITY_THRESHOLD and plaus_confidence < 1.0 else 'Failed')
            explanation_lines.append(f" - Status: {status} (Confidence: {plaus_confidence:.2f})")
            explanation_lines.append(f" - Details: {plaus_explanation}")
        elif answer_fact:
             explanation_lines.append(" - Not performed (or failed).")
        else:
            explanation_lines.append(" - Not applicable (no answer found).")

        return "\n".join(explanation_lines)

class AdaptiveLearner:
    def __init__(self,
                 neural_model: 'BertExtractor',
                 symbolic_reasoner: 'SymbolicReasoner',
                 rule_generator: 'DynamicRuleGenerator'):
        self.neural_model = neural_model
        self.symbolic_reasoner = symbolic_reasoner
        self.rule_generator = rule_generator

    def update_on_feedback(self,
                           question_info: Dict[str, Any],
                           answer_fact: Optional[Fact],
                           reasoning_trace: List[Tuple[Rule, Dict[str, Any], Fact]],
                           is_correct: bool,
                           ground_truth_fact: Optional[Fact] = None):

        print(f"\n--- Adaptation Triggered ---")
        print(f"  Question: \"{question_info.get('question', 'N/A')[:50]}...\"")
        print(f"  Answer Correct: {is_correct}")
        print(f"  Provided Answer: {answer_fact}")

        relevant_rules = set()
        if answer_fact and reasoning_trace:

             final_steps = [step for step in reasoning_trace if step[2] == answer_fact]
             if final_steps:
                  relevant_rules.add(final_steps[0][0])

        if is_correct:
            if answer_fact and relevant_rules:
                print("  Action: Strengthening rules involved in correct answer.")
                for rule in relevant_rules:
                    delta = config.ADAPTATION_LEARNING_RATE_SYMBOLIC * (1.0 - rule.confidence)
                    rule.confidence = min(1.0, rule.confidence + delta)
                    print(f"    - Increased confidence of Rule {rule.id} ({rule.source}) to {rule.confidence:.3f}")
            elif answer_fact:
                 print("  Info: Correct answer was likely an initial fact or trace is incomplete. No rule confidence updated.")
            else:
                 print("  Warning: Feedback indicates 'correct' but no answer was provided.")

        else:
            if answer_fact:
                print("  Action: Weakening rules involved in incorrect answer.")
                if relevant_rules:
                    for rule in relevant_rules:
                        delta = config.ADAPTATION_LEARNING_RATE_SYMBOLIC * rule.confidence
                        rule.confidence = max(0.05, rule.confidence - delta)
                        print(f"    - Decreased confidence of Rule {rule.id} ({rule.source}) to {rule.confidence:.3f}")
                else:
                     print("  Info: Cannot identify specific rule(s) leading to the incorrect answer from trace. No confidence updated.")

            else:
                print("  Action: Attempting to address missing answer.")
                query = question_info.get('query')
                relations = question_info.get('relations', [])
                if query:
                     query_pred = query[0]
                     needs_time = any(p in query_pred for p in ['time', 'when', 'arrive', 'depart', 'end', 'start'])
                     needs_duration = any(p in query_pred for p in ['duration', 'long', 'takes', 'travel'])
                     has_time = any(isinstance(r[2], datetime) for r in relations)
                     has_duration = any(isinstance(r[2], timedelta) for r in relations)
                     if needs_time and not has_time:
                          print("    - Possible Issue: Query requires time, but none reliably extracted.")
                     if needs_duration and not has_duration:
                          print("    - Possible Issue: Query requires duration, but none reliably extracted.")
                else:
                     print("    - Possible Issue: Query itself was not identified by neural component.")

                print("  Action: Triggering dynamic rule generation.")

                temp_reasoner = SymbolicReasoner()
                temp_translator = NeuralToSymbolicTranslator()
                temp_translator.translate(question_info, temp_reasoner)
                initial_facts = temp_reasoner.get_all_facts()

                if hasattr(self, 'rule_generator') and self.rule_generator:
                    new_rule = self.rule_generator.generate_rule_from_failure(
                        question_info=question_info,
                        initial_facts=initial_facts,
                        existing_rules=self.symbolic_reasoner.get_all_rules(),
                        ground_truth_fact=ground_truth_fact
                    )

                    if new_rule:
                        print(f"    - Generated new rule: {new_rule}")
                        is_redundant = self.rule_generator._is_rule_redundant(new_rule, self.symbolic_reasoner.get_all_rules())
                        if not is_redundant:
                             added_rule_obj = self.symbolic_reasoner.add_rule(
                                  conditions=new_rule.conditions,
                                  conclusion=new_rule.conclusion,
                                  confidence=new_rule.confidence,
                                  source=new_rule.source
                             )
                             print(f"    - Added Rule {added_rule_obj.id} to the reasoner.")
                        else:
                             print(f"    - Generated rule is redundant with existing rules. Not added.")
                    else:
                        print("    - Rule generation did not produce a new rule for this specific case.")
                        self.rule_generator.store_failure(question_info, initial_facts, ground_truth_fact)
                else:
                     print("   - ERROR: Rule generator not available in AdaptiveLearner.")

# Dynamic Rule Generator

In [18]:
class DynamicRuleGenerator:
    def __init__(self):
        self.templates = [
            {
                "conditions": [('event_time', '?e', '?t'), ('duration', '?e', '?d')],
                "conclusion": ('start_work_time', '?e', 'calculate_departure(?t, ?d)'),
                "required_query": "start_work_time"
            },
            {
                 "conditions": [('start_time', '?e', '?t1'), ('end_time', '?e', '?t2')],
                 "conclusion": ('duration', '?e', 'calculate_duration(?t1, ?t2)'),
                 "required_query": "duration"
            },
             {
                 "conditions": [('departure_time', '?e', '?t'), ('travel_time', '?e', '?d')],
                 "conclusion": ('arrival_time', '?e', 'calculate_end(?t, ?d)'),
                 "required_query": "arrival_time"
            },
             {
                 "conditions": [('start_time', '?e', '?t')],
                 "conclusion": ('event_time', '?e', '?t'),
                 "required_query": "event_time"
            },
            {
                 "conditions": [('start_time', '?e', '?t'), ('travel_time', '?e', '?d')],
                 "conclusion": ('arrival_time', '?e', 'calculate_end(?t, ?d)'),
                 "required_query": "arrival_time"
            },
        ]
        self.failed_cases_store: List[Dict[str, Any]] = []
        print(f"Initialized DynamicRuleGenerator with {len(self.templates)} templates.")

    def store_failure(self, question_info: Dict, initial_facts: Set[Fact], ground_truth: Optional[Fact] = None):
         if not isinstance(question_info, dict) or not isinstance(initial_facts, set):
              print("Warning: Invalid input for store_failure.")
              return

         self.failed_cases_store.append({
             "question_info": question_info,
             "initial_facts": initial_facts,
             "ground_truth": ground_truth
         })

    def _is_rule_redundant(self, new_rule: Rule, existing_rules: List[Rule]) -> bool:
        new_rule_cond_preds = set(c[0] for c in new_rule.conditions)
        new_rule_conc_pred = new_rule.conclusion[0]

        new_rule_conc_calc = None
        if isinstance(new_rule.conclusion[2], str) and new_rule.conclusion[2].startswith("calculate_"):
             match = re.match(r'(calculate_\w+)\(.*\)', new_rule.conclusion[2])
             if match: new_rule_conc_calc = match.group(1)

        for rule in existing_rules:
            if rule.conclusion[0] == new_rule_conc_pred:
                existing_rule_cond_preds = set(c[0] for c in rule.conditions)
                if existing_rule_cond_preds == new_rule_cond_preds:
                     existing_rule_conc_calc = None
                     if isinstance(rule.conclusion[2], str) and rule.conclusion[2].startswith("calculate_"):
                          match = re.match(r'(calculate_\w+)\(.*\)', rule.conclusion[2])
                          if match: existing_rule_conc_calc = match.group(1)
                     if new_rule_conc_calc == existing_rule_conc_calc:
                          return True

        return False

    def _apply_templates(self,
                         query_predicate: str,
                         sample_facts: Set[Fact],
                         existing_rules: List[Rule]
                        ) -> Optional[Rule]:
         available_fact_predicates = {f.predicate for f in sample_facts}

         for template in self.templates:
            if template["required_query"] == query_predicate:
                required_cond_preds = {cond[0] for cond in template["conditions"]}
                if required_cond_preds.issubset(available_fact_predicates):
                    new_rule = Rule(
                        conditions=template["conditions"],
                        conclusion=template["conclusion"],
                        confidence=max(0.1, config.INITIAL_RULE_CONFIDENCE * 0.8),
                        source="generated"
                    )
                    if not self._is_rule_redundant(new_rule, existing_rules):
                         return new_rule
                    else:
                         pass

         return None

    def generate_rule_from_failure(self,
                                   question_info: Dict,
                                   initial_facts: Set[Fact],
                                   existing_rules: List[Rule],
                                   ground_truth_fact: Optional[Fact] = None
                                  ) -> Optional[Rule]:
        query = question_info.get("query")
        if not query or not isinstance(query, tuple) or len(query) != 3:
            print("WARN: Rule generation skipped: Invalid or missing query in question_info.")
            return None

        query_predicate = query[0]
        generated_rule = self._apply_templates(query_predicate, initial_facts, existing_rules)

        if generated_rule:
             return generated_rule
        else:
             return None

    def generate_rules_from_stored_failures(self, existing_rules: List[Rule]) -> List[Rule]:
        newly_generated_rules = []
        num_failures = len(self.failed_cases_store)

        if num_failures < config.RULE_GEN_MIN_SUPPORT:
             return newly_generated_rules

        print(f"INFO: Attempting batch rule generation from {num_failures} stored failures.")

        embeddings = []
        valid_indices = []
        for i, case in enumerate(self.failed_cases_store):
             q_info = case.get("question_info", {})
             emb = q_info.get("embedding")

             if emb is not None and isinstance(emb, torch.Tensor) and emb.ndim == 2 and emb.shape[0] == 1:
                  embeddings.append(emb.numpy().flatten())
                  valid_indices.append(i)
             else:
                  emb_info = f"type {type(emb)}" if emb is not None else "None"
                  emb_shape = f"shape {emb.shape}" if isinstance(emb, torch.Tensor) else ""
                  print(f"Warning: Skipping failure case {i} due to missing or invalid embedding ({emb_info} {emb_shape}).")


        if len(valid_indices) < config.RULE_GEN_MIN_SUPPORT:
             print("INFO: Batch rule generation skipped: Not enough valid embeddings found.")
             self.failed_cases_store.clear()
             return newly_generated_rules

        embeddings_np = np.vstack(embeddings)

        try:
            clustering = AgglomerativeClustering(
                n_clusters=None,
                distance_threshold=1.0 - config.RULE_GEN_CLUSTER_THRESHOLD,
                metric='cosine',
                linkage='average'
            )
            labels = clustering.fit_predict(embeddings_np)
            num_clusters = (max(labels) + 1) if labels.size > 0 and max(labels) > -1 else 0
            print(f"INFO: Clustering resulted in {num_clusters} potential clusters.")
        except Exception as e:
            print(f"Error during clustering: {e}. Aborting batch rule generation.")
            self.failed_cases_store.clear()
            return newly_generated_rules

        generated_rules_in_batch = []
        for i in range(num_clusters):
            cluster_member_indices = [valid_indices[idx] for idx, label in enumerate(labels) if label == i]

            if len(cluster_member_indices) >= config.RULE_GEN_MIN_SUPPORT:
                print(f"\nINFO: Analyzing Cluster {i} with {len(cluster_member_indices)} members.")
                cluster_cases = [self.failed_cases_store[idx] for idx in cluster_member_indices]


                common_query_pred = self._find_common_query_predicate(cluster_cases)

                if common_query_pred:
                    print(f"  - Common query predicate in cluster: '{common_query_pred}'")

                    sample_facts = cluster_cases[0]['initial_facts']

                    generated_rule = self._apply_templates(
                        common_query_pred,
                        sample_facts,
                        existing_rules + generated_rules_in_batch
                    )

                    if generated_rule:
                         print(f"  - Generated candidate Rule {generated_rule.id} for cluster {i}.")
                         newly_generated_rules.append(generated_rule)
                         generated_rules_in_batch.append(generated_rule)
                    else:
                         print(f"  - No suitable template found or rule was redundant for cluster {i}.")
                else:
                    print(f"  - Could not determine a common query predicate for cluster {i}.")

        print(f"\nINFO: Finished batch processing. Generated {len(newly_generated_rules)} new rules.")
        self.failed_cases_store.clear()
        return newly_generated_rules

    def _find_common_query_predicate(self, cases: List[Dict]) -> Optional[str]:
        query_preds = []
        for case in cases:
             q_info = case.get("question_info", {})
             query = q_info.get("query")
             if query and isinstance(query, tuple) and len(query) == 3 and isinstance(query[0], str):
                  query_preds.append(query[0])

        if not query_preds:
            return None

        predicate_counts = Counter(query_preds)
        most_common = predicate_counts.most_common(1)
        return most_common[0][0] if most_common else None

gen_reasoner = SymbolicReasoner()
generator = DynamicRuleGenerator()

q3 = "The project deadline is 5 PM. I need 2 hours to review it. When should I start?"

bert_hidden_size = 768
neural_out3 = {
    "question": q3,
    "embedding": torch.randn(1, bert_hidden_size),
    "relations": [
        ('event_time', 'event1', TimeUtils.parse_time("5 PM")),
        ('duration', 'event1', TimeUtils.parse_duration("2 hours"))
    ],
    "query": ('start_work_time', 'event1', '?')
}
initial_facts3 = {
    Fact('event_time', 'event1', TimeUtils.parse_time("5 PM")),
    Fact('duration', 'event1', TimeUtils.parse_duration("2 hours"))
}
ground_truth3 = Fact(predicate='start_work_time', subject='event1', object=TimeUtils.parse_time("3:00 PM"))

print("--- Attempting Rule Generation from Single Failure ---")

rules_before = gen_reasoner.get_all_rules().copy()
print(f"Rules before generation: {len(rules_before)}")

new_rule = generator.generate_rule_from_failure(
    neural_out3,
    initial_facts3,
    rules_before,
    ground_truth3
)

if new_rule:
    print(f"\nSuccessfully generated new rule:\n{new_rule}")

    added_rule = gen_reasoner.add_rule(
         conditions=new_rule.conditions,
         conclusion=new_rule.conclusion,
         confidence=new_rule.confidence,
         source=new_rule.source
    )
    print(f"Rule {added_rule.id} added to reasoner.")
    print(f"Rules after generation: {len(gen_reasoner.get_all_rules())}")
    print("\n--- Reasoning again with the new rule ---")
    gen_reasoner.clear_facts()
    for f in initial_facts3: gen_reasoner.add_fact(f)
    new_facts, trace = gen_reasoner.reason()
    print("New facts derived:")
    for f in new_facts: print(f"  {f}")
    results = gen_reasoner.query(neural_out3['query'])
    print(f"Query result: {results[0] if results else 'Not Found'}")

else:
    print("\nFailed to generate a rule for this case from template.")

    generator.store_failure(neural_out3, initial_facts3, ground_truth3)

print("\n--- Attempting Batch Rule Generation ---")

q4 = "Flight leaves at 10 AM, checkin takes 1 hour. When to start checkin?"
neural_out4 = {
    "question": q4,
    "embedding": torch.randn(1, bert_hidden_size) * 0.9,
    "relations": [('event_time', 'event2', TimeUtils.parse_time("10 AM")), ('duration', 'event2', TimeUtils.parse_duration("1 hour"))],
    "query": ('start_work_time', 'event2', '?'),
}
initial_facts4 = { Fact('event_time', 'event2', TimeUtils.parse_time("10 AM")), Fact('duration', 'event2', TimeUtils.parse_duration("1 hour"))}
ground_truth4 = Fact(predicate='start_work_time', subject='event2', object=TimeUtils.parse_time("9:00 AM"))

if not new_rule:
     generator.store_failure(neural_out3, initial_facts3, ground_truth3)
generator.store_failure(neural_out4, initial_facts4, ground_truth4)


batch_rules = generator.generate_rules_from_stored_failures(gen_reasoner.get_all_rules())

if batch_rules:
     print("\nGenerated rules from batch processing:")
     for rule in batch_rules:
          print(f" - {rule}")
          if not generator._is_rule_redundant(rule, gen_reasoner.get_all_rules()):
              added_rule = gen_reasoner.add_rule(
                   conditions=rule.conditions, conclusion=rule.conclusion,
                   confidence=rule.confidence, source=rule.source
              )
              print(f"   Added Rule {added_rule.id} to reasoner.")
          else:
              print(f"   (Rule {rule.id} already exists or is redundant, not re-adding)")
     print(f"Rules after batch generation: {len(gen_reasoner.get_all_rules())}")
else:
     print("\nNo new rules generated from batch processing.")


print(f"Stored failures after batch processing: {len(generator.failed_cases_store)}")

Added 6 core temporal rules.
Initialized DynamicRuleGenerator with 5 templates.
--- Attempting Rule Generation from Single Failure ---
Rules before generation: 6

Failed to generate a rule for this case from template.

--- Attempting Batch Rule Generation ---
INFO: Attempting batch rule generation from 3 stored failures.
INFO: Clustering resulted in 2 potential clusters.

INFO: Analyzing Cluster 0 with 2 members.
  - Common query predicate in cluster: 'start_work_time'
  - No suitable template found or rule was redundant for cluster 0.

INFO: Finished batch processing. Generated 0 new rules.

No new rules generated from batch processing.
Stored failures after batch processing: 0


# Data Loading

In [19]:
def load_qa_data(file_path: str) -> List[Dict[str, Any]]:
    print(f"Attempting to load QA data from: {file_path}")
    if not os.path.exists(file_path):
        print(f"Error: Data file not found at {file_path}")
        return []

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except json.JSONDecodeError as e:
        print(f"Error: Could not decode JSON from {file_path}. Error: {e}")
        return []
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return []

    if not isinstance(data, list):
        print(f"Error: Invalid data format in {file_path}. Expected a JSON list.")
        return []

    processed_data = []
    required_keys = {"question", "answer_predicate", "answer_value"}
    for i, item in enumerate(data):
        if not isinstance(item, dict):
            print(f"Warning: Skipping item {i} in {file_path}, not a dictionary.")
            continue

        essential_keys = {"question", "answer_predicate", "answer_value"}
        if not essential_keys.issubset(item.keys()):
            missing_keys = essential_keys - item.keys()
            print(f"Warning: Skipping item {i} ('{item.get('question', 'N/A')[:30]}...') due to missing keys: {missing_keys}")
            continue

        gt_pred = item['answer_predicate']
        gt_val_str = item['answer_value']
        parsed_gt_value = None
        if isinstance(gt_val_str, str):
             if 'time' in gt_pred.lower() or 'when' in gt_pred.lower():
                  parsed_gt_value = TimeUtils.parse_time(gt_val_str)
             elif 'duration' in gt_pred.lower() or 'long' in gt_pred.lower():
                  parsed_gt_value = TimeUtils.parse_duration(gt_val_str)

             elif gt_val_str.isdigit():
                  parsed_gt_value = int(gt_val_str)


        item['parsed_answer_value'] = parsed_gt_value
        processed_data.append(item)

    print(f"Successfully loaded and processed {len(processed_data)} QA pairs from {file_path}.")
    return processed_data

def create_sample_data_file(file_path: str = config.DATA_PATH):
    if os.path.exists(file_path):
        return
    sample_data = [
        {
            "id": "q1",
            "question": "If I have a meeting at 3 PM and it takes 30 minutes to get there, when should I leave?",
            "answer_predicate": "departure_time",
            "answer_value": "02:30 PM"
        },
        {
            "id": "q2",
            "question": "The workshop starts at 9 AM and runs for 3 hours. What time does it finish?",
            "answer_predicate": "end_time",
            "answer_value": "12:00 PM"
        },
        {
            "id": "q3",
            "question": "My flight is at 18:00 and I need 1 hour 15 minutes for travel. What's the latest departure time?",
            "answer_predicate": "departure_time",
            "answer_value": "04:45 PM"
        },
        {
            "id": "q4",
            "question": "The project deadline is 5 PM. I need 2 hours to review it. When should I start?",
            "answer_predicate": "start_work_time",
            "answer_value": "03:00 PM"
        },
        {
            "id": "q5",
            "question": "How long did the meeting last if it started at 2pm and ended at 4:30 PM?",
            "answer_predicate": "duration",
            "answer_value": "2 hours and 30 minutes"
        },
         {
            "id": "q6",
            "question": "My train leaves at 10:00 AM. The journey is 1 hr 15 min long. When do I arrive?",
            "answer_predicate": "arrival_time",
            "answer_value": "11:15 AM"
        }
    ]

    try:
        data_dir = os.path.dirname(file_path)
        if data_dir and not os.path.exists(data_dir):
            os.makedirs(data_dir)
            print(f"Created data directory: {data_dir}")

        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(sample_data, f, indent=2)
        print(f"Created sample data file: {file_path}")
    except Exception as e:
        print(f"Error creating sample data file {file_path}: {e}")



create_sample_data_file(config.DATA_PATH)

loaded_data = load_qa_data(config.DATA_PATH)

if loaded_data:
    print(f"\nSuccessfully loaded {len(loaded_data)} QA pairs.")
    print("\nFirst item:")
    import pprint
    pprint.pprint(loaded_data[0], width=100)
    print("\nLast item:")
    pprint.pprint(loaded_data[-1], width=100)

    print("\nChecking parsed ground truth values (first few):")
    for i in range(min(3, len(loaded_data))):
         item = loaded_data[i]
         parsed_val = item.get('parsed_answer_value', 'N/A')
         print(f"  Item {i}: Pred='{item['answer_predicate']}', StrVal='{item['answer_value']}', ParsedVal='{parsed_val}' ({type(parsed_val)})")

else:
    print("\nFailed to load data.")

Attempting to load QA data from: data/sample_qa.json
Successfully loaded and processed 6 QA pairs from data/sample_qa.json.

Successfully loaded 6 QA pairs.

First item:
{'answer_predicate': 'departure_time',
 'answer_value': '02:30 PM',
 'id': 'q1',
 'parsed_answer_value': datetime.datetime(2025, 3, 31, 14, 30),
 'question': 'If I have a meeting at 3 PM and it takes 30 minutes to get there, when should I '
             'leave?'}

Last item:
{'answer_predicate': 'arrival_time',
 'answer_value': '11:15 AM',
 'id': 'q6',
 'parsed_answer_value': datetime.datetime(2025, 3, 31, 11, 15),
 'question': 'My train leaves at 10:00 AM. The journey is 1 hr 15 min long. When do I arrive?'}

Checking parsed ground truth values (first few):
  Item 0: Pred='departure_time', StrVal='02:30 PM', ParsedVal='2025-03-31 14:30:00' (<class 'datetime.datetime'>)
  Item 1: Pred='end_time', StrVal='12:00 PM', ParsedVal='2025-03-31 12:00:00' (<class 'datetime.datetime'>)
  Item 2: Pred='departure_time', StrVal='04

# Evaluation

In [20]:
def evaluate_answer(predicted_fact: Optional[Fact],
                    ground_truth_pred: str,
                    ground_truth_val_str: str,
                    parsed_ground_truth_val: Any) -> bool:
    if predicted_fact is None:
        return parsed_ground_truth_val is None

    if not isinstance(predicted_fact, Fact):
         print("Warning: evaluate_answer received non-Fact object as prediction.")
         return False

    if predicted_fact.predicate != ground_truth_pred:
        return False

    pred_val = predicted_fact.object
    gt_val = parsed_ground_truth_val

    if gt_val is None and ground_truth_val_str is not None:
         print(f"Warning: Ground truth value '{ground_truth_val_str}' could not be parsed. Comparing prediction '{pred_val}' as string.")
         return str(pred_val) == ground_truth_val_str

    try:
        if isinstance(pred_val, datetime) and isinstance(gt_val, datetime):
            return pred_val.strftime("%H:%M") == gt_val.strftime("%H:%M")
        elif isinstance(pred_val, timedelta) and isinstance(gt_val, timedelta):
            return abs(pred_val.total_seconds() - gt_val.total_seconds()) < 60
        elif type(pred_val) == type(gt_val):
             return pred_val == gt_val
        else:
             return str(pred_val) == ground_truth_val_str

    except Exception as e:
        print(f"Error during value comparison: Predicted='{pred_val}' ({type(pred_val)}), "
              f"GT='{gt_val}' ({type(gt_val)}), GT_Str='{ground_truth_val_str}'. Error: {e}")
        return False

def run_evaluation(dataset: List[Dict[str, Any]],
                   pipeline_func: Callable[[str], Dict[str, Any]]
                  ) -> Dict[str, float]:
    correct_count = 0
    total_count = len(dataset)
    total_processing_time = 0
    results_log = []
    processed_count = 0

    if total_count == 0:
        print("Evaluation dataset is empty.")
        return {"accuracy": 0.0, "average_time_ms": 0.0}

    print(f"\n--- Starting Evaluation on {total_count} Questions ---")

    for i, item in enumerate(dataset):
        question = item.get('question')
        gt_pred = item.get('answer_predicate')
        gt_val_str = item.get('answer_value')
        parsed_gt_val = item.get('parsed_answer_value')
        item_id = item.get('id', f'item_{i}')

        if not all([isinstance(question, str), isinstance(gt_pred, str), gt_val_str is not None]):
             print(f"Warning: Skipping invalid dataset item {item_id}: Missing or invalid required fields.")
             continue

        print(f"\n--- Evaluating {item_id} ({i+1}/{len(dataset)}) ---")
        print(f"  Q: {question}")
        gt_repr = f"{gt_pred}(..., {gt_val_str})"

        start_time = time.time()
        is_correct = False
        pipeline_failed = False
        try:
            result = pipeline_func(question)
        except Exception as e:
             print(f"  ERROR: Pipeline execution failed for question {item_id}: {e}")
             result = {}
             pipeline_failed = True
        end_time = time.time()
        processing_time = (end_time - start_time) * 1000

        if not pipeline_failed:
            processed_count += 1
            total_processing_time += processing_time

            predicted_fact = result.get('answer_fact')
            adapter = result.get('adapter')
            question_info = result.get('question_info')
            reasoning_trace = result.get('reasoning_trace', [])
            explanation = result.get('explanation', "No explanation generated.")

            is_correct = evaluate_answer(predicted_fact, gt_pred, gt_val_str, parsed_gt_val)

            if is_correct:
                correct_count += 1
                print(f"  Result: CORRECT")
                print(f"  Predicted: {predicted_fact}")
            else:
                print(f"  Result: INCORRECT")
                print(f"  Predicted: {predicted_fact}")
                print(f"  Expected:  {gt_repr} (Parsed: {parsed_gt_val})")

            print(f"  Time: {processing_time:.2f} ms")

            if adapter and question_info:
                 ground_truth_fact_obj = None
                 if gt_pred and parsed_gt_val is not None:
                      subject = 'event_gt'
                      if query := question_info.get('query'):
                           if isinstance(query, tuple) and len(query) > 1 and isinstance(query[1], str):
                                subject = query[1]

                      ground_truth_fact_obj = Fact(predicate=gt_pred, subject=subject, object=parsed_gt_val)

                 try:
                      adapter.update_on_feedback(
                          question_info=question_info,
                          answer_fact=predicted_fact,
                          reasoning_trace=reasoning_trace,
                          is_correct=is_correct,
                          ground_truth_fact=ground_truth_fact_obj
                      )
                 except Exception as e:
                      print(f"  ERROR: Adaptation step failed for question {item_id}: {e}")
            else:
                 print("  WARN: Adaptation info missing in pipeline result or adapter not found. Skipping update.")

        results_log.append({
            "id": item_id,
            "question": question,
            "ground_truth_pred": gt_pred,
            "ground_truth_val_str": gt_val_str,
            "predicted_fact": str(predicted_fact) if not pipeline_failed and predicted_fact else None,
            "is_correct": is_correct,
            "pipeline_failed": pipeline_failed,
            "processing_time_ms": processing_time if not pipeline_failed else None,
        })

    accuracy = correct_count / processed_count if processed_count > 0 else 0.0
    average_time = total_processing_time / processed_count if processed_count > 0 else 0.0
    failure_rate = (total_count - processed_count) / total_count if total_count > 0 else 0.0

    print(f"\n--- Evaluation Summary ---")
    print(f"Total Questions Attempted: {total_count}")
    print(f"Successfully Processed: {processed_count}")
    print(f"Pipeline Failures: {total_count - processed_count}")
    print(f"Correct Answers (among processed): {correct_count}")
    print(f"Accuracy (on processed): {accuracy:.4f}")
    print(f"Failure Rate: {failure_rate:.4f}")
    print(f"Average Processing Time (on processed): {average_time:.2f} ms")

    return {
        "accuracy": accuracy,
        "average_time_ms": average_time,
        "processed_count": processed_count,
        "total_questions": total_count,
        "failure_rate": failure_rate
        }

# Main Pipeline and Execution

In [21]:
class NeuroSymbolicPipeline:
    def __init__(self):
        print("--- Initializing Neuro-Symbolic Pipeline ---")
        start_init = time.time()

        try:
            self.extractor = BertExtractor(model_name=config.BERT_MODEL_NAME, device=config.DEVICE)
            self.reasoner = SymbolicReasoner()
            self.translator = NeuralToSymbolicTranslator()
            self.evaluator = MetaCognitiveEvaluator()

            self.rule_generator = DynamicRuleGenerator()
            self.adapter = AdaptiveLearner(self.extractor, self.reasoner, self.rule_generator)
        except Exception as e:
             print(f"FATAL ERROR during pipeline initialization: {e}")

             raise RuntimeError("Pipeline initialization failed") from e
        end_init = time.time()
        print(f"--- Pipeline Initialization Complete ({end_init - start_init:.2f}s) ---")

    def process_question(self, question: str) -> Dict[str, Any]:

        result_dict = {
            "question": question, "question_info": None, "initial_facts": set(),
            "derived_facts": set(), "all_facts": set(), "reasoning_trace": [],
            "answer_fact": None, "plausibility": None, "explanation": "Processing failed early.",
            "adapter": self.adapter
        }

        try:

            neural_output = self.extractor.extract_entities_relations(question)
            result_dict["question_info"] = neural_output
            query_pattern = neural_output.get('query')


            self.reasoner.clear_facts()
            num_translated = self.translator.translate(neural_output, self.reasoner)
            initial_facts = self.reasoner.get_all_facts().copy()
            result_dict["initial_facts"] = initial_facts

            derived_facts, reasoning_trace = self.reasoner.reason()
            all_facts = self.reasoner.get_all_facts()
            result_dict["derived_facts"] = derived_facts
            result_dict["reasoning_trace"] = reasoning_trace
            result_dict["all_facts"] = all_facts

            if query_pattern:
                results = self.reasoner.query(query_pattern)
                if results:

                    answer_fact = results[0]

            result_dict["answer_fact"] = answer_fact

            plausibility_result = None
            if answer_fact:
                plausibility_result = self.evaluator.evaluate_plausibility(answer_fact, all_facts)

            result_dict["plausibility"] = plausibility_result

            explanation = self.evaluator.generate_explanation(
                query=query_pattern,
                answer_fact=answer_fact,
                reasoning_trace=reasoning_trace,
                plausibility_result=plausibility_result
            )
            result_dict["explanation"] = explanation

        except Exception as e:
             print(f"ERROR during pipeline processing for question '{question[:50]}...': {e}")
             result_dict["explanation"] = f"Pipeline processing failed: {e}"

        return result_dict

def run_pipeline_evaluation_main(pipeline_instance: NeuroSymbolicPipeline, dataset_path: str):
    create_sample_data_file(dataset_path)
    dataset = load_qa_data(dataset_path)

    if not dataset:
        print("Evaluation skipped: Failed to load dataset.")
        return None

    def pipeline_func_wrapper(question: str) -> Dict[str, Any]:
         try:
             return pipeline_instance.process_question(question)
         except Exception as e:
             print(f"ERROR in pipeline_func_wrapper for question '{question[:50]}...': {e}")

             return {
                 "question": question, "answer_fact": None, "adapter": pipeline_instance.adapter,
                 "question_info": None, "reasoning_trace": [], "initial_facts": set(),
                 "derived_facts": set(), "all_facts": set(), "plausibility": None,
                 "explanation": f"Pipeline failed: {e}"
             }

    print("\n--- Running Evaluation Loop ---")
    evaluation_metrics = run_evaluation(dataset, pipeline_func_wrapper)
    print(f"\n--- Final Evaluation Metrics ---")
    pprint.pprint(evaluation_metrics)

    print("\n--- Attempting Post-Evaluation Batch Rule Generation ---")
    if hasattr(pipeline_instance, 'rule_generator') and hasattr(pipeline_instance, 'reasoner'):
        batch_rules = pipeline_instance.rule_generator.generate_rules_from_stored_failures(
            pipeline_instance.reasoner.get_all_rules()
        )
        if batch_rules:
            print(f"Generated {len(batch_rules)} new rules from batch processing:")
            added_count = 0
            for rule in batch_rules:
                print(f" - {rule}")

                if not pipeline_instance.rule_generator._is_rule_redundant(rule, pipeline_instance.reasoner.get_all_rules()):
                     added_rule = pipeline_instance.reasoner.add_rule(
                          conditions=rule.conditions, conclusion=rule.conclusion,
                          confidence=rule.confidence, source=rule.source
                     )
                     print(f"   Added Rule {added_rule.id} to reasoner.")
                     added_count += 1
                else:
                     print(f"   (Rule {rule.id} already exists or is redundant, not re-adding)")
            print(f"Added {added_count} unique rules from batch generation.")
        else:
            print("No new rules generated from batch processing.")
    else:
        print("WARN: Rule generator or reasoner not found in pipeline instance for batch generation.")

    return evaluation_metrics

def main():
    print("=============================================")
    print("=== Neuro-Symbolic Meta-Cognitive System ===")
    print("=============================================")

    try:
        pipeline = NeuroSymbolicPipeline()
    except Exception as e:
        print(f"Could not initialize pipeline: {e}. Exiting.")
        return

    run_pipeline_evaluation_main(pipeline, config.DATA_PATH)

    print("\n=============================================")
    print("=== Processing New Question Post-Adaptation ===")
    print("=============================================")

    new_q = "My train journey starts at 1 PM and takes 90 minutes. What time do I arrive?"

    arrival_rule_exists = any(
        r.conclusion[0] == 'arrival_time' and
        any(c[0]=='start_time' for c in r.conditions)
        for r in pipeline.reasoner.get_all_rules()
    )
    if not arrival_rule_exists:
         print(f"INFO: No direct start_time -> arrival_time rule found. Processing '{new_q}' might fail or require generation.")

    try:
        result = pipeline.process_question(new_q)
        print(f"\nProcessed New Question: '{new_q}'")
        print(f"Predicted Answer: {result.get('answer_fact')}")
        print("\n--- Explanation for New Question ---")
        print(result.get('explanation', 'No explanation available.'))
    except Exception as e:
        print(f"ERROR processing new question '{new_q}': {e}")

    print("\n--- Final State of Rules ---")
    if hasattr(pipeline, 'reasoner'):
        rules = pipeline.reasoner.get_all_rules()
        if rules:
             rules.sort(key=lambda r: r.id)
             for rule in rules:
                  print(rule)
        else:
             print("No rules found in the reasoner.")
    else:
        print("Pipeline or reasoner object not available.")

    print("\n=============================================")
    print("=== Execution Finished ===")
    print("=============================================")

if __name__ == "__main__":
    main()

=== Neuro-Symbolic Meta-Cognitive System ===
--- Initializing Neuro-Symbolic Pipeline ---
Loading BERT model: bert-base-uncased...
BERT model loaded.
Added 6 core temporal rules.
Initialized DynamicRuleGenerator with 5 templates.
--- Pipeline Initialization Complete (0.88s) ---
Attempting to load QA data from: data/sample_qa.json
Successfully loaded and processed 6 QA pairs from data/sample_qa.json.

--- Running Evaluation Loop ---

--- Starting Evaluation on 6 Questions ---

--- Evaluating q1 (1/6) ---
  Q: If I have a meeting at 3 PM and it takes 30 minutes to get there, when should I leave?
  Result: CORRECT
  Predicted: departure_time(event1, 02:30 PM)
  Time: 16.88 ms

--- Adaptation Triggered ---
  Question: "If I have a meeting at 3 PM and it takes 30 minute..."
  Answer Correct: True
  Provided Answer: departure_time(event1, 02:30 PM)
  ERROR: Adaptation step failed for question q1: unhashable type: 'Rule'

--- Evaluating q2 (2/6) ---
  Q: The workshop starts at 9 AM and runs f