<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/Reflective_Updater.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Iterable, Optional, Protocol, Tuple, List, Sequence, Iterator
import logging
import itertools
import statistics


# --- Protocols / lightweight contracts --------------------------------------

class EpisodeLike(Protocol):
    @property
    def input(self) -> Any: ...
    @property
    def outcome(self) -> Any: ...


class MemoryLike(Protocol):
    def recent(self) -> Iterable[EpisodeLike]: ...
    # Optional fast sampling/bounded retrieval (used if present)
    def recent_n(self, n: int) -> Sequence[EpisodeLike]: ...


class WorldModelLike(Protocol):
    def simulate(self, x: Any) -> Any: ...
    def update(self, contradictory: List[EpisodeLike]) -> None: ...
    # Optional vectorized APIs
    def batch_simulate(self, xs: Sequence[Any]) -> Sequence[Any]: ...
    def batch_update(self, contradictory: Sequence[EpisodeLike]) -> None: ...


class SelfModelLike(Protocol):
    def adapt(self, contradictory: List[EpisodeLike]) -> None: ...


class CurriculumScheduler(Protocol):
    """
    Controls per-step threshold and which contradictions to use.
    Implement any subset you need; defaults are provided by the updater.
    """
    def decide_threshold(self, step: int, prev_report: Optional["TrainingReport"]) -> Optional[float]: ...
    def select(self, contradictions: List["Contradiction"], step: int) -> List["Contradiction"]: ...


SimilarityFn = Callable[[Any, Any], float]


# --- Data structures ---------------------------------------------------------

@dataclass(frozen=True)
class Contradiction:
    episode: EpisodeLike
    prediction: Any
    score: float
    reason: str = "score_below_threshold"  # or "simulate_error" / "similarity_error"


@dataclass(frozen=True)
class TrainingReport:
    step: int
    processed: int
    contradictory: int
    threshold: float
    avg_score_all: Optional[float]
    avg_score_contradictions: Optional[float]


# --- Utilities ---------------------------------------------------------------

def _chunked(seq: Sequence[Any], n: int) -> Iterator[Sequence[Any]]:
    if n <= 0:
        raise ValueError("batch_size must be > 0")
    for i in range(0, len(seq), n):
        yield seq[i:i+n]


def _take(iterable: Iterable[Any], limit: Optional[int]) -> List[Any]:
    if limit is None:
        return list(iterable)
    return list(itertools.islice(iterable, int(limit)))


# --- Implementation -----------------------------------------------------------

class ReflectiveUpdater:
    """
    Cross-check recent episodes against the world model; on contradiction,
    update both the world and self models. Adds auditable detection, batching,
    and curriculum-aware training.

    Parameters
    ----------
    episodic_memory : MemoryLike
        Provides recent() -> iterable of episodes with .input and .outcome.
        If it implements recent_n(n), batching becomes more efficient.
    world_model : WorldModelLike
        Must implement simulate(x) and update(episodes).
        If it implements batch_simulate/batch_update, those are used.
    self_model : SelfModelLike
        Must implement adapt(episodes).
    similarity_fn : Callable[[pred, actual], float]
        Returns a consistency score; higher means more similar.
    threshold : float, default 0.8
        Minimum score to consider pred and actual consistent (inclusive).
    logger : logging.Logger, optional
        For warnings/debugging; defaults to module logger.
    """

    def __init__(
        self,
        episodic_memory: MemoryLike,
        world_model: WorldModelLike,
        self_model: SelfModelLike,
        similarity_fn: SimilarityFn,
        *,
        threshold: float = 0.8,
        logger: Optional[logging.Logger] = None,
    ) -> None:
        self.memory = episodic_memory
        self.world = world_model
        self.self_model = self_model
        self.similarity_fn = similarity_fn
        self.threshold = float(threshold)
        self.log = logger or logging.getLogger(__name__)
        self.step: int = 0
        self._last_report: Optional[TrainingReport] = None

    # --- Basic consistency utilities ----------------------------------------

    def consistency_score(self, pred: Any, actual: Any) -> float:
        return float(self.similarity_fn(pred, actual))

    def _is_consistent(self, pred: Any, actual: Any, *, threshold: Optional[float] = None) -> bool:
        th = self.threshold if threshold is None else float(threshold)
        return self.consistency_score(pred, actual) >= th

    # --- Detection: simple and detailed -------------------------------------

    def detect_contradictions(self, *, limit: Optional[int] = None, threshold: Optional[float] = None) -> Iterable[EpisodeLike]:
        """
        Yield episodes whose predicted outcome is inconsistent with actual.
        """
        th = self.threshold if threshold is None else float(threshold)
        for episode in self._iter_recent(limit):
            try:
                prediction = self._simulate_one(episode.input)
            except Exception as e:
                self.log.warning("simulate() failed; treating as contradiction: %s", e)
                yield episode
                continue

            try:
                if not self._is_consistent(prediction, episode.outcome, threshold=th):
                    yield episode
            except Exception as e:
                self.log.warning("similarity check failed; treating as contradiction: %s", e)
                yield episode

    def detect_contradictions_detailed(
        self,
        *,
        limit: Optional[int] = None,
        threshold: Optional[float] = None,
        batch_size: Optional[int] = None,
    ) -> List[Contradiction]:
        """
        Return detailed contradictions with predictions and scores for audit.
        Uses vectorized simulate if available; otherwise falls back to per-item.
        """
        th = self.threshold if threshold is None else float(threshold)
        episodes = self._collect_recent(limit)
        if not episodes:
            return []

        # Try vectorized simulate
        preds: List[Any]
        try:
            preds = self._simulate_many([e.input for e in episodes], batch_size=batch_size)
        except Exception as e:
            self.log.warning("batch simulate failed; falling back to itemwise: %s", e)
            preds = [self._simulate_one(e.input) for e in episodes]

        contradictions: List[Contradiction] = []
        for ep, pred in zip(episodes, preds):
            try:
                score = self.consistency_score(pred, ep.outcome)
                if score < th:
                    contradictions.append(Contradiction(ep, pred, score, reason="score_below_threshold"))
            except Exception as e:
                self.log.warning("similarity check failed; marking contradiction: %s", e)
                contradictions.append(Contradiction(ep, pred, float("nan"), reason="similarity_error"))
        return contradictions

    # --- Revision: immediate and batched ------------------------------------

    def revise_model(self) -> int:
        """
        Collect contradictions and update world/self models.
        Returns the number of contradictory episodes processed.
        """
        contradictory = list(self.detect_contradictions())
        if not contradictory:
            self.log.debug("No contradictions detected.")
            return 0

        self.log.debug("Updating with %d contradictory episodes.", len(contradictory))
        self._update_models(contradictory)
        return len(contradictory)

    def revise_model_batched(self, *, batch_size: int = 64, threshold: Optional[float] = None) -> int:
        """
        Detect contradictions and update in batches for memory/latency control.
        """
        th = self.threshold if threshold is None else float(threshold)
        detailed = self.detect_contradictions_detailed(threshold=th, batch_size=batch_size)
        if not detailed:
            self.log.debug("No contradictions detected.")
            return 0
        episodes = [c.episode for c in detailed]
        self.log.debug("Batched update with %d contradictory episodes (batch_size=%d).", len(episodes), batch_size)
        self._update_models(episodes, batch_size=batch_size)
        return len(episodes)

    # --- Training step with curriculum --------------------------------------

    def training_step(
        self,
        *,
        max_episodes: int = 256,
        batch_size: Optional[int] = None,
        scheduler: Optional[CurriculumScheduler] = None,
        threshold: Optional[float] = None,
    ) -> TrainingReport:
        """
        One training step:
        1) Pull up to max_episodes recent episodes.
        2) Compute predictions and scores (batched if available).
        3) Apply threshold (from scheduler or argument or default).
        4) Optionally let scheduler select a subset to update.
        5) Update world/self (batched if requested).

        Returns a TrainingReport with counts and averages.
        """
        self.step += 1

        # Decide threshold order: scheduler > arg > self.threshold
        th_sched = scheduler.decide_threshold(self.step, self._last_report) if scheduler else None
        th = self.threshold
        if th_sched is not None:
            th = float(th_sched)
        elif threshold is not None:
            th = float(threshold)

        episodes = self._collect_recent(max_episodes)
        if not episodes:
            report = TrainingReport(self.step, 0, 0, th, None, None)
            self._last_report = report
            return report

        # Predict (try batch)
        try:
            preds = self._simulate_many([e.input for e in episodes], batch_size=batch_size)
        except Exception as e:
            self.log.warning("batch simulate failed; falling back to itemwise: %s", e)
            preds = [self._simulate_one(e.input) for e in episodes]

        scores: List[float] = []
        contradictions: List[Contradiction] = []
        for ep, pred in zip(episodes, preds):
            try:
                s = self.consistency_score(pred, ep.outcome)
                scores.append(s)
                if s < th:
                    contradictions.append(Contradiction(ep, pred, s, reason="score_below_threshold"))
            except Exception as e:
                self.log.warning("similarity check failed; marking contradiction: %s", e)
                contradictions.append(Contradiction(ep, pred, float("nan"), reason="similarity_error"))

        # Curriculum selection
        use_contras = contradictions
        if scheduler and hasattr(scheduler, "select"):
            try:
                use_contras = scheduler.select(contradictions, self.step)  # type: ignore[attr-defined]
            except Exception as e:
                self.log.warning("scheduler.select failed; using all contradictions: %s", e)

        # Update models
        if use_contras:
            self._update_models([c.episode for c in use_contras], batch_size=batch_size)

        avg_all = statistics.fmean(scores) if scores else None
        only_scores = [c.score for c in use_contras if c.score == c.score]  # filter NaN
        avg_contra = statistics.fmean(only_scores) if only_scores else None

        report = TrainingReport(
            step=self.step,
            processed=len(episodes),
            contradictory=len(use_contras),
            threshold=th,
            avg_score_all=avg_all,
            avg_score_contradictions=avg_contra,
        )
        self._last_report = report
        return report

    # --- Internals: recent collection, simulate, update ---------------------

    def _iter_recent(self, limit: Optional[int]) -> Iterable[EpisodeLike]:
        # Prefer recent_n if available for efficiency
        if limit is not None and hasattr(self.memory, "recent_n"):
            try:
                seq = self.memory.recent_n(int(limit))  # type: ignore[attr-defined]
                return seq
            except Exception:
                pass
        return _take(self.memory.recent(), limit)

    def _collect_recent(self, limit: Optional[int]) -> List[EpisodeLike]:
        recent = self._iter_recent(limit)
        return list(recent) if not isinstance(recent, list) else recent

    def _simulate_one(self, x: Any) -> Any:
        return self.world.simulate(x)

    def _simulate_many(self, xs: Sequence[Any], *, batch_size: Optional[int]) -> List[Any]:
        # If world has batch_simulate, prefer it with optional chunking
        if hasattr(self.world, "batch_simulate"):
            bs = int(batch_size) if batch_size else len(xs)
            preds: List[Any] = []
            for chunk in _chunked(list(xs), bs):
                part = self.world.batch_simulate(chunk)  # type: ignore[attr-defined]
                preds.extend(list(part))
            return preds
        # Fallback itemwise
        return [self.world.simulate(x) for x in xs]

    def _update_models(self, episodes: Sequence[EpisodeLike], *, batch_size: Optional[int] = None) -> None:
        # World model update: prefer batch_update if available; otherwise chunked calls to update
        if hasattr(self.world, "batch_update"):
            if batch_size and batch_size > 0 and batch_size < len(episodes):
                for chunk in _chunked(list(episodes), int(batch_size)):
                    self.world.batch_update(chunk)  # type: ignore[attr-defined]
            else:
                self.world.batch_update(episodes)  # type: ignore[attr-defined]
        else:
            # Fallback to update() expecting a list
            if batch_size and batch_size > 0 and batch_size < len(episodes):
                for chunk in _chunked(list(episodes), int(batch_size)):
                    self.world.update(list(chunk))
            else:
                self.world.update(list(episodes))
        # Self model always takes a list
        self.self_model.adapt(list(episodes))

class LinearThresholdScheduler:
    """
    Linearly increase threshold from start -> end over total_steps.
    Also selects the hardest half of contradictions (lowest scores).
    """
    def __init__(self, start: float = 0.6, end: float = 0.95, total_steps: int = 100):
        self.start = float(start)
        self.end = float(end)
        self.total_steps = int(total_steps)

    def decide_threshold(self, step: int, prev_report: Optional[TrainingReport]) -> Optional[float]:
        t = max(1, min(step, self.total_steps))
        alpha = (t - 1) / max(1, self.total_steps - 1)
        return (1 - alpha) * self.start + alpha * self.end

    def select(self, contradictions: List[Contradiction], step: int) -> List[Contradiction]:
        if not contradictions:
            return []
        # Keep the harder half (lower scores first, NaN treated as hardest)
        def key(c: Contradiction) -> float:
            return c.score if c.score == c.score else -1.0  # NaN -> -1 to sort first
        sorted_cs = sorted(contradictions, key=key)
        k = max(1, len(sorted_cs) // 2)
        return sorted_cs[:k]

# 1. Dummy episode format
@dataclass
class Episode:
    input: str
    outcome: str

# 2. Dummy memory
class DummyMemory:
    def recent(self) -> Iterable[Episode]:
        return [
            Episode("A", "outcome1"),
            Episode("B", "outcome2"),
            Episode("C", "outcome3"),
        ]

# 3. Dummy world model
class DummyWorldModel:
    def simulate(self, x: Any) -> str:
        return {
            "A": "outcome1",       # correct
            "B": "wrong_outcome",  # incorrect
            "C": "outcome3"        # correct
        }.get(x, "unknown")

    def update(self, contradictory: List[Episode]) -> None:
        print(f"[WorldModel] Updated with: {[e.input for e in contradictory]}")

# 4. Dummy self model
class DummySelfModel:
    def adapt(self, contradictory: List[Episode]) -> None:
        print(f"[SelfModel] Adapted from: {[e.input for e in contradictory]}")

# 5. Similarity function
def basic_similarity(a: str, b: str) -> float:
    return 1.0 if a == b else 0.0

# 6. Instantiate the updater
updater = ReflectiveUpdater(
    episodic_memory=DummyMemory(),
    world_model=DummyWorldModel(),
    self_model=DummySelfModel(),
    similarity_fn=basic_similarity,
    threshold=0.8
)

# 7. Run audited contradiction detection
contradictions = list(updater.detect_contradictions_detailed(limit=512, batch_size=128))
for c in contradictions:
    print(f"[Contradiction] Input={c.episode.input} → Pred={c.prediction}, Actual={c.episode.outcome}, Score={c.score}")

# 8. Run batched model revision
n = updater.revise_model_batched(batch_size=128)
print(f"[Revision] Contradictions processed: {n}")

# 9. Curriculum training with linear threshold ramp
sched = LinearThresholdScheduler(start=0.7, end=0.95, total_steps=50)
report = updater.training_step(max_episodes=512, batch_size=128, scheduler=sched)

# 10. Report summary
print(f"[Training Report] Step={report.step}, Processed={report.processed}, "
      f"Contradictory={report.contradictory}, Threshold={report.threshold}, "
      f"Avg All={report.avg_score_all}, Avg Contra={report.avg_score_contradictions}")