<a href="https://colab.research.google.com/github/Bhuvanesh2218K/Chatbot_to_Known_Individual_Prakriti/blob/main/Chaturya.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Phase => 1 : Dataset preparation and clustering

Phase 1.1 => Feature Selector

In [None]:

%%writefile /content/FeatureSelector.py
# FeatureSelector.py

import pandas as pd

class FeatureSelector:
    def __init__(self, cat_max_unique=20):
        self.cat_max_unique = cat_max_unique
        self.selected_features = []
        self.dropped_features = []
        self.encoders = {}

    def fit(self, df: pd.DataFrame):
        """Decide which columns to keep/drop and build encoders for categoricals."""
        self.selected_features.clear()
        self.dropped_features.clear()
        self.encoders.clear()

        for col in df.columns:
            lower = col.lower()

            # Drop IDs, dates, names
            if "id" in lower or "date" in lower or "name" in lower:
                self.dropped_features.append(col)
                continue

            # Categorical
            if df[col].dtype == "object" or df[col].dtype.name == "category":
                nunq = df[col].nunique(dropna=True)
                if nunq > 0.5 * len(df) or nunq > self.cat_max_unique:
                    self.dropped_features.append(col)
                else:
                    self.selected_features.append(col)
                    self.encoders[col] = df[col].dropna().unique().tolist()
                continue

            # Numeric
            if pd.api.types.is_numeric_dtype(df[col]):
                self.selected_features.append(col)
                continue

            # Otherwise drop
            self.dropped_features.append(col)

        return self

    def transform(self, df: pd.DataFrame):
        """Apply selection + one-hot encoding to dataset."""
        df_sel = df[self.selected_features].copy()

        # Expand categoricals
        for col, vals in self.encoders.items():
            for v in vals:
                df_sel[f"{col}_{v}"] = (df_sel[col] == v).astype(int)
            df_sel.drop(columns=[col], inplace=True)

        return df_sel

    def fit_transform(self, df: pd.DataFrame):
        return self.fit(df).transform(df)

    def get_selected_features(self):
        return list(self.selected_features)

    def get_dropped_features(self):
        return list(self.dropped_features)

Writing /content/FeatureSelector.py


Phase 1.2 => DiffcultyScore

In [None]:
%%writefile /content/DifficultyScorer.py
# DifficultyScorer.py

import pandas as pd

class DifficultyScorer:
    def __init__(self,
                 weights=None,
                 threshold=0.75,
                 important_features=None,
                 domain_factor=0.0):
        """
        Computes row-wise difficulty scores + tree_type for a DataFrame.
        """
        self.weights = weights or {
            "keyword":   0.4,
            "length":    0.3,
            "ambiguity": 0.2,
            "domain":    0.1
        }
        self.threshold = threshold
        self.important_features = important_features or []
        self.domain_factor = domain_factor

    def fit(self, df: pd.DataFrame, y=None):
        # No learning needed
        return self

    def transform(self, df: pd.DataFrame):
        records = []
        for _, row in df.iterrows():
            feats = row.to_dict()
            n_total = len(feats)
            if n_total == 0:
                records.append({
                    "difficulty_score": 0.0,
                    "tree_type": "binary"
                })
                continue

            def is_filled(x):
                return x is not None and x == x and x != ""

            n_filled = sum(1 for v in feats.values() if is_filled(v))
            n_missing = n_total - n_filled

            if self.important_features:
                k_pres = sum(
                    1 for k in self.important_features
                    if k in feats and is_filled(feats[k])
                )
                K = k_pres / len(self.important_features)
            else:
                K = n_filled / n_total

            L = n_filled / n_total
            A = n_missing / n_total
            D = float(self.domain_factor)

            w = self.weights
            score = (
                w["keyword"]   * K +
                w["length"]    * L +
                w["ambiguity"] * A +
                w["domain"]    * D
            )

            tree = "three-tree" if score > self.threshold else "binary"
            records.append({
                "difficulty_score": round(score, 4),
                "tree_type": tree
            })

        return pd.DataFrame(records, index=df.index)

    def fit_transform(self, df: pd.DataFrame, y=None):
        return self.fit(df, y).transform(df)

Writing /content/DifficultyScorer.py


Phase 3 => Integration Phase

In [None]:
%%writefile /content/Phase1Pipeline.py

import pandas as pd
from FeatureSelector import FeatureSelector
from DifficultyScorer import DifficultyScorer # Corrected import

class Phase1Pipeline:
    def __init__(self,
                 cat_max_unique=20,
                 threshold=0.75,
                 important_features=None,
                 domain_factor=0.0):
        self.selector = FeatureSelector(cat_max_unique=cat_max_unique)
        self.scorer   = DifficultyScorer(
            threshold=threshold,
            important_features=important_features,
            domain_factor=domain_factor
        )

    def fit(self, df: pd.DataFrame):
        """Fit selector on raw df, then scorer on selected features."""
        self.selector.fit(df)
        selected = self.selector.transform(df)
        self.scorer.fit(selected)
        return self

    def transform(self, df: pd.DataFrame):
        """Run selection + scoring and merge into one DataFrame."""
        X_sel = self.selector.transform(df)
        scores = self.scorer.transform(X_sel)
        return pd.concat(
            [X_sel.reset_index(drop=True),
             scores.reset_index(drop=True)],
            axis=1
        )

    def fit_transform(self, df: pd.DataFrame):
        return self.fit(df).transform(df)

Writing /content/Phase1Pipeline.py


# Phase => 2 : Tree Construction

In [None]:

%%writefile /content/TreeNodeV1.py
import math
import logging
from datetime import datetime

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class TreeNodeV1:
    def __init__(self,
                 node_id=None,
                 value=None,
                 children=None,
                 level: int = 0,
                 rule=None,
                 confidence: float = 0.0,
                 lock_flag: str = "unlocked",  # "soft", "hard", "unlocked"
                 provenance: str = None,
                 difficulty_tag: str = None,
                 branch_tag: str = None):
        self.node_id = node_id
        self.value = value
        self.children = children or []
        self.level = max(0, level)
        self.rule = rule
        self.confidence = confidence
        self.lock_flag = lock_flag
        self.updated_at = datetime.utcnow()
        self.provenance = provenance
        self.conflicts = []
        self.difficulty_tag = difficulty_tag
        self.branch_tag = branch_tag  # "L", "M", "R" for three-tree
        self.cached_vector = None

    # --- Existing methods ---
    def is_leaf(self) -> bool:
        return not self.children

    def add_child(self, node):
        if node is not None:
            self.children.append(node)
            self.updated_at = datetime.utcnow()
            logger.debug(f"Added child to node {self.node_id} at level {self.level}")

    def store_vector(self, vec):
        self.cached_vector = vec
        self.updated_at = datetime.utcnow()

    def get_vector(self):
        return self.cached_vector

    def get_depth(self) -> int:
        if self.is_leaf():
            return 1
        depths = [child.get_depth() for child in self.children if child is not None]
        return 1 + max(depths) if depths else 1

    def get_confidence(self) -> float:
        depth = self.get_depth()
        self.confidence = min(1.0, math.log(depth + 1) / 5)
        return self.confidence

    # --- New helpers ---
    def lock(self, mode="soft"):
        """Lock node to prevent overwrites."""
        self.lock_flag = mode
        self.updated_at = datetime.utcnow()

    def unlock(self):
        self.lock_flag = "unlocked"
        self.updated_at = datetime.utcnow()

    def add_conflict(self, conflict_note: str):
        self.conflicts.append(conflict_note)
        self.updated_at = datetime.utcnow()

Writing /content/TreeNodeV1.py


In [None]:

%%writefile /content/TreeBuilderV2.py
import logging
import torch
from collections import deque
from typing import List, Tuple, Optional, Union
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class TreeBuilderV2:
    def __init__(self, device: str = "cpu", dim: int = 50, mode: str = "binary"):
        """
        mode: 'binary' or 'three'
        """
        self.device = device
        self.dim = dim
        self.mode = mode
        self.max_warnings = 5
        self.failed_samples = set()

    def build_tree(self, vec_pairs: List[Tuple[str, Optional[torch.Tensor]]], sample_id="unknown") -> Optional[TreeNodeV1]:
        """Build a binary or three-tree from token-vector pairs."""
        if not vec_pairs:
            logger.warning(f"TreeBuilderV2 failed for {sample_id}: Empty vec_pairs")
            return None

        # Root node
        root = TreeNodeV1(node_id=sample_id, value="root", level=0)

        # BFS queue
        queue = deque([(root, vec_pairs)])

        while queue:
            parent, pairs = queue.popleft()

            # Decide branching factor
            if self.mode == "binary":
                chunk_size = 2
                branch_tags = ["L", "R"]
            else:  # three-tree
                chunk_size = 3
                branch_tags = ["L", "M", "R"]

            # Split pairs into chunks
            for i in range(0, len(pairs), chunk_size):
                chunk = pairs[i:i+chunk_size]
                for j, (token, vector) in enumerate(chunk):
                    if vector is None or (vector.shape[0] != self.dim):
                        logger.warning(f"Skipping invalid vector for token {token}")
                        continue

                    child = TreeNodeV1(
                        node_id=f"{sample_id}_{i}_{j}",
                        value=token,
                        level=parent.level + 1,
                        branch_tag=branch_tags[j % len(branch_tags)]
                    )
                    child.store_vector(vector.to(self.device))
                    parent.add_child(child)

                    # For now, stop at one level (can extend deeper with rules)
                    # queue.append((child, next_pairs))  # placeholder for deeper BFS

        # Compute root vector as mean of children
        if root.children:
            child_vectors = [c.get_vector() for c in root.children if c.get_vector() is not None]
            if child_vectors:
                root_vector = torch.stack(child_vectors).mean(dim=0)
                root.store_vector(root_vector)

        return root

    # --- DFS tracing for explanations ---
    def trace_dfs(self, node: TreeNodeV1, path=None) -> List[str]:
        """Return a list of node values in DFS order."""
        if path is None:
            path = []
        path.append(node.value)
        for child in node.children:
            self.trace_dfs(child, path)
        return path

Writing /content/TreeBuilderV2.py


In [None]:
%%writefile /content/LockManager.py
import logging
from datetime import datetime
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class LockManager:
    def __init__(self):
        self.audit_log = []  # list of dicts with {node_id, action, mode, timestamp}

    def lock_node(self, node: TreeNodeV1, mode="soft"):
        """Lock a node (soft/hard)."""
        if node.lock_flag == "hard":
            logger.warning(f"Node {node.node_id} already hard-locked, cannot override")
            return False
        node.lock(mode)
        event = {
            "node_id": node.node_id,
            "action": "lock",
            "mode": mode,
            "timestamp": datetime.utcnow().isoformat()
        }
        self.audit_log.append(event)
        logger.info(f"Locked node {node.node_id} with mode={mode}")
        return True

    def unlock_node(self, node: TreeNodeV1):
        """Unlock a node."""
        if node.lock_flag == "hard":
            logger.warning(f"Node {node.node_id} is hard-locked, cannot unlock")
            return False
        node.unlock()
        event = {
            "node_id": node.node_id,
            "action": "unlock",
            "mode": "unlocked",
            "timestamp": datetime.utcnow().isoformat()
        }
        self.audit_log.append(event)
        logger.info(f"Unlocked node {node.node_id}")
        return True

    def get_audit_log(self):
        return self.audit_log

Writing /content/LockManager.py


In [None]:

%%writefile /content/TreeSnapshot.py
import math
from collections import deque
from TreeNodeV1 import TreeNodeV1

class TreeSnapshot:
    def __init__(self, root: TreeNodeV1):
        self.root = root
        self.depth = 0
        self.node_count = 0
        self.leaf_count = 0
        self.branching_factor = 0.0
        self.entropy = 0.0
        self.weak_leaves = 0

        # Extras to keep EIS stable even if caller forgets to set them
        self.branch_flip_rate = 0.0  # 0..1 (if you compute flips, overwrite upstream)

        if root:
            self._analyze()

    def _analyze(self):
        queue = deque([self.root])
        total_children = 0
        non_leaf_count = 0
        leaf_confidences = []
        max_level = 0
        total_nodes = 0

        while queue:
            node = queue.popleft()
            total_nodes += 1
            max_level = max(max_level, getattr(node, "level", 0))

            if node.is_leaf():
                self.leaf_count += 1
                try:
                    leaf_confidences.append(float(node.get_confidence()))
                except Exception:
                    leaf_confidences.append(0.0)
            else:
                non_leaf_count += 1
                c = getattr(node, "children", []) or []
                total_children += len(c)
                for child in c:
                    queue.append(child)

        self.node_count = total_nodes
        self.depth = max_level

        # Branching factor over non-leaves
        if non_leaf_count > 0:
            self.branching_factor = total_children / float(non_leaf_count)

        # Root-level entropy over child fanouts
        if self.root and getattr(self.root, "children", []):
            counts = [len(c.children or []) for c in self.root.children]
            total = sum(counts) if sum(counts) > 0 else 1
            probs = [(c / total) for c in counts]
            self.entropy = -sum(p * math.log(p + 1e-9, 2) for p in probs if p > 0)

        # Weak leaves by confidence threshold
        self.weak_leaves = sum(1 for c in leaf_confidences if c < 0.3)

    def to_dict(self):
        return {
            "depth": int(self.depth),
            "node_count": int(self.node_count),
            "leaf_count": int(self.leaf_count),
            "branching_factor": round(float(self.branching_factor), 3),
            "entropy": round(float(self.entropy), 4),
            "weak_leaves": int(self.weak_leaves),

            # keep EIS fields present to avoid KeyErrors upstream
            "branch_flip_rate": 0.0,
        }

Writing /content/TreeSnapshot.py


In [None]:
%%writefile /content/RowBatchSummary.py
# RowBatchSummary.py
from collections import defaultdict, Counter
import numpy as np

class RowBatchSummary:
    def __init__(self, batch_results: list):
        """
        Robust aggregator over a batch of row-level tree results.

        Expected item format (keys are optional; handled defensively):
        {
          "row_index": int,
          "features": dict,                # row feature map (can be mixed types)
          "path": list[str],               # DFS path (leaf last)
          "prev_path": list[str],          # previous path (for flip detection)
          "snapshot": dict                 # optional structural stats
        }
        """
        self.batch_results = batch_results or []
        self.branch_flip_rate = 0.0
        self.stability_score = 0.0
        self.feature_stats = {}
        self.coverage = {}
        self._analyze()

    def _analyze(self):
        flips = 0
        total = len(self.batch_results)
        branch_counts = Counter()
        feature_accum = defaultdict(list)

        for row in self.batch_results:
            # Defensive access
            prev_path = row.get("prev_path")
            path = row.get("path")
            features = row.get("features", {})

            # Flip detection
            if isinstance(prev_path, list) and isinstance(path, list) and prev_path != path:
                flips += 1

            # Coverage (count leaf occurrences)
            if isinstance(path, list) and len(path) > 0:
                leaf_id = path[-1]
                branch_counts[leaf_id] += 1

            # Feature stats (numeric-only, robust casting)
            if isinstance(features, dict):
                for k, v in features.items():
                    # Try to coerce to float; skip non-numeric
                    val = None
                    try:
                        # Handle numpy scalars, booleans, ints, floats, numeric strings
                        if isinstance(v, (bool, int, float, np.number)):
                            val = float(v)
                        elif isinstance(v, str):
                            # Attempt to parse numeric strings; skip categorical text
                            val = float(v)
                        # Else: unsupported type → skip
                    except Exception:
                        val = None

                    # Accept only finite numbers
                    if val is not None and np.isfinite(val):
                        feature_accum[k].append(val)

        # Flip rate and stability
        self.branch_flip_rate = (flips / total) if total > 0 else 0.0
        self.stability_score = 1.0 - self.branch_flip_rate

        # Means over accumulated numeric features
        self.feature_stats = {
            k: float(np.mean(vals))
            for k, vals in feature_accum.items()
            if isinstance(vals, list) and len(vals) > 0
        }

        # Coverage distribution of leaves
        self.coverage = dict(branch_counts)

    def to_dict(self):
        return {
            "branch_flip_rate": round(self.branch_flip_rate, 3),
            "stability_score": round(self.stability_score, 3),
            "feature_stats": self.feature_stats,
            "coverage": self.coverage,
        }

Writing /content/RowBatchSummary.py


# Phase => 3 : Neural Intergation

In [None]:
%%writefile anchor_extractor.py
"""
Tree-first Anchor Extractor.
Emits anchors tied to tree nodes based on node features and snapshot stats.
"""

from dataclasses import dataclass
from typing import Dict, List, Optional


@dataclass
class Anchor:
    anchor_id: str
    node_id: str
    anchor_type: str
    features: Dict
    provenance: str = "tree_based"
    span: Optional[Dict] = None  # {"start": int, "end": int}


@dataclass
class AnchorConfig:
    branching_threshold: int = 2
    entropy_threshold: float = 0.6
    leaf_token_threshold: int = 8


def _has_digits(tokens: List[str]) -> bool:
    for t in tokens or []:
        if any(ch.isdigit() for ch in t):
            return True
    return False


def _has_symbols(tokens: List[str]) -> bool:
    SYMBOLS = set("=+-*/^%")
    for t in tokens or []:
        if any(ch in SYMBOLS for ch in t):
            return True
    return False


def extract_anchors(snapshot: Dict, config: Optional[AnchorConfig] = None) -> List[Dict]:
    """
    Extract anchors from a TreeSnapshot.

    Expected minimal snapshot format:
      snapshot = {
        "nodes": {
          node_id: {
            "depth": int,
            "entropy": float,
            "tokens": List[str],
            "is_leaf": bool,
            "children": List[str],
            "locked": bool,
            "weak": bool,               # optional
            "is_heading": bool,         # optional
            "is_code_block": bool,      # optional
            "span": {"start": int, "end": int}  # optional
          },
          ...
        },
        "root_id": str
      }

    Returns: List[Anchor as plain dict]
    """
    cfg = config or AnchorConfig()
    nodes = snapshot.get("nodes", {})
    root_id = snapshot.get("root_id")

    anchors: List[Anchor] = []

    for node_id, meta in nodes.items():
        depth = meta.get("depth", 0)
        entropy = float(meta.get("entropy", 0.0))
        tokens = meta.get("tokens", []) or []
        is_leaf = bool(meta.get("is_leaf", False))
        children = meta.get("children", []) or []
        branching_factor = len(children)
        locked = bool(meta.get("locked", False))
        weak = bool(meta.get("weak", False))
        span = meta.get("span")

        features = {
            "depth": depth,
            "entropy": entropy,
            "token_count": len(tokens),
            "is_leaf": is_leaf,
            "branching_factor": branching_factor,
            "locked": locked,
        }

        def emit(anchor_type: str):
            anchors.append(Anchor(
                anchor_id=f"{node_id}::{anchor_type}",
                node_id=node_id,
                anchor_type=anchor_type,
                features=features,
                span=span
            ))

        # Root
        if node_id == root_id:
            emit("root_anchor")

        # Structure rules
        if branching_factor >= cfg.branching_threshold:
            emit("branching_point")
        if entropy >= cfg.entropy_threshold:
            emit("unstable_branch")
        if is_leaf and len(tokens) >= cfg.leaf_token_threshold:
            emit("leaf_dense")

        # Content rules (tree-derived)
        if is_leaf and _has_digits(tokens):
            emit("number_leaf")
        if is_leaf and _has_symbols(tokens):
            emit("symbol_leaf")

        # State rules
        if locked:
            emit("locked_node")
        if weak:
            emit("weak_leaf")

        # Optional: heading/code-like flags
        if meta.get("is_heading", False):
            emit("text_heading_like")
        if meta.get("is_code_block", False):
            emit("code_block_like")

    # Deduplicate by anchor_id (determinism)
    unique = {a.anchor_id: a for a in anchors}

    # Convert dataclasses to plain dicts
    return [vars(a) for a in unique.values()]

Writing anchor_extractor.py


In [None]:

%%writefile embedding_index.py
"""
Lightweight embedding index for retrieval (TF-IDF + cosine).
Domain-agnostic and deterministic.
"""

from typing import List, Dict, Optional
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity


class EmbeddingIndex:
    def __init__(self):
        self.fragments: List[str] = []
        self.vectorizer: Optional[TfidfVectorizer] = None
        self.embeddings = None

    def build_index(self, fragments: List[str]):
        """
        Build index from list of text fragments.
        """
        self.fragments = fragments or []
        # TF-IDF with unigrams+bigrams for better matching
        self.vectorizer = TfidfVectorizer(
            lowercase=True,
            analyzer="word",
            ngram_range=(1, 2),
            min_df=1,
            max_df=1.0
        )
        if self.fragments:
            self.embeddings = self.vectorizer.fit_transform(self.fragments)
        else:
            self.embeddings = None

    def query(self, text: str, k: int = 5) -> List[Dict]:
        """
        Query top-k similar fragments.
        Returns list of dicts:
          {
            "fragmentid": str,
            "fragmenttext": str,
            "embeddingrank": int,
            "retrievalconfidence": float
          }
        """
        if not self.fragments or self.embeddings is None or self.vectorizer is None:
            return []
        q_vec = self.vectorizer.transform([text])
        sims = cosine_similarity(q_vec, self.embeddings)[0]  # shape: (N,)
        ranked = sorted(
            [(i, float(sims[i])) for i in range(len(self.fragments))],
            key=lambda x: x[1],
            reverse=True
        )
        top = ranked[:max(1, k)]
        results: List[Dict] = []
        for rank, (idx, score) in enumerate(top, start=1):
            results.append({
                "fragmentid": f"frag{idx}",
                "fragmenttext": self.fragments[idx],
                "embeddingrank": rank,
                "retrievalconfidence": round(score, 6)
            })
        return results

Writing embedding_index.py


In [None]:

%%writefile embedding_matching.py
"""
Wrapper around EmbeddingIndex for retrieval queries.
"""

from typing import List, Dict
from embedding_index import EmbeddingIndex

_index = EmbeddingIndex()

def build_fragments(fragments: List[str]):
    """
    Initialize or rebuild the index with given fragments.
    """
    _index.build_index(fragments)

def retrieve(text: str, k: int = 5) -> List[Dict]:
    """
    Retrieve top-k fragments for the given text.
    """
    return _index.query(text, k)

Writing embedding_matching.py


In [None]:

%%writefile rl_critic.py
"""
RL Critic: quick scoring estimates for candidate actions.
Heuristic-based, deterministic, tree-feature aware.
"""

from typing import Dict

# Domain-specific modifiers (tunable)
DOMAIN_MODIFIERS = {
    "math": 0.05,
    "code": 0.02,
    "science": 0.00,
    "chess": -0.01
}

def score_candidate(state: Dict, action: Dict) -> Dict:
    """
    Estimate pos/neg/quality scores for a candidate.

    Args:
        state: dict with features like
            {"depth": int, "entropy": float, "token_count": int,
             "branching_factor": int, "is_leaf": bool,
             "coverage_balance": float, "domain": str}
        action: dict like
            {"type": "split"|"lock"|"promote", "target": node_id, "anchor_type": str}

    Returns:
        {
            "est_pos": float,
            "est_neg": float,
            "est_quality": float
        }
    """
    depth = float(state.get("depth", 0))
    entropy = float(state.get("entropy", 0.0))
    token_count = int(state.get("token_count", 0))
    branching = int(state.get("branching_factor", 0))
    is_leaf = bool(state.get("is_leaf", False))
    coverage_balance = float(state.get("coverage_balance", 0.0))
    domain = (state.get("domain") or "").lower()

    action_type = (action.get("type") or "").lower()
    anchor_type = (action.get("anchor_type") or "").lower()

    # Base positive signal: entropy, branching, token richness
    pos = 0.0
    pos += min(1.0, entropy) * 0.5
    pos += min(1.0, branching / 3.0) * 0.3
    pos += min(1.0, token_count / 12.0) * 0.2

    # Coverage balance bonus
    pos += min(0.2, coverage_balance * 0.2)

    # Action alignment boosts
    if action_type == "split":
        pos += 0.2 if (entropy >= 0.6 or branching >= 2) else 0.0
    if action_type == "lock":
        pos += 0.15 if anchor_type in ("number_leaf", "text_heading_like") else 0.0
    if action_type == "promote":
        pos += 0.15 if (depth >= 1 and not is_leaf) else 0.0

    # Negative signals
    neg = 0.0
    if is_leaf and entropy >= 0.7 and token_count < 3:
        neg += 0.3
    if action_type == "lock" and anchor_type not in ("number_leaf", "symbol_leaf", "text_heading_like"):
        neg += 0.25
    if action_type == "split" and branching == 0:
        neg += 0.2

    # Quality = positive minus weighted negative
    quality = pos - 0.5 * neg

    # Apply domain modifier
    quality += DOMAIN_MODIFIERS.get(domain, 0.0)

    # Clamp to [0,1]
    quality = max(0.0, min(1.0, quality))

    return {
        "est_pos": round(pos, 4),
        "est_neg": round(neg, 4),
        "est_quality": round(quality, 4)
    }

Writing rl_critic.py


In [None]:

%%writefile decoder.py
"""
Decoder: generate short explanation text from snapshot + retrieval context.
Deterministic, template-based.
"""

from typing import Dict, List

def decode_snapshot(snapshot: Dict, anchors: List[Dict], retrievals: List[Dict]) -> str:
    """
    Produce explanation text for reviewer clarity.
    Chooses first anchor and first retrieval for a concise rationale.
    """
    if not anchors:
        return "No anchors found for this snapshot."
    anchor = anchors[0]
    node_id = anchor.get("node_id")
    a_type = anchor.get("anchor_type")
    entropy = anchor.get("features", {}).get("entropy", 0.0)
    branching = anchor.get("features", {}).get("branching_factor", 0)

    base = f"Node {node_id} anchored as {a_type}; entropy {entropy:.2f}, branching {branching}."
    if retrievals:
        r = retrievals[0]
        frag = r.get("fragmenttext", "")
        conf = r.get("retrievalconfidence", 0.0)
        base += f" Retrieved: '{frag}' (confidence {conf:.2f})."

    return base

Writing decoder.py


In [None]:

%%writefile smoother.py
"""
Smoother: polish explanation text for readability.
Simple rule-based cleanup.
"""

def smooth_text(text: str) -> str:
    if not text:
        return ""
    cleaned = text.strip()
    # Capitalize first letter
    cleaned = cleaned[0].upper() + cleaned[1:]
    # Ensure trailing period
    if cleaned[-1] not in ".!?":
        cleaned += "."
    # Compact extra spaces
    cleaned = " ".join(cleaned.split())
    return cleaned

Writing smoother.py


## Phase 3.2 => Dual‑Valence RL‑lite Learner

In [None]:

%%writefile /content/tokenizer_and_embedding.py
import torch, re, logging
from typing import List

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class TokenEmbedding:
    def __init__(self, vocab: List[str], dim:int=50, device:str='cpu'):
        self.dim, self.device = dim, device
        self.vocab = ['<unk>'] + vocab
        self.word2idx = {w:i for i,w in enumerate(self.vocab)}
        self.embeddings = torch.randn(len(self.vocab), dim, device=device) / (dim**0.5)
        logger.info(f"TokenEmbedding: vocab_size={len(self.vocab)} dim={dim}")

    def lookup(self, token:str) -> torch.Tensor:
        idx = self.word2idx.get(token, 0)
        return self.embeddings[idx]

def universal_tokenizer(text: str) -> List[str]:
    if not text:
        return []
    # split numbers, identifiers, symbols
    return re.findall(r'\d+\.\d+|\d+|[A-Za-z]+|[+\-*/^=():]', text)

import torch
import torch.nn.functional as F

def build_trait_embedding_index(trait_vpk_table, embedder):
    index = {}
    for trait in trait_vpk_table.keys():
        tokens = universal_tokenizer(trait)
        if not tokens:
            continue
        vec = sum(embedder.lookup(tok) for tok in tokens) / len(tokens)
        index[trait] = vec / (vec.norm() + 1e-9)  # normalized
    return index


def semantic_match_traits(query, trait_index, embedder, top_k=8):
    tokens = universal_tokenizer(query.lower())
    if not tokens:
        return []
    q_vec = sum(embedder.lookup(tok) for tok in tokens) / len(tokens)
    q_vec = q_vec / (q_vec.norm() + 1e-9)

    sims = []
    for trait, t_vec in trait_index.items():
        s = F.cosine_similarity(q_vec.unsqueeze(0), t_vec.unsqueeze(0)).item()
        sims.append((trait, s))

    sims.sort(key=lambda x: x[1], reverse=True)

    strong = [t for t, s in sims if s >= 0.72][:top_k]     # strong signals
    medium = [t for t, s in sims if 0.55 <= s < 0.72][:top_k - len(strong)]

    return strong + medium

Overwriting /content/tokenizer_and_embedding.py


In [None]:

%%writefile /content/model_utils.py
import torch
import torch.nn as nn
import logging

# Setup logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class InputProjector(nn.Module):
    def __init__(self, input_dim, target_dim):
        super().__init__()
        self.project = nn.Linear(input_dim, target_dim)
        logger.info(f"Initialized InputProjector: {input_dim} -> {target_dim}")

    def forward(self, x):
        return self.project(x)

def patch_model_input(model, input_vector, expected_dim):
    actual_dim = input_vector.shape[-1]
    if actual_dim != expected_dim:
        if hasattr(model, 'input_projector'):
            logger.warning(f"Model already has input_projector, skipping patch")
            return
        logger.info(f"Auto-patching model input: {actual_dim} -> {expected_dim}")
        projector = InputProjector(actual_dim, expected_dim).to(input_vector.device)
        old_forward = model.forward

        def new_forward(x):
            x_proj = projector(x)
            return old_forward(x_proj)

        model.forward = new_forward
        model.input_projector = projector

Writing /content/model_utils.py


In [None]:

%%writefile Phase2Env.py
import numpy as np
import math
import time

from TreeSnapshot import TreeSnapshot
from check_action_allowed import check_action_allowed
from apply_action import apply_action
from compute_reward import compute_reward
from generative_decision_loop_safe import generative_decision_loop as generative_decision_loop_safe


class Phase2Env:
    def __init__(self, builder, data, feature_vectors, max_edits=20):
        self.builder = builder
        self.data = data
        self.feature_vectors = feature_vectors
        self.current_idx = None
        self.tree = None
        self.snapshot = None
        self.steps = 0
        self.max_edits = max_edits

        # Economy state
        self.tickets = {"G": 0, "B": 0, "Y": 0, "R": 0, "P": 0}
        self.failure_memory = {}
        self.temp_tickets = {"G": 0, "B": 0, "Y": 0, "R": 0}
        self.decay_queue = []

        # Adaptive coeffs container (not directly used by compute_reward;
        # you can surface them to UI or curriculum later)
        self.coeffs = {"alpha": 1.0, "beta": 1.0, "gamma": 1.0, "delta": 1.0}

    # --------------------------------------------------------------------------
    # Reset (supports global prakriti tree pointer)
    # --------------------------------------------------------------------------
    def reset(self, idx: int = None, use_global_tree=False, global_tree=None):
        if use_global_tree and global_tree is not None:
            self.tree = global_tree
            self.snapshot = TreeSnapshot(self.tree).to_dict()
            self.steps = 0

            # reset economy
            self.tickets = {"G": 0, "B": 0, "Y": 0, "R": 0, "P": 0}
            self.failure_memory = {}
            self.temp_tickets = {"G": 0, "B": 0, "Y": 0, "R": 0}
            self.decay_queue = []
            self.coeffs = {"alpha": 1.0, "beta": 1.0, "gamma": 1.0, "delta": 1.0}
            return self._snapshot_to_obs(self.snapshot)

        # Local (row) tree
        row = self.data.iloc[idx]
        active_tokens = [c for c, v in row.items() if v == 1 and c not in ["difficulty_score", "tree_type"]]
        vec_pairs = [(tok, self.feature_vectors[tok]) for tok in active_tokens if tok in self.feature_vectors]

        self.tree = self.builder.build_tree(vec_pairs, sample_id=f"row{idx}")
        self.snapshot = TreeSnapshot(self.tree).to_dict()
        self.steps = 0

        self.tickets = {"G": 0, "B": 0, "Y": 0, "R": 0, "P": 0}
        self.failure_memory = {}
        self.temp_tickets = {"G": 0, "B": 0, "Y": 0, "R": 0}
        self.decay_queue = []
        self.coeffs = {"alpha": 1.0, "beta": 1.0, "gamma": 1.0, "delta": 1.0}

        # small structural budget by default
        self.tickets["B"] = max(self.tickets.get("B", 0), 12)
        return self._snapshot_to_obs(self.snapshot)

    # --------------------------------------------------------------------------
    def _normalize_action(self, action):
        if isinstance(action, dict):
            op = action.get("type") or action.get("op_type")
            tgt = action.get("target") or action.get("target_id") or action.get("node_id")
            extra = action.get("extra") if "extra" in action else action.get("payload", None)
            return (op, tgt, extra)
        if isinstance(action, (list, tuple)) and len(action) >= 2:
            op, tgt = action[0], action[1]
            extra = action[2] if len(action) > 2 else None
            return (op, tgt, extra)
        return (None, None, None)

    def _validate_snapshot(self, snap):
        if not isinstance(snap, dict):
            return False
        required = ["depth", "node_count", "leaf_count", "branching_factor", "entropy", "weak_leaves"]
        return all(k in snap for k in required)

    # --------------------------------------------------------------------------
    def step(self, action, config=None,
             generator=None, scorer=None, sim_engine=None,
             verifier=None, policy_module=None,
             apply_engine=None, provenance_logger=None,
             fallback_module=None):

        action = self._normalize_action(action)

        if not self._validate_snapshot(self.snapshot):
            return np.zeros(6, dtype=np.float32), -1.0, True, {"log": {"error": "invalid_snapshot"}}

        prev_snap = dict(self.snapshot)

        # EIS (for generative trigger only; the authoritative EIS is recomputed inside compute_reward)
        entropy = self.snapshot.get("entropy", 0.0)
        stability = 1.0 - self.snapshot.get("branch_flip_rate", 0.0)
        depth_ratio = self.snapshot.get("depth", 1) / max(1, self.snapshot.get("node_count", 1))
        EIS = 0.4 * entropy + 0.3 * (1 - stability) + 0.3 * depth_ratio

        # -------------------- Generative path --------------------
        if config and EIS >= config.get("generative_threshold", 0.7):
            if generator is None or provenance_logger is None:
                return self._snapshot_to_obs(self.snapshot), -0.5, True, {
                    "log": {"status": "generative_blocked", "reason": "missing_generator_or_logger"}
                }

            new_snapshot, status = generative_decision_loop_safe(
                self.snapshot,
                self.tickets,
                self.failure_memory,
                config,
                generator,
                scorer,
                sim_engine,
                verifier,
                policy_module,
                apply_engine,
                provenance_logger,
                fallback_module
            )
            self.snapshot = new_snapshot

            reward, self.tickets, self.failure_memory, log_entry = compute_reward(
                prev_snap, self.snapshot, self.builder.mode,
                self.tickets, self.failure_memory,
                self.temp_tickets, self.decay_queue,
                proposed_fix=None
            )

            self.steps += 1
            done = (self.steps >= self.max_edits) or (reward < -5.0)
            return self._snapshot_to_obs(self.snapshot), reward, done, {
                "log": {"mode": "generative", **status, **log_entry}
            }

        # -------------------- Structural path --------------------
        allowed, cost, reason = check_action_allowed(action, self.tickets)
        if not allowed:
            return self._snapshot_to_obs(self.snapshot), -1.0, True, {
                "log": {"status": "blocked", "reason": reason}
            }

        # ticket spend
        for k, v in (cost or {}).items():
            self.tickets[k] = max(0, self.tickets.get(k, 0) - int(v))

        # apply edit
        log = apply_action(self.tree, action)

        # refresh snapshot + add shaping
        self.snapshot = TreeSnapshot(self.tree).to_dict()
        cur = self.snapshot

        # compute deltas for shaping
        d_ent    = float(prev_snap.get("entropy", 0.0) - cur.get("entropy", 0.0))
        d_weak   = float(prev_snap.get("weak_leaves", 0) - cur.get("weak_leaves", 0))
        d_br     = float(prev_snap.get("branching_factor", 0.0) - cur.get("branching_factor", 0.0))

        pos = 0.20 * max(0.0, d_ent) + 0.15 * max(0.0, d_weak) + 0.07 * max(0.0, d_br)
        neg = 0.20 * max(0.0, -d_ent) + 0.10 * max(0.0, -d_br)

        # anti-gaming cap
        cur["pos_score"] = float(min(pos, 5.0))
        cur["neg_score"] = float(min(neg, 5.0))

        # outcome hint (reward will still re-infer robustly if absent)
        if d_ent > 0 or d_weak > 0:
            cur["outcome"] = "success"
        elif neg > pos:
            cur["outcome"] = "repeat_fail"
        else:
            cur["outcome"] = "neutral"

        self.snapshot.setdefault("pos_score",0.0)
        self.snapshot.setdefault("neg_score",0.0)
        # compute reward
        reward, self.tickets, self.failure_memory, log_entry = compute_reward(
            prev_snap, self.snapshot, self.builder.mode,
            self.tickets, self.failure_memory,
            self.temp_tickets, self.decay_queue,
            proposed_fix=None
        )

        self.steps += 1
        done = (self.steps >= self.max_edits) or (reward < -5.0)

        obs = self._snapshot_to_obs(self.snapshot)
        return obs, reward, done, {"log": {**(log or {}), **log_entry}}

    # --------------------------------------------------------------------------
    def _snapshot_to_obs(self, snap):
        return np.array([
            snap["depth"],
            snap["node_count"],
            snap["leaf_count"],
            snap["branching_factor"],
            snap["entropy"],
            snap["weak_leaves"],
        ], dtype=np.float32)

    # --------------------------------------------------------------------------
    def update_feedback(self, episode_stats):
        # optional curriculum hooks (you already had these; preserved)
        if episode_stats.get("avg_red", 0) > 0.4:
            self.coeffs["alpha"] *= 0.9
            self.coeffs["beta"] *= 0.9

        if episode_stats.get("avg_success", 0) > 0.8:
            self.coeffs["alpha"] *= 1.1
            self.coeffs["beta"] *= 1.1

        if episode_stats.get("entropy_plateau", False):
            self.coeffs["gamma"] += 0.1

        if episode_stats.get("purple_conversions", 0) > 0.5:
            self.coeffs["delta"] += 0.1

        if episode_stats.get("loan_tickets", 0) > 3:
            self.coeffs["alpha"] *= 0.8
            self.coeffs["beta"] *= 0.8

        return {"coeffs": dict(self.coeffs), "adjustments": episode_stats}

Writing Phase2Env.py


In [None]:

# ============================================================
# Null-Shims + Safe Generative Wrapper (Phase 2 - A)
# Paste this BELOW the Phase2Env class
# ============================================================

class NullProvLogger:
    def log_request(self, *a, **k): pass
    def log_candidates(self, *a, **k): pass
    def log_scores(self, *a, **k): pass
    def log_accept(self, *a, **k): pass
    def log_fallback(self, *a, **k): pass

class NullGenerator:
    def generate(self, request, k=1):
        # returns empty candidate list (safe no-op)
        return []

class NullScorer:
    def score_batch(self, candidates, snapshot):
        # assign default neutral scores
        for c in candidates:
            c["gen_confidence"] = c.get("gen_confidence", 0.5)
            c["novelty_score"] = c.get("novelty_score", 0.0)
        return candidates
    def select_top_k(self, cands, k=3):
        return cands[:k]

class NullSimEngine:
    def dual_valence(self, snapshot, candidate):
        # pos, neg, entropy_delta, stability_delta
        return 0.1, 0.0, 0.0, 0.0

class NullVerifier:
    def check(self, snapshot, candidate):
        return True, "null"

class NullPolicy:
    def apply_ticket_penalties(self, scored, tickets, costs):
        return scored
    def select(self, scored, policy="greedy", epsilon=0.1, temperature=1.0):
        if not scored:
            return None
        return max(scored, key=lambda s: s.get("final_score", 0.0))

class NullApplyEngine:
    # Return snapshot unchanged — "no structural effect"
    def apply(self, *args, **kwargs):
        snap = args[1]  # snapshot reference
        return snap, snap, False, "noop"

class NullFallback:
    def choose(self, snapshot, scored, extras, failure_memory):
        return {"type": "noop", "reason": "no candidate"}

# ---- Safe Generative Loop Wrapper ----

def generative_decision_loop_safe(snapshot,
                                  tickets,
                                  failure_memory,
                                  config,
                                  generator=None,
                                  scorer=None,
                                  sim_engine=None,
                                  verifier=None,
                                  policy_module=None,
                                  apply_engine=None,
                                  provenance_logger=None,
                                  fallback_module=None):

    # Fill missing modules with null shims
    generator = generator or NullGenerator()
    scorer = scorer or NullScorer()
    sim_engine = sim_engine or NullSimEngine()
    verifier = verifier or NullVerifier()
    policy_module = policy_module or NullPolicy()
    apply_engine = apply_engine or NullApplyEngine()
    provenance_logger = provenance_logger or NullProvLogger()
    fallback_module = fallback_module or NullFallback()

    try:
        return generative_decision_loop(
            snapshot,
            tickets,
            failure_memory,
            config,
            generator,
            scorer,
            sim_engine,
            verifier,
            policy_module,
            apply_engine,
            provenance_logger,
            fallback_module
        )
    except Exception as e:
        # Safe fallback (no crash)
        return snapshot, {"status": "safe_fallback", "reason": str(e)}

# Decoder

In [None]:

%%writefile TLiteComponents.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import numpy as np

# Optional niceties
try:
    import nltk
    nltk.download('wordnet', quiet=True)
except Exception:
    # not fatal - tokenizer/embedding code may handle missing resources at runtime
    pass

logger = logging.getLogger(__name__)
if not logger.handlers:
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
    logger.addHandler(handler)
logger.setLevel(logging.INFO)


# -------------------------
# Small utility / missing pieces
# -------------------------
class InputProjector(nn.Module):
    """
    Simple linear projector used to match unexpected input dims to model.expected_dim.
    """
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


# -------------------------
# Token embedding wrapper (expects TokenEmbedding in your project)
# -------------------------
class EnhancedTokenEmbedding(nn.Module):
    def __init__(self, vocab, dim, device='cpu'):
        super().__init__()
        # `TokenEmbedding` is assumed to be provided elsewhere in your repo
        try:
            from tokenizer_and_embedding import TokenEmbedding
            tokens = list(vocab) if isinstance(vocab, (list, tuple, dict)) else vocab
            self.embedding = TokenEmbedding(tokens, dim, device)
        except Exception as e:
            logger.warning(f"TokenEmbedding unavailable or failed to import: {e}")
            self.embedding = None
        self.device = device
        self.dim = dim
        self.oov_vector = nn.Parameter(torch.randn(dim, device=device) * 0.01)
        logger.info(f"Initialized EnhancedTokenEmbedding (dim={dim}) on device={device}")

    def lookup(self, token: str) -> torch.Tensor:
        if self.embedding is None:
            return self.oov_vector
        try:
            vec = self.embedding.lookup(token)
            if isinstance(vec, torch.Tensor) and vec.dim() == 1:
                return vec.to(self.device)
            # if returned shape is (1, dim) or similar
            if isinstance(vec, torch.Tensor) and vec.dim() > 1:
                return vec.squeeze(0).to(self.device)
            return self.oov_vector
        except Exception as e:
            logger.warning(f"Token {token} lookup failed, using OOV: {e}")
            return self.oov_vector


# -------------------------
# Tree encoder with attention
# -------------------------
class TreeEncoderWithAttention(nn.Module):
    """
    Encodes a tree structure by recursively encoding children and running self-attention
    over the child vectors. Returns a vector of size `dim`.
    """
    def __init__(self, dim: int, num_heads: int = 4, device: str = 'cpu'):
        super().__init__()
        self.dim = dim
        self.device = device
        self.projectors = {}  # cache for per-input-dim linear projectors

        # normalize and MHA
        # MultiheadAttention in PyTorch expects embed_dim divisible by num_heads
        if dim % num_heads != 0:
            # choose nearest divisor-friendly heads
            orig_heads = num_heads
            num_heads = max(1, dim // (dim // max(1, num_heads)))
            logger.warning(f"Adjusted num_heads {orig_heads} -> {num_heads} for embed_dim={dim}")

        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.norm = nn.LayerNorm(dim)
        self.to(device)
        logger.info(f"Initialized TreeEncoderWithAttention dim={dim}, heads={num_heads} on {device}")

    def _ensure_dim(self, vec: torch.Tensor) -> torch.Tensor:
        vec = vec.to(self.device)
        if vec.shape[-1] == self.dim:
            return vec
        input_dim = vec.shape[-1]
        key = f"{input_dim}>{self.dim}"
        if key not in self.projectors:
            logger.info(f"Auto-projecting leaf vector: {input_dim} -> {self.dim}")
            self.projectors[key] = nn.Linear(input_dim, self.dim).to(self.device)
        return self.projectors[key](vec)

    def encode(self, node, get_vector_fn):
        """
        Recursively encode node. node.is_leaf() and node.get_vector() are expected.
        Returns torch.Tensor shape (dim,)
        """
        # defensive checks
        if node is None:
            return torch.zeros(self.dim, device=self.device)

        if getattr(node, "is_leaf", None) and node.is_leaf():
            vec = get_vector_fn(node)
            if vec is None:
                return torch.zeros(self.dim, device=self.device)
            if not isinstance(vec, torch.Tensor):
                vec = torch.tensor(np.asarray(vec), dtype=torch.float32, device=self.device)
            return self._ensure_dim(vec)

        # gather child vectors
        vectors = []
        for child in getattr(node, "children", []):
            try:
                vec = self.encode(child, get_vector_fn)
            except Exception as e:
                logger.debug(f"child encode failed: {e}")
                vec = torch.zeros(self.dim, device=self.device)
            if vec is not None:
                vectors.append(vec)

        if not vectors:
            return torch.zeros(self.dim, device=self.device)

        # stack into shape (batch=1, seq_len, embed_dim) for batch_first attention
        stacked = torch.stack(vectors, dim=0).unsqueeze(0)  # [1, seq_len, dim]
        attn_output, _ = self.attention(stacked, stacked, stacked)  # returns [1, seq_len, dim]
        pooled = attn_output.mean(dim=1).squeeze(0)  # [dim]
        return self.norm(pooled)

    def forward(self, node, get_vector_fn=lambda n: n.get_vector()):
        return self.encode(node, get_vector_fn)


# -------------------------
# TLite modules (V4, V5, V6)
# -------------------------
class TLiteV4_SearchEncoder(nn.Module):
    """
    Small MLP that reduces embedding to a non-negative score.
    Expects input vector size `expected_dim` (if not, will project).
    """
    def __init__(self, dim=50, hidden_dim=256, device='cpu'):
        super().__init__()
        self.expected_dim = dim
        self.device = device
        self.projector = None
        self.encoder = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Softplus()
        )
        self.to(device)
        logger.info(f"Initialized TLiteV4_SearchEncoder dim={dim} on {device}")

    def forward(self, x: torch.Tensor):
        x = x.to(self.device)
        if x.dim() == 1:
            x = x.unsqueeze(0)  # make batch
        if x.shape[-1] != self.expected_dim:
            if self.projector is None:
                self.projector = InputProjector(x.shape[-1], self.expected_dim).to(self.device)
                logger.info(f"Created projector for TLiteV4: {x.shape[-1]} -> {self.expected_dim}")
            x = self.projector(x)
        out = self.encoder(x)  # [batch, 1]
        return out.squeeze(-1)  # [batch]


class TLiteV5_ReasoningModule(nn.Module):
    """
    Residual stack reasoning head: input -> depth of residual blocks -> scalar head.
    """
    def __init__(self, dim=50, hidden_dim=128, depth=4, device='cpu'):
        super().__init__()
        self.expected_dim = dim
        self.device = device
        self.projector = None
        layers = []
        for _ in range(depth):
            layers.append(nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, dim)
            ))
        self.layers = nn.ModuleList(layers)
        self.final_norm = nn.LayerNorm(dim)
        self.head = nn.Sequential(nn.Linear(dim, 1), nn.Softplus())
        self.to(device)
        logger.info(f"Initialized TLiteV5_ReasoningModule dim={dim}, depth={depth} on {device}")

    def forward(self, x: torch.Tensor):
        x = x.to(self.device)
        if x.dim() == 1:
            x = x.unsqueeze(0)
        if x.shape[-1] != self.expected_dim:
            if self.projector is None:
                self.projector = InputProjector(x.shape[-1], self.expected_dim).to(self.device)
                logger.info(f"Created projector for TLiteV5: {x.shape[-1]} -> {self.expected_dim}")
            x = self.projector(x)
        # residual stack
        for layer in self.layers:
            x = x + layer(x)
        x = self.final_norm(x)
        out = self.head(x)
        return out.squeeze(-1)  # [batch]


class TLiteExpert(nn.Module):
    def __init__(self, dim=50, hidden_dim=64, device='cpu'):
        super().__init__()
        self.device = device
        self.norm = nn.LayerNorm(dim)
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.to(device)
        logger.info(f"Initialized TLiteExpert dim={dim} on {device}")

    def forward(self, x: torch.Tensor):
        x = x.to(self.device)
        if x.dim() == 1:
            x = x.unsqueeze(0)
            squeezed = True
        else:
            squeezed = False
        y = self.norm(x)
        y = F.gelu(self.fc1(y))
        y = self.fc2(y)
        if squeezed:
            return y.squeeze(0)
        return y


class TLiteRouter(nn.Module):
    def __init__(self, dim=50, num_experts=8, top_k=2, device='cpu'):
        super().__init__()
        self.gate = nn.Linear(dim, num_experts)
        self.num_experts = num_experts
        self.top_k = top_k
        self.device = device
        self.to(device)
        logger.info(f"Initialized TLiteRouter num_experts={num_experts}, top_k={top_k} on {device}")

    def forward(self, x: torch.Tensor):
        """
        x: [batch, dim]
        returns topk_indices [batch, top_k], topk_weights [batch, top_k]
        """
        x = x.to(self.device)
        logits = self.gate(x)  # [batch, num_experts]
        topk_scores, topk_indices = torch.topk(logits, k=self.top_k, dim=-1)  # both [batch, top_k]
        topk_weights = F.softmax(topk_scores, dim=-1)  # [batch, top_k]
        return topk_indices, topk_weights


class TLiteV6(nn.Module):
    """
    Mixture of small experts. Input: [batch, dim] -> produces scalar per batch.
    """
    def __init__(self, dim=48, hidden_dim=64, num_experts=8, top_k=2, device='cpu'):
        super().__init__()
        self.device = device
        self.num_experts = num_experts
        self.top_k = top_k
        # experts return [batch, dim]
        self.experts = nn.ModuleList([TLiteExpert(dim, hidden_dim, device) for _ in range(num_experts)])
        self.router = TLiteRouter(dim, num_experts, top_k, device)
        self.final_head = nn.Sequential(nn.Linear(dim, 1), nn.Softplus())
        self.to(device)
        logger.info(f"Initialized TLiteV6 num_experts={num_experts} on {device}")

    def forward(self, x: torch.Tensor):
        """
        x: [batch, dim] or [dim] (will be batched)
        returns: [batch] scalar scores
        """
        x = x.to(self.device)
        if x.dim() == 1:
            x = x.unsqueeze(0)
            squeezed = True
        else:
            squeezed = False

        batch_size = x.shape[0]
        topk_indices, topk_weights = self.router(x)  # [batch, top_k], [batch, top_k]

        # aggregate expert outputs
        # prepare a container with same shape as x to accumulate weighted expert outputs
        out = torch.zeros_like(x, device=self.device)  # [batch, dim]

        for b in range(batch_size):
            for i in range(self.top_k):
                idx = int(topk_indices[b, i].item())
                weight = topk_weights[b, i]
                expert_out = self.experts[idx](x[b])  # expert_out: [dim] or [1,dim] squeezed to [dim]
                if expert_out.dim() == 1:
                    expert_vec = expert_out
                else:
                    expert_vec = expert_out.squeeze(0)
                out[b] += weight * expert_vec

        final = self.final_head(out)  # [batch,1]
        final = final.squeeze(-1)  # [batch]
        if squeezed:
            return final.squeeze(0)
        return final

Writing TLiteComponents.py


In [None]:

%%writefile /content/DecoderV1.py
import logging
from TreeBuilderV2 import TreeBuilderV2
from TreeSnapshot import TreeSnapshot
from TLiteComponents import TreeEncoderWithAttention, TLiteV6

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class DecoderV1:
    def __init__(self, dim=48, device="cpu", use_smoother=True):
        self.encoder = TreeEncoderWithAttention(dim=dim, device=device)
        self.smoother = TLiteV6(dim=dim, device=device) if use_smoother else None
        self.device = device

    def explain(self, root_node):
        if root_node is None:
            return "No tree available to explain."

        # 1. Structural snapshot
        snapshot = TreeSnapshot(root_node)
        stats = snapshot.to_dict()

        # 2. Path extraction
        builder = TreeBuilderV2(mode="binary")  # for tracing
        path = builder.trace_dfs(root_node)

        # 3. Encode tree vector
        vec = self.encoder(root_node, get_vector_fn=lambda n: n.get_vector())
        vec_str = f"[encoded-vector norm={vec.norm().item():.3f}]"

        # 4. Construct raw explanation
        explanation = (
            f"Tree has depth {stats['depth']} with {stats['node_count']} nodes "
            f"({stats['leaf_count']} leaves). Branching factor={stats['branching_factor']}, "
            f"entropy={stats['entropy']}. Weak leaves={stats['weak_leaves']}.\n"
            f"Traversal path: {' -> '.join(str(x) for x in path)}.\n"
            f"Neural embedding summary: {vec_str}"
        )

        # 5. Optional smoother
        if self.smoother:
            try:
                vec_in = vec.unsqueeze(0)  # batchify
                smooth_score = self.smoother(vec_in).item()
                explanation += f"\n[Polished score={smooth_score:.3f}]"
            except Exception as e:
                logger.warning(f"Smoother failed: {e}")

        return explanation

Writing /content/DecoderV1.py


# Phase-4 : RLTeacher prototype




In [None]:
import random

class RLTeacher:
    def __init__(self):
        self.history = []

    # Step 1: Propose candidate actions
    def propose_actions(self, anchors):
        actions = []
        for anchor in anchors:
            actions.append({
                "type": "split",
                "target": anchor["id"],
                "anchor_type": anchor["type"]
            })
            actions.append({
                "type": "merge",
                "target": anchor["id"],
                "anchor_type": anchor["type"]
            })
        return actions

    # Step 2: Evaluate actions with a simple critic
    def evaluate_actions(self, state, actions):
        scored = []
        for action in actions:
            # Dummy critic: random quality for now
            est_quality = random.uniform(0, 1)
            score = {
                "est_pos": est_quality,
                "est_neg": 1 - est_quality,
                "est_quality": est_quality
            }
            scored.append({"action": action, "score": score})
        return scored

    # Step 3: Select best action (greedy)
    def select_action(self, scored_actions):
        if not scored_actions:
            return None, None
        sorted_actions = sorted(
            scored_actions,
            key=lambda x: x["score"].get("est_quality", 0.0),
            reverse=True
        )
        best = sorted_actions[0]
        return best["action"], best["score"]

    # Step 4: Apply action (stub)
    def apply_action(self, state, action):
        new_state = state.copy()
        new_state["applied_action"] = action["type"]
        new_state["applied_target"] = action["target"]
        return new_state

    # Step 5: Log decision
    def log_decision(self, state, action, score, explanation):
        self.history.append({
            "state": state,
            "action": action,
            "score": score,
            "explanation": explanation
        })

    # Step 6: Adaptive run loop
    def run(self, snapshot, anchors, retrievals,
            base_steps=2, max_steps=6, success_threshold=0.7):
        steps_remaining = base_steps
        current_state = snapshot.copy()

        while steps_remaining > 0 and len(self.history) < max_steps:
            actions = self.propose_actions(anchors)
            scored = self.evaluate_actions(current_state, actions)
            action, score = self.select_action(scored)

            if not action:
                break

            new_state = self.apply_action(current_state, action)

            # Simple explanation (stub)
            explanation = (
                f"Applied {action['type']} on {action['target']} "
                f"with quality {score['est_quality']:.2f}"
            )

            self.log_decision(new_state, action, score, explanation)

            # Adaptive adjustment
            if score["est_quality"] >= success_threshold:
                steps_remaining += 1
            else:
                steps_remaining -= 1

            current_state = new_state

        return self.history

In [None]:

class FailureMemory:
    """Persistent memory for clustered failures and their solutions."""
    def __init__(self, filepath="failure_memory.json"):
        self.filepath = filepath
        self.memory = self._load()

    def _load(self):
        if os.path.exists(self.filepath):
            with open(self.filepath, "r") as f:
                return json.load(f)
        return {}

    def _save(self):
        with open(self.filepath, "w") as f:
            json.dump(self.memory, f, indent=2)

    def record_failure(self, signature, snapshot_dict, details):
        """
        Record a failure into its cluster.
        signature: str (stable hash of failure pattern)
        snapshot_dict: dict from TreeSnapshot.to_dict()
        details: dict with error_type, node_id, etc.
        """
        cluster = self.memory.get(signature, {
            "signature": signature,
            "problems": [],
            "repeat_count": 0,
            "impact_score": 0.0,
            "solution": None,
            "stats": {"last_seen": None, "resolved": False, "uses": 0}
        })

        cluster["problems"].append({
            "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            "details": details
        })
        cluster["repeat_count"] += 1
        cluster["stats"]["last_seen"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())

        # Compute impact score using snapshot metrics
        eis = snapshot_dict.get("entropy", 0.0)
        depth = snapshot_dict.get("depth", 1)
        max_depth = max(1, snapshot_dict.get("node_count", 1))
        cluster["impact_score"] = 0.5*(1 - depth/max_depth) + 0.3*math.log(1+cluster["repeat_count"]) + 0.2*eis

        self.memory[signature] = cluster
        self._save()
        return cluster

    def update_solution(self, signature, natural_text, patch, quality, provenance):
        """Attach a verified solution to a cluster."""
        cluster = self.memory.get(signature)
        if not cluster:
            return None
        cluster["solution"] = {
            "natural_text": natural_text,
            "patch": patch,
            "quality": quality,
            "provenance": provenance
        }
        cluster["stats"]["resolved"] = quality >= 0.6
        self.memory[signature] = cluster
        self._save()
        return cluster

    def reuse_solution(self, signature):
        """Fetch a stored solution if available."""
        cluster = self.memory.get(signature)
        if cluster and cluster.get("solution"):
            cluster["stats"]["uses"] += 1
            self._save()
            return cluster["solution"]
        return None

    def get_penalty(self, signature):
        """Penalty scales with impact score instead of flat count."""
        cluster = self.memory.get(signature)
        if not cluster:
            return 0.0
        return 0.5 * cluster["impact_score"]  # tune factor

In [None]:
%%writefile update_failure_memory.py
import math
import time

def update_failure_memory(failure_memory, failure_sig, eis, proposed_fix=None):
    """
    Update or create a failure cluster entry.
    Returns: (updated_failure_memory, outcome, lesson_flag, insight_flag)
    """

    now = time.time()
    lesson_flag, insight_flag = 0, 0

    if failure_sig in failure_memory:
        # --- Existing cluster ---
        cluster = failure_memory[failure_sig]
        cluster["repeat_count"] += 1
        cluster["stats"]["last_seen"] = now

        # Compute Impact Score (IS)
        is_val = 0.5*(1 - cluster.get("tree_score", 0.5)) \
                 + 0.3*math.log(1+cluster["repeat_count"]) \
                 + 0.2*eis
        cluster["impact_score"] = round(is_val, 3)

        # Try to resolve with proposed fix
        if proposed_fix:
            quality = proposed_fix.get("quality", 0.0)
            thresh = max(0.5, 0.7 - 0.2*eis)
            if quality > thresh:
                cluster["solution"] = proposed_fix
                cluster["resolved"] = True
                insight_flag = 1   # Purple ticket trigger
                outcome = "resolved_cluster"
            else:
                outcome = "repeat_fail"
                lesson_flag = 1    # Yellow ticket trigger
        else:
            outcome = "repeat_fail"
            lesson_flag = 1

    else:
        # --- New cluster ---
        failure_memory[failure_sig] = {
            "signature": failure_sig,
            "repeat_count": 1,
            "impact_score": 0.2*eis,
            "solution": None,
            "resolved": False,
            "stats": {"created": now, "last_seen": now}
        }
        outcome = "unique_fail"
        lesson_flag = 1

    return failure_memory, outcome, lesson_flag, insight_flag

Writing update_failure_memory.py


In [None]:

from rl_critic import score_candidate
from TreeSnapshot import TreeSnapshot

class DualValenceCritic:
    """
    Dual‑valence critic that uses TreeSnapshot features + rl_critic.score_candidate.
    Adjusts quality with failure memory penalties.
    """
    def __init__(self, failure_memory, pos_weight=1.0, neg_weight=1.0):
        self.failure_memory = failure_memory
        self.pos_weight = pos_weight
        self.neg_weight = neg_weight

    def optimistic_sim(self, snapshot_dict, action):
        """
        Optimistic sim: use rl_critic.score_candidate with normal weights.
        """
        return score_candidate(snapshot_dict, action)["est_pos"]

    def pessimistic_sim(self, snapshot_dict, action):
        """
        Pessimistic sim: invert perspective by treating negatives more heavily.
        """
        return score_candidate(snapshot_dict, action)["est_neg"]

    def evaluate(self, tree, action):
        """
        Evaluate an action on a given tree snapshot.
        """
        # Build snapshot features
        snapshot = TreeSnapshot(tree).to_dict()

        # Run dual‑valence sims
        pos = self.optimistic_sim(snapshot, action)
        neg = self.pessimistic_sim(snapshot, action)
        raw_quality = self.pos_weight * pos - self.neg_weight * neg

        # Action signature for memory
        signature = f"{action.get('type','unknown')}_{action.get('target','na')}"
        penalty = self.failure_memory.get_penalty(signature)
        adjusted_quality = raw_quality - penalty

        return {
            "action_signature": signature,
            "pos_score": pos,
            "neg_score": neg,
            "raw_quality": raw_quality,
            "penalty": penalty,
            "adjusted_quality": adjusted_quality,
            "snapshot": snapshot
        }

In [None]:

import random

class PolicyModule:
    """
    Chooses actions based on critic scores, curriculum thresholds, and exploration.
    - Curriculum gating: structural always; rule/generative unlocked per domain thresholds.
    - Curriculum pruning: cap at max_actions (default 32).
    - Selection: greedy by max pos_score, then adjusted_quality; ε-greedy optional.
    - Rationale includes thresholds, failure-memory effects, and domain modifiers.
    """
    def __init__(self, domain_thresholds, epsilon=0.0, max_actions=32, domain_modifiers=None):
        self.domain_thresholds = domain_thresholds
        self.epsilon = epsilon
        self.max_actions = max_actions
        # Domain modifiers: small adjustments to adjusted_quality
        # e.g. {"math": +0.05, "code": +0.02, "science": 0.0, "chess": -0.01}
        self.domain_modifiers = domain_modifiers or {}

    def _curriculum_allowed(self, domain, stability_score, action_type):
        th = self.domain_thresholds.get(domain, {})
        if action_type == "structural":
            return True
        if action_type == "rule":
            return stability_score >= th.get("rule", 1.0)
        if action_type == "generative":
            return stability_score >= th.get("generative", 1.0)
        return False

    def filter_by_curriculum(self, domain, stability_score, scored_candidates):
        allowed, locked_out = [], []
        for c in scored_candidates:
            a_type = c["action"].get("type", "structural")
            if self._curriculum_allowed(domain, stability_score, a_type):
                allowed.append(c)
            else:
                locked_out.append(c)

        # 🔒 Curriculum pruning: cap at max_actions
        if len(allowed) > self.max_actions:
            allowed = sorted(
                allowed,
                key=lambda c: (c["scores"]["pos_score"], c["scores"]["adjusted_quality"]),
                reverse=True
            )[:self.max_actions]

        return allowed, locked_out

    def _apply_domain_modifier(self, domain, score):
        """Apply small domain-specific adjustment to adjusted_quality."""
        return score + self.domain_modifiers.get(domain, 0.0)

    def choose_action(self, domain, stability_score, scored_candidates):
        allowed, locked_out = self.filter_by_curriculum(domain, stability_score, scored_candidates)
        if not allowed:
            rationale = f"No unlocked actions for domain={domain} at stability={stability_score:.2f}"
            return None, rationale

        # ε-greedy exploration
        if random.random() < self.epsilon:
            chosen = random.choice(allowed)
            rationale = (
                f"Exploration (ε={self.epsilon}): sampled among unlocked actions. "
                f"Stability={stability_score:.2f}, domain={domain}."
            )
            return chosen, rationale

        # Greedy: sort by (pos_score, adjusted_quality + domain_modifier), descending
        chosen = sorted(
            allowed,
            key=lambda c: (
                c["scores"]["pos_score"],
                self._apply_domain_modifier(domain, c["scores"]["adjusted_quality"])
            ),
            reverse=True
        )[0]

        # Build rationale including curriculum, failure-memory, and domain modifier
        s = chosen["scores"]
        a = chosen["action"]
        fm_influence = ("yes" if (s.get("penalty", 0.0) or 0.0) > 0 else "no")
        dom_mod = self.domain_modifiers.get(domain, 0.0)
        rationale = (
            f"Greedy selection: max pos_score then adjusted_quality. "
            f"Action={a.get('type','?')} on target={a.get('target','?')}. "
            f"Curriculum: unlocked for domain={domain} at stability={stability_score:.2f}. "
            f"Failure-memory influence={fm_influence} (penalty={s.get('penalty',0.0):.3f}). "
            f"Domain modifier applied={dom_mod:+.3f}."
        )
        return chosen, rationale

In [None]:

# ✅ Define thresholds outside the class
domain_thresholds = {
    "text": {"rule": 0.70, "generative": 0.90},
    "code": {"rule": 0.60, "generative": 0.80},
    "math": {"rule": 0.65, "generative": 0.85}
}

In [None]:

%%writefile check_action_allowed.py
ROOT_IDS = {"prakriti_global"}  # do not allow structural edits on root

def check_action_allowed(action, tickets):
    """
    Enforce ticket gating rules before executing an action.
    Returns (allowed: bool, cost: dict, reason: str)
    """
    op_type, target_id, _ = action

    # --- hard protection on root ---
    if target_id in ROOT_IDS and op_type in ("prune", "split", "reorder"):
        return False, {}, "Root is protected"

    cost = {}
    reason = "ok"

    if op_type in ("prune", "split", "reorder"):
        # Structural actions
        if tickets.get("B", 0) >= 1:
            cost = {"B": 1}
        else:
            return False, {}, "Insufficient Blue tickets"

    elif op_type in ("lock", "unlock"):
        # Rule actions
        if tickets.get("G", 0) >= 5:
            cost = {"G": 2}
        else:
            return False, {}, "Need ≥5 Green tickets"

    elif op_type in ("retrieve", "generate"):
        # Generative actions
        if (tickets.get("G", 0) >= 10 and tickets.get("Y", 0) >= 1):
            cost = {"G": 3, "Y": 1}
        elif tickets.get("P", 0) >= 2:
            cost = {"P": 2}  # Purple shortcut
        else:
            return False, {}, "Not enough tickets for generative action"

    else:
        return False, {}, "Unknown action"

    return True, cost, reason

Writing check_action_allowed.py


In [None]:

%%writefile apply_action.py
from LockManager import LockManager
from TreeNodeV1 import TreeNodeV1

# protect root from destructive structural edits
ROOT_IDS = {"prakriti_global"}

lock_manager = LockManager()

def apply_action(tree, action, embedding_index=None):
    """
    Apply an action to the tree.
    Supports structural (split, prune, lock/unlock, reorder) and generative (retrieve, generate).

    Args:
        tree: TreeNodeV1 root
        action: tuple (op_type, target_id, extra)
        embedding_index: optional EmbeddingIndex for retrieval
    """
    op_type, target_id, extra = action
    log = {"action": op_type, "target": target_id, "extra": extra, "status": "ok"}

    # --- hard protection: never mutate the root ---
    if target_id in ROOT_IDS and op_type in ("prune", "split", "reorder"):
        log["status"] = "blocked_root"
        return log

    # --- Structural actions ---
    if op_type == "prune":
        # remove children of target node
        node = find_node(tree, target_id)
        if node:
            node.children = []
        else:
            log["status"] = "fail"

    elif op_type == "split":
        # placeholder: split node into two children
        node = find_node(tree, target_id)
        if node:
            left = TreeNodeV1(node_id=f"{target_id}_L", value="split_left", level=node.level+1)
            right = TreeNodeV1(node_id=f"{target_id}_R", value="split_right", level=node.level+1)
            node.children = [left, right]
        else:
            log["status"] = "fail"

    elif op_type == "reorder":
        node = find_node(tree, target_id)
        if node and len(node.children) > 1:
            # simple deterministic rotate (stable)
            node.children = node.children[1:] + node.children[:1]
        else:
            log["status"] = "fail"

    elif op_type == "lock":
        node = find_node(tree, target_id)
        if node:
            node.lock(mode=extra or "soft")
        else:
            log["status"] = "fail"

    elif op_type == "unlock":
        node = find_node(tree, target_id)
        if node:
            node.unlock()
        else:
            log["status"] = "fail"

    # --- Generative actions ---
    elif op_type == "retrieve":
        if embedding_index and isinstance(extra, str):
            results = embedding_index.query(extra, k=3)
            log["retrievals"] = results
        else:
            log["status"] = "fail"

    elif op_type == "generate":
        # (keep as-is if your workspace has the decoder pieces wired)
        try:
            snapshot = TreeSnapshot(tree).to_dict()
            anchors = extract_anchors(snapshot)
            retrievals = []
            if embedding_index and extra:
                retrievals = embedding_index.query(extra, k=1)
            explanation = decode_snapshot(snapshot, anchors, retrievals)
            explanation = smooth_text(explanation)
            node = find_node(tree, target_id)
            if node:
                node.value = f"{node.value} :: {explanation}"
                log["generated"] = explanation
            else:
                log["status"] = "fail"
        except Exception as e:
            log["status"] = f"generate_fail: {e}"

    else:
        log["status"] = "unknown_action"

    return log


def find_node(root, node_id):
    """DFS search for a node by id."""
    if getattr(root, "node_id", None) == node_id:
        return root
    for child in getattr(root, "children", []) or []:
        found = find_node(child, node_id)
        if found:
            return found
    return None

Writing apply_action.py


In [None]:

import time
import json

class Logger:
    """
    Logs each decision with scores, rationale, curriculum gating,
    failure memory, and provenance.
    Outputs JSONL records for easy review.
    """
    def __init__(self, filepath="rlteacher_log.jsonl"):
        self.filepath = filepath

    def log_decision(self, step_idx, state, action, scores, rationale,
                     curriculum_status=None, failure_memory_penalty=None,
                     provenance=None, bestaction_flag=True):
        record = {
            "timestamp": time.time(),
            "step": step_idx,
            "state": state,              # snapshot features
            "action": action,            # chosen action
            "scores": scores,            # pos, neg, raw, adjusted
            "rationale": rationale,      # why chosen
            "curriculum_status": curriculum_status or {},
            "failure_memory_penalty": failure_memory_penalty or 0.0,
            "provenance": provenance or {},
            "bestaction_flag": bool(bestaction_flag)  # ✅ new field
        }
        with open(self.filepath, "a") as f:
            f.write(json.dumps(record) + "\n")
        return record

def explain_choice(chosen, rationale, curriculum_status, fm_penalty):
    """
    Generate a reviewer-friendly explanation string.
    """
    sig = chosen["scores"]["action_signature"]
    pos = round(chosen["scores"]["pos_score"], 3)
    neg = round(chosen["scores"]["neg_score"], 3)
    adj = round(chosen["scores"]["adjusted_quality"], 3)
    return (
        f"Action {sig} chosen with pos={pos}, neg={neg}, adj={adj}. "
        f"Reason: {rationale}. "
        f"Curriculum: {curriculum_status}. "
        f"Failure memory penalty={fm_penalty:.3f}. "
        f"BestActionFlag=True"
    )

In [None]:
def normalize_action_type(a_type: str) -> str:
    """
    Normalize raw action types into canonical categories:
    - split/merge → structural
    - lock/unlock → rule
    - generative → generative
    - anything else → structural (default fallback)
    """
    if not a_type:
        return "structural"
    a_type = a_type.lower()
    if a_type in ["split", "merge"]:
        return "structural"
    if a_type in ["lock", "unlock"]:
        return "rule"
    if a_type == "generative":
        return "generative"
    return "structural"

In [None]:

from decoder import decode_snapshot   # TLiteV6 Decoder
from smoother import smooth_text      # TLiteV6 Language Polisher

def build_explanation(snapshot, anchors, retrievals=None):
    """
    Convert tree snapshot → structured reasoning → polished natural language.
    snapshot: dict from TreeSnapshot(root).to_dict()
    anchors: dict anchor group metadata
    retrievals: optional retrieval evidence (list)
    """
    if retrievals is None:
        retrievals = []

    # Step 1 — Convert symbolic structure -> rough explanation
    raw_text = decode_snapshot(snapshot, anchors, retrievals)

    # Step 2 — Smooth grammar + clarity (does not change meaning)
    final_text = smooth_text(raw_text)

    return final_text

In [None]:

def coeff_a(eis: float) -> float:
    """Positive weight a(EIS) = 1 + 0.3*EIS"""
    return 1.0 + 0.3 * eis

def coeff_b(eis: float) -> float:
    """Negative weight β(EIS) = 1 - 0.2*EIS"""
    return 1.0 - 0.2 * eis

def coeff_gamma(eis: float) -> float:
    """Lesson weight γ(EIS) = 1 + 0.5*EIS"""
    return 1.0 + 0.5 * eis

def coeff_delta(eis: float) -> float:
    """Insight weight δ(EIS) = 1 + 0.6*EIS"""
    return 1.0 + 0.6 * eis

In [None]:

import math

def compute_impact_score(atree_score: float,
                         max_depth: int,
                         repeat_count: int,
                         eis: float) -> float:
    """
    Compute Impact Score (IS) for a failure cluster.
    IS = 0.5*(1 - ATreeScore/max_depth)
       + 0.3*log(1 + repeat_count)
       + 0.2*EIS
    """
    depth_term = 1.0 - (atree_score / max(1, max_depth))
    repeat_term = math.log(1 + repeat_count)
    eis_term = eis

    is_val = 0.5 * depth_term + 0.3 * repeat_term + 0.2 * eis
    return round(is_val, 4)

In [None]:

%%writefile compute_reward.py
import math, time
import numpy as np

# ---------- helpers ----------
def _safe_float(x, default=0.0):
    try:
        return float(x)
    except Exception:
        try:
            return float(default)
        except Exception:
            return 0.0

def _clip(x, lo, hi):
    x = _safe_float(x, 0.0)
    return max(lo, min(hi, x))

def _sanitize_snapshot(snap):
    """Ensure every field compute_reward relies on is a real number."""
    if snap is None: snap = {}
    s = dict(snap)

    # structural
    s["entropy"]           = _clip(s.get("entropy", 0.0), 0.0, 32.0)
    s["branch_flip_rate"]  = _clip(s.get("branch_flip_rate", 0.0) or 0.0, 0.0, 1.0)
    s["depth"]             = max(0, int(_safe_float(s.get("depth", 0), 0)))
    s["node_count"]        = max(1, int(_safe_float(s.get("node_count", 1), 1)))
    s["leaf_count"]        = max(0, int(_safe_float(s.get("leaf_count", 0), 0)))
    s["branching_factor"]  = _clip(s.get("branching_factor", 0.0) or 0.0, 0.0, 64.0)

    # shaping (optional; default to 0.0)
    s["pos_score"]         = _safe_float(s.get("pos_score", 0.0), 0.0)
    s["neg_score"]         = _safe_float(s.get("neg_score", 0.0), 0.0)

    # outcome flags (optional)
    if "outcome" not in s:
        s["outcome"] = "neutral"

    return s

def _sanitize_tickets(t):
    base = {"G":0,"B":0,"Y":0,"R":0,"P":0}
    if not isinstance(t, dict): return base
    out = {}
    for k in base:
        try: out[k] = int(t.get(k, 0))
        except Exception: out[k] = 0
    return out

def _sanitize_temp_tickets(t):
    base = {"G":0,"B":0,"Y":0,"R":0}
    if not isinstance(t, dict): return base
    out = {}
    for k in base:
        try: out[k] = int(t.get(k, 0))
        except Exception: out[k] = 0
    return out

# ---------- main ----------
def compute_reward(prev_snap, new_snap, mode,
                   tickets, failure_memory,
                   temp_tickets, decay_queue,
                   proposed_fix=None):
    """
    A-CRES v1.3 with failure-memory + economy + shaping.
    Returns: (reward, tickets, failure_memory, log)
    """

    # 0) sanitize inputs
    prev = _sanitize_snapshot(prev_snap)
    new  = _sanitize_snapshot(new_snap)
    tickets       = _sanitize_tickets(tickets)
    temp_tickets  = _sanitize_temp_tickets(temp_tickets)
    if decay_queue is None: decay_queue = []
    if failure_memory is None: failure_memory = {}

    # 1) EIS
    entropy = new["entropy"]
    stability = 1.0 - new["branch_flip_rate"]
    depth_ratio = new["depth"] / max(1, new["node_count"])
    EIS = 0.4*entropy + 0.3*(1.0 - stability) + 0.3*depth_ratio
    EIS = float(_clip(EIS, 0.0, 4.0))

    # 2) Failure memory
    failure_sig = new.get("failure_sig", None)
    lesson_flag, insight_flag = 0, 0
    outcome = new.get("outcome", "neutral")

    IS = 0.0
    if failure_sig:
        cluster = failure_memory.get(failure_sig)
        if cluster:
            cluster["repeat_count"] = int(cluster.get("repeat_count", 0)) + 1
            stats = cluster.setdefault("stats", {})
            stats["last_seen"] = time.time()

            # Impact Score
            denom = max(1, new["depth"])
            IS = 0.5*(1.0 - stability/denom) + 0.3*math.log(1 + cluster["repeat_count"]) + 0.2*EIS
            IS = float(_clip(IS, 0.0, 10.0))
            cluster["impact_score"] = round(IS, 3)

            # proposed fix path
            if proposed_fix:
                quality = _safe_float(proposed_fix.get("quality", 0.0), 0.0)
                thresh = max(0.5, 0.7 - 0.2*EIS)
                if quality > thresh:
                    cluster["solution"] = proposed_fix
                    cluster["resolved"] = True
                    outcome = "resolved_cluster"
                    insight_flag = 1
                else:
                    outcome = "repeat_fail"
                    lesson_flag = 1
            else:
                outcome = "repeat_fail"
                lesson_flag = 1
        else:
            failure_memory[failure_sig] = {
                "signature": failure_sig,
                "repeat_count": 1,
                "impact_score": round(0.2*EIS, 3),
                "solution": None,
                "resolved": False,
                "stats": {"created": time.time(), "last_seen": time.time()}
            }
            outcome = "unique_fail"
            lesson_flag = 1
    else:
        IS = float(round(0.2*EIS, 3))

    # 3) Adaptive coeffs
    alpha = 1 + 0.3*EIS
    beta  = 1 - 0.2*EIS
    gamma = 1 + 0.5*EIS
    delta = 1 + 0.6*EIS

    # 4) Ticket updates
    if outcome == "success":
        tickets["G"] += int(round(3 * alpha))
    elif outcome == "neutral":
        tickets["B"] += max(1, int(round(2 * max(0.1, beta))))
    elif outcome == "unique_fail":
        tickets["Y"] += int(round(1 * gamma))
    elif outcome == "repeat_fail":
        tickets["R"] += int(round(max(1, beta)))
    elif outcome == "resolved_cluster":
        tickets["P"] += 1

    # Purple conversions/loans (optional flags in new snapshot)
    if tickets.get("P",0) >= 1 and new.get("convert_purple", False):
        tickets["P"] -= 1
        target = new.get("convert_target", "G")
        tickets[target] = tickets.get(target, 0) + 3

    if tickets.get("P",0) >= 1 and new.get("loan_purple", False):
        tickets["P"] -= 1
        target = new.get("loan_target", "G")
        temp_tickets[target] = temp_tickets.get(target, 0) + 1
        decay_queue.append(target)

    # Decay one temporary ticket per step (if any)
    if decay_queue:
        expired = decay_queue.pop(0)
        temp_tickets[expired] = max(0, temp_tickets.get(expired, 0) - 1)

    # 5) Reward (use shaping from new snapshot; may be 0.0 if absent)
    pos = _safe_float(new.get("pos_score", 0.0), 0.0)
    neg = _safe_float(new.get("neg_score", 0.0), 0.0)
    reward = alpha*pos - beta*neg + gamma*lesson_flag + delta*insight_flag - 2.0*IS
    reward = float(round(reward, 6))

    # 6) Log
    log_entry = {
        "EIS": round(EIS, 3),
        "IS": round(IS, 3),
        "coeffs": {"alpha": round(alpha,3), "beta": round(beta,3),
                   "gamma": round(gamma,3), "delta": round(delta,3)},
        "tickets": dict(tickets),
        "temp_tickets": dict(temp_tickets),
        "outcome": outcome,
        "pos": round(pos, 6), "neg": round(neg, 6),
        "lesson": int(lesson_flag), "insight": int(insight_flag),
        "reward": round(reward, 6),
    }
    return reward, tickets, failure_memory, log_entry

Writing compute_reward.py


In [None]:
import uuid
import time
from typing import Dict, List

class GeneratorAdapter:
    def __init__(self, model_name="tlite-v4.1", backend=None):
        self.model_name = model_name
        self.backend = backend  # hook to TLite or HF model

    def build_prompt(self, request: Dict) -> str:
        """Format prompt based on request type."""
        pt = request.get("prompt_type", "subtree_expand")
        snapshot = request.get("snapshot", {})
        anchors = snapshot.get("anchors", [])
        constraints = request.get("constraints", {})

        if pt == "subtree_expand":
            return f"""CONTEXT:
{snapshot.get('tree','')}
ANCHOR: {anchors[0] if anchors else 'root'}
CONSTRAINTS: {constraints}
TASK: Propose up to {request.get('k',3)} valid subtree expansions that reduce entropy and improve stability.
Provide both: (a) structured subtree JSON, (b) natural explanation."""

        elif pt == "code_fix":
            return f"""CONTEXT:
{snapshot.get('tree','')}
ERROR: {constraints.get('error_trace','')}
TASK: Suggest a minimal patch (<=10 lines) to fix the bug and explain reasoning in 2 sentences.
Provide patch as unified diff and as AST JSON."""

        elif pt == "rule_synth":
            return f"""CONTEXT:
{constraints.get('example_pairs','')}
TASK: Propose a concise rule (pseudo-code) that generalizes the transformation.
Format: rule_name, pattern, replacement, guard_conditions, cost."""

        else:
            return f"Default prompt: expand snapshot {snapshot.get('id','unknown')}"

    def generate(self, request: Dict, k: int = 3) -> List[Dict]:
        """Generate candidates using backend model."""
        prompt = self.build_prompt(request)
        start = time.time()

        # --- Call backend (stubbed here) ---
        # Replace with actual model call: self.backend.generate(prompt, k=k)
        outputs = [f"Mock output {i} for {prompt[:40]}..." for i in range(k)]

        latency = int((time.time() - start) * 1000)
        candidates = []
        for i, out in enumerate(outputs):
            candidates.append({
                "candidate_id": f"cand-{uuid.uuid4().hex[:6]}",
                "request_id": request["request_id"],
                "generated_subtree": {"node_id": f"gen_{i}", "value": "mock"},
                "natural_text": out,
                "gen_confidence": 0.7 + 0.05*i,
                "generator_version": self.model_name,
                "generation_latency_ms": latency
            })
        return candidates

In [None]:

import hashlib
import numpy as np
from typing import List, Dict

class QuickScorer:
    def __init__(self, novelty_cache=None):
        self.novelty_cache = novelty_cache or set()

    def score_batch(self, candidates: List[Dict], snapshot: Dict) -> List[Dict]:
        scored = []
        for cand in candidates:
            text = cand.get("natural_text", "")
            # --- Novelty via hash ---
            h = hashlib.md5(text.encode()).hexdigest()
            novelty = 1.0 if h not in self.novelty_cache else 0.0
            self.novelty_cache.add(h)

            # --- Coverage heuristic ---
            node_count = snapshot.get("node_count", 1)
            subtree_size = len(cand.get("generated_subtree", {}).get("children", []))
            coverage = min(1.0, subtree_size / max(1, node_count))

            # --- Anchor alignment (stub: reward if anchor mentioned) ---
            anchors = snapshot.get("anchors", [])
            anchor_bonus = 0.2 if any(a["node_id"] in text for a in anchors) else 0.0

            cand["novelty_score"] = novelty
            cand["coverage_score"] = coverage
            cand["anchor_bonus"] = anchor_bonus
            scored.append(cand)
        return scored

    def select_top_k(self, candidates: List[Dict], k: int = 4) -> List[Dict]:
        return sorted(candidates, key=lambda c: (
            c.get("novelty_score", 0.0) +
            c.get("coverage_score", 0.0) +
            c.get("anchor_bonus", 0.0)
        ), reverse=True)[:k]

In [None]:

import random
import numpy as np
from typing import Dict, Tuple

class SimEngine:
    def __init__(self, n_rollouts: int = 3):
        self.n_rollouts = n_rollouts

    def dual_valence(self, snapshot: Dict, candidate: Dict) -> Tuple[float,float,float,float]:
        """
        Run optimistic and pessimistic rollouts.
        Returns: (pos, neg, delta_entropy, delta_stability)
        """
        entropy = snapshot.get("entropy", 0.5)
        stability = 1.0 - snapshot.get("branch_flip_rate", 0.0)

        pos_scores, neg_scores, dEs, dSs = [], [], [], []

        for _ in range(self.n_rollouts):
            # optimistic rollout
            pos = max(0.0, min(1.0, entropy*0.6 + stability*0.4 + random.uniform(0,0.1)))
            # pessimistic rollout
            neg = max(0.0, min(1.0, (1-stability)*0.7 + entropy*0.3 + random.uniform(0,0.1)))

            dE = random.uniform(-0.1, 0.1)  # entropy delta
            dS = random.uniform(-0.1, 0.1)  # stability delta

            pos_scores.append(pos); neg_scores.append(neg)
            dEs.append(dE); dSs.append(dS)

        return (
            float(np.mean(pos_scores)),
            float(np.mean(neg_scores)),
            float(np.mean(dEs)),
            float(np.mean(dSs))
        )

In [None]:
%%writefile generative_decision_loop_safe.py
import uuid
import numpy as np

def generative_decision_loop(snapshot,
                             tickets,
                             failure_memory,
                             config,
                             generator,          # GeneratorAdapter
                             scorer,             # QuickScorer
                             sim_engine,         # SimEngine
                             verifier,           # Verifier
                             policy_module,      # Policy
                             apply_engine,       # ApplyEngine
                             provenance_logger,  # ProvenanceLogger
                             fallback_module):   # FallbackStrategies
    """
    End-to-end generative decision loop with QuickScorer + SimEngine integrated.
    """

    # --- 1. Build generator request ---
    req_id = f"req-{uuid.uuid4().hex[:6]}"
    request = {
        "request_id": req_id,
        "snapshot_id": snapshot.get("id", "unknown"),
        "snapshot": snapshot,
        "prompt_type": config.get("prompt_type", "subtree_expand"),
        "constraints": config.get("constraints", {}),
        "ticket_budget": dict(tickets),
        "k": config.get("k", 4)
    }
    provenance_logger.log_request(request)

    # --- 2. Generate candidates ---
    candidates = generator.generate(request, k=request["k"])
    provenance_logger.log_candidates(req_id, candidates)

    # --- 3. Quick scoring + filter (novelty, coverage, anchor bonus) ---
    candidates = scorer.score_batch(candidates, snapshot)
    top_candidates = scorer.select_top_k(candidates, k=min(config.get("k_quick", 4), len(candidates)))

    scored = []
    # --- 4. Micro-sim scoring (dual valence) ---
    for cand in top_candidates:
        pos, neg, dE, dS = sim_engine.dual_valence(snapshot, cand)
        novelty = cand.get("novelty_score", 0.0)
        gconf = cand.get("gen_confidence", 0.5)

        # --- 5. Fuse scores ---
        quality_raw = 1.0 * pos - 0.9 * neg + 0.8 * dE + 0.6 * dS
        fused = 0.7 * quality_raw + 0.2 * gconf + 0.1 * novelty
        final = 1 / (1 + np.exp(-fused))  # sigmoid squash

        ok, reason = verifier.check(snapshot, cand)

        scored.append({
            "candidate": cand,
            "pos": pos,
            "neg": neg,
            "delta_entropy": dE,
            "delta_stability": dS,
            "novelty": novelty,
            "final_score": final,
            "verifier_ok": ok,
            "verifier_reason": reason
        })
    provenance_logger.log_scores(req_id, scored)

    # --- 6. Policy selection (ticket-aware) ---
    scored = policy_module.apply_ticket_penalties(
        scored,
        tickets,
        config.get("ticket_costs", {})
    )
    chosen = policy_module.select(
        scored,
        policy=config.get("policy", "greedy"),
        epsilon=config.get("epsilon", 0.1),
        temperature=config.get("temperature", 1.0)
    )

    # --- 7. Verify + Apply or Fallback ---
    if chosen and chosen["verifier_ok"] and chosen["final_score"] >= config.get("min_accept_score", 0.6):
        pre_snap, post_snap, success, reason = apply_engine.apply(
            None, snapshot, chosen["candidate"], verifier
        )
        if success:
            provenance_logger.log_accept(request, chosen, post_snap)
            return post_snap, {"status": "accepted", "reason": reason,
                               "ticket_penalty": chosen.get("ticket_penalty_factor", 1.0)}
        else:
            provenance_logger.log_fallback(request, chosen, pre_snap)
            return pre_snap, {"status": "apply_failed", "reason": reason}
    else:
        fallback = fallback_module.choose(snapshot, scored, [], failure_memory)
        provenance_logger.log_fallback(request, fallback, snapshot)
        return snapshot, {"status": "fallback", "reason": "no valid candidate"}

Writing generative_decision_loop_safe.py


In [None]:
import hashlib
import time
from typing import Dict, Tuple, Optional

def _now():
    return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())

def make_signature(details: dict, node_id: str, context_id: str) -> str:
    """
    Build a stable cluster signature from salient failure features.
    Replace fields to match your domain (e.g., rule_id, error_type, anchor path).
    """
    key = f"{node_id}|{details.get('error_type','unknown')}|{details.get('anchor','')}"
    return hashlib.sha1(key.encode("utf-8")).hexdigest()[:12]

def compute_cluster_is(tree_depth: int, max_depth: int, repeat_count: int, eis: float) -> float:
    """
    Mirror spec: IS = 0.5*(1 - ATreeScore/max_depth) + 0.3*log(1+repeat_count) + 0.2*EIS
    Use tree_depth as proxy for ATreeScore if you don’t have it yet.
    """
    import math
    atree_score = max(1, tree_depth)
    return 0.5*(1 - atree_score/max(1, max_depth)) + 0.3*math.log(1+repeat_count) + 0.2*eis

def record_failure(failure_memory: Dict[str, dict],
                   context_id: str, node_id: str, details: dict,
                   tree_depth: int, max_depth: int, eis: float) -> Tuple[str, dict]:
    """
    Upsert a failure into its cluster and update impact metrics.
    Returns (signature, cluster_ref).
    """
    sig = make_signature(details, node_id, context_id)
    cluster = failure_memory.get(sig)
    if cluster is None:
        cluster = {
            "signature": sig,
            "problems": [],
            "repeat_count": 0,
            "impact_score": 0.0,
            "cluster_weight": 1.0,
            "solution": None,
            "stats": {"last_seen": _now(), "resolved": False, "uses": 0}
        }
        failure_memory[sig] = cluster

    cluster["problems"].append({
        "timestamp": _now(),
        "context_id": context_id,
        "node_id": node_id,
        "details": details
    })
    cluster["repeat_count"] += 1
    cluster["stats"]["last_seen"] = _now()

    # Update metrics
    cluster["impact_score"] = compute_cluster_is(tree_depth, max_depth, cluster["repeat_count"], eis)
    # Optional: scale cluster weight with repeat_count to reflect severity
    cluster["cluster_weight"] = min(5.0, 1.0 + 0.25 * cluster["repeat_count"])

    return sig, cluster

def update_cluster_solution(failure_memory: Dict[str, dict],
                            signature: str,
                            natural_text: str,
                            patch: dict,
                            quality: float,
                            provenance: dict) -> Optional[dict]:
    """
    Attach a verified solution to a cluster; mark resolved if quality passes threshold.
    """
    cluster = failure_memory.get(signature)
    if not cluster:
        return None
    cluster["solution"] = {
        "natural_text": natural_text,
        "patch": patch,
        "quality": quality,
        "provenance": provenance
    }
    cluster["stats"]["resolved"] = quality >= 0.6   # tune threshold to your verifier
    return cluster

def reuse_solution_if_available(failure_memory: Dict[str, dict],
                               signature: str) -> Optional[dict]:
    """
    Fetch a stored solution for reapplication.
    """
    cluster = failure_memory.get(signature)
    if not cluster or not cluster.get("solution"):
        return None
    cluster["stats"]["uses"] += 1
    return cluster["solution"]

def prune_clusters(failure_memory: Dict[str, dict],
                   max_clusters: int = 50) -> None:
    """
    Keep memory bounded; drop the lowest-impact clusters if above cap.
    """
    if len(failure_memory) <= max_clusters:
        return
    # Sort by impact_score descending; keep top max_clusters
    sorted_items = sorted(failure_memory.items(), key=lambda kv: kv[1].get("impact_score", 0.0), reverse=True)
    to_keep = dict(sorted_items[:max_clusters])
    failure_memory.clear()
    failure_memory.update(to_keep)

In [None]:

def award_purple_on_resolution(tickets: Dict[str, int], cluster: dict) -> None:
    """
    If cluster is resolved with acceptable quality, grant Purple insight.
    """
    if cluster.get("stats", {}).get("resolved", False):
        tickets["P"] = tickets.get("P", 0) + 1

def convert_purple(tickets: Dict[str, int],
                   temp_tickets: Dict[str, int],
                   decay_queue: list,
                   target_color: str = "G",
                   loan: bool = False) -> str:
    """
    v1.3 conversion: 1P -> 3 any OR 1 loan ticket (expires next cycle).
    """
    if tickets.get("P", 0) < 1:
        return "Insufficient Purple"
    tickets["P"] -= 1
    if loan:
        temp_tickets[target_color] = temp_tickets.get(target_color, 0) + 1
        decay_queue.append(target_color)
        return f"Loaned 1{target_color}, expires next cycle"
    else:
        tickets[target_color] = tickets.get(target_color, 0) + 3
        return f"Converted 1P -> 3{target_color}"

def decay_loans(temp_tickets: Dict[str, int], tickets: Dict[str, int], decay_queue: list) -> None:
    """
    Apply loan decay at episode boundary.
    """
    while decay_queue:
        color = decay_queue.pop(0)
        temp_tickets[color] = max(0, temp_tickets.get(color, 0) - 1)

In [None]:
import re
from typing import Tuple, Dict

class Verifier:
    def __init__(self, max_nodes: int = 50):
        self.max_nodes = max_nodes

    def check(self, snapshot: Dict, candidate: Dict) -> Tuple[bool, str]:
        """
        Verify candidate before applying.
        Returns (pass: bool, reason: str).
        """
        # --- 1. Structural checks ---
        subtree = candidate.get("generated_subtree", {})
        if not isinstance(subtree, dict):
            return False, "Invalid subtree format"

        # --- 2. Lock checks ---
        locked_nodes = snapshot.get("locked_nodes", [])
        must_preserve = candidate.get("constraints", {}).get("must_preserve_nodes", [])
        for nid in must_preserve:
            if nid in locked_nodes:
                return False, f"Lock violation on node {nid}"

        # --- 3. Complexity cap ---
        node_count = snapshot.get("node_count", 0)
        added_nodes = len(subtree.get("children", []))
        if node_count + added_nodes > self.max_nodes:
            return False, "Complexity cap exceeded"

        # --- 4. Domain-specific checks ---
        natxt = candidate.get("natural_text", "")
        if self._contains_profanity(natxt):
            return False, "Prohibited language"

        if snapshot.get("domain") == "code":
            if not self._lint_code(natxt):
                return False, "Code lint failed"

        if snapshot.get("domain") == "math":
            if not self._math_simplifiable(natxt):
                return False, "Math expression invalid"

        # --- 5. Passed all checks ---
        return True, "Verifier passed"

    # --- Helpers ---
    def _contains_profanity(self, text: str) -> bool:
        banned = ["badword1", "badword2"]  # extend as needed
        return any(b in text.lower() for b in banned)

    def _lint_code(self, code: str) -> bool:
        # Placeholder: could integrate flake8/pylint
        return "import os" not in code  # trivial sandbox rule

    def _math_simplifiable(self, expr: str) -> bool:
        # Placeholder: integrate sympy.simplify
        return bool(expr and re.match(r"^[0-9x+\-*/^() ]+$", expr))

In [None]:

import numpy as np
import random
from typing import List, Dict, Optional

class Policy:
    @staticmethod
    def apply_ticket_penalties(scored: List[Dict], tickets: Dict, ticket_costs: Dict) -> List[Dict]:
        """
        Apply ticket-aware penalties to candidate scores.
        If candidate requires more tickets than available, scale score down.
        """
        adjusted = []
        for cand in scored:
            action_type = cand["candidate"].get("action_type", "generative")
            cost = ticket_costs.get(action_type, {"G": 10, "Y": 1})  # default
            available = sum(tickets.get(k, 0) for k in cost.keys())

            required = sum(cost.values())
            if required == 0:
                factor = 1.0
            else:
                factor = min(1.0, (available + 1e-6) / (required + 1e-6))

            cand["final_score"] *= factor
            cand["ticket_blocked"] = (factor < 1.0)
            cand["ticket_penalty_factor"] = round(factor, 3)
            adjusted.append(cand)
        return adjusted

    @staticmethod
    def select(scored: List[Dict],
               policy: str = "greedy",
               epsilon: float = 0.1,
               temperature: float = 1.0) -> Optional[Dict]:
        """
        Select candidate based on policy.
        """
        valid = [c for c in scored if c["final_score"] > 0]  # allow penalized but nonzero
        if not valid:
            return None

        if policy == "greedy":
            return max(valid, key=lambda c: c["final_score"])

        elif policy == "epsilon_greedy":
            if random.random() < epsilon:
                return random.choice(valid)
            return max(valid, key=lambda c: c["final_score"])

        elif policy == "softmax":
            scores = np.array([c["final_score"] for c in valid])
            exp_scores = np.exp(scores / max(temperature, 1e-6))
            probs = exp_scores / exp_scores.sum()
            return np.random.choice(valid, p=probs)

        else:
            # Default fallback: greedy
            return max(valid, key=lambda c: c["final_score"])

In [None]:

import json
import time
from typing import Dict, List

class ProvenanceLogger:
    def __init__(self, filepath="provenance_log.jsonl"):
        self.filepath = filepath

    def _write(self, record: Dict):
        record["timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
        with open(self.filepath, "a") as f:
            f.write(json.dumps(record) + "\n")

    def log_request(self, request: Dict):
        self._write({
            "event": "request",
            "request": request
        })

    def log_candidates(self, request_id: str, candidates: List[Dict]):
        self._write({
            "event": "candidates",
            "request_id": request_id,
            "candidates": candidates
        })

    def log_scores(self, request_id: str, scores: List[Dict]):
        self._write({
            "event": "scores",
            "request_id": request_id,
            "scores": scores
        })

    def log_accept(self, request: Dict, chosen: Dict, snapshot: Dict):
        self._write({
            "event": "accept",
            "request_id": request["request_id"],
            "chosen": chosen,
            "snapshot_post": snapshot
        })

    def log_fallback(self, request: Dict, fallback: Dict, snapshot: Dict):
        self._write({
            "event": "fallback",
            "request_id": request["request_id"],
            "fallback": fallback,
            "snapshot_post": snapshot
        })

    def log_reject(self, request: Dict, rejected: List[Dict], snapshot: Dict):
        self._write({
            "event": "reject",
            "request_id": request["request_id"],
            "rejected": rejected,
            "snapshot_post": snapshot
        })

In [None]:
import copy
import time
from typing import Dict, Tuple

class ApplyEngine:
    def __init__(self):
        pass

    def apply(self, tree, snapshot: Dict, candidate: Dict, verifier) -> Tuple[Dict, Dict, bool, str]:
        """
        Apply a candidate edit to the tree atomically.
        Returns (pre_snapshot, post_snapshot, success, reason).
        """
        # --- 1. Save pre-snapshot ---
        pre_snapshot = copy.deepcopy(snapshot)

        # --- 2. Verify candidate before apply ---
        ok, reason = verifier.check(snapshot, candidate)
        if not ok:
            return pre_snapshot, snapshot, False, f"Verifier failed: {reason}"

        try:
            # --- 3. Apply candidate edit ---
            # Placeholder: actual logic depends on your tree structure
            # Example: add/replace nodes
            subtree = candidate.get("generated_subtree", {})
            if subtree:
                # naive example: attach subtree to root
                if hasattr(tree, "children"):
                    tree.children.append(subtree)

            # --- 4. Build post-snapshot ---
            post_snapshot = TreeSnapshot(tree).to_dict()

            return pre_snapshot, post_snapshot, True, "Applied successfully"

        except Exception as e:
            # --- 5. Rollback on failure ---
            return pre_snapshot, snapshot, False, f"Apply failed: {str(e)}"

In [None]:
import random
from typing import Dict, List, Optional

class FallbackStrategies:
    @staticmethod
    def conservative_generator(snapshot: Dict) -> Dict:
        """Produce a minimal safe patch instead of a full subtree."""
        return {
            "candidate_id": f"fallback-{random.randint(1000,9999)}",
            "generated_subtree": {"node_id": "hint", "value": "TODO: refine"},
            "natural_text": "Conservative fallback: added placeholder node.",
            "gen_confidence": 0.3,
            "verifier_ok": True,
            "verifier_reason": "Conservative safe patch"
        }

    @staticmethod
    def structural_best(snapshot: Dict, structural_actions: List[Dict]) -> Optional[Dict]:
        """Pick the best structural action if available."""
        if not structural_actions:
            return None
        # Assume each action has a 'score' field
        return max(structural_actions, key=lambda a: a.get("score", 0.0))

    @staticmethod
    def escalation(snapshot: Dict, failure_memory: Dict, reason: str) -> Dict:
        """Escalate unresolved case into failure memory and issue Yellow ticket."""
        sig = f"escalation-{random.randint(1000,9999)}"
        failure_memory[sig] = {
            "signature": sig,
            "problems": [{"timestamp": "now", "details": {"reason": reason}}],
            "repeat_count": 1,
            "impact_score": 0.5,
            "solution": None,
            "stats": {"last_seen": "now", "resolved": False, "uses": 0}
        }
        return {
            "candidate_id": sig,
            "natural_text": f"Escalated unresolved case: {reason}",
            "verifier_ok": False,
            "verifier_reason": "Escalation"
        }

    @staticmethod
    def choose(snapshot: Dict, scored: List[Dict], structural_actions: List[Dict], failure_memory: Dict) -> Dict:
        """
        Decide which fallback to use.
        Priority: conservative generator > structural best > escalation.
        """
        if scored:
            return FallbackStrategies.conservative_generator(snapshot)
        elif structural_actions:
            return FallbackStrategies.structural_best(snapshot, structural_actions)
        else:
            return FallbackStrategies.escalation(snapshot, failure_memory, "No valid candidates")

In [None]:
%%writefile tlite_state_action_encoder.py
import torch
import torch.nn as nn

class StateActionEncoder(nn.Module):
    """
    Shared encoder to transform ANY feature vector (state or action)
    → into a 64D neural embedding TLite can use.
    """
    def __init__(self, input_dim, embed_dim=64, device='cpu'):
        super().__init__()
        self.device = device
        self.net = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, embed_dim * 2),
            nn.GELU(),
            nn.Linear(embed_dim * 2, embed_dim)
        ).to(device)

    def forward(self, raw_vec):
        if not isinstance(raw_vec, torch.Tensor):
            raw_vec = torch.tensor(raw_vec, dtype=torch.float32)
        return self.net(raw_vec.to(self.device))

Writing tlite_state_action_encoder.py


In [None]:

%%writefile tlite_rl_bridge.py
import torch
import torch.nn as nn
import numpy as np

# -----------------------------------------------------------
# TLite scorer wrapper
# -----------------------------------------------------------
class TLiteActionScorer:
    def __init__(self, model, device='cpu'):
        """
        model: a torch.nn.Module with forward(x) and attribute .expected_dim
               (combined feature dim). If dims mismatch, a projector is created.
        """
        self.model = model.to(device)
        self.device = device

        # Align action embedding (50) to state embedding (64) for elementwise interaction.
        self.align_action = nn.Linear(50, 64).to(device)

        # Optional projector if model.expected_dim != computed combined dim.
        self.projector = None

    def _as_tensor1d(self, vec):
        # Accept list/np/tensor; return shape (1, D) on device
        if not isinstance(vec, torch.Tensor):
            vec = torch.tensor(vec, dtype=torch.float32)
        return vec.to(self.device).unsqueeze(0)

    def score_actions(self, state_vector, action_vectors):
        """
        state_vector: 1D np.ndarray / list / torch.Tensor (length 64)
        action_vectors: list of 1D vectors (each length 50)
        returns: list[float] scores
        """
        state_rep = self._as_tensor1d(state_vector)  # (1, 64)

        scores = []
        for act_vec in action_vectors:
            act_rep = self._as_tensor1d(act_vec)     # (1, 50)
            act_rep = self.align_action(act_rep)     # (1, 64) <- aligned to state

            # Elementwise interaction now safe (both 64)
            interaction = state_rep * act_rep        # (1, 64)

            # Combined features: [state, action_aligned, interaction] = 64+64+64 = 192
            combined = torch.cat([state_rep, act_rep, interaction], dim=-1)  # (1, 192)

            # Match model expected_dim if provided
            if hasattr(self.model, "expected_dim"):
                expected = int(self.model.expected_dim)
                if combined.shape[-1] != expected:
                    if self.projector is None:
                        self.projector = nn.Linear(combined.shape[-1], expected).to(self.device)
                    combined = self.projector(combined)

            score = self.model(combined).item()
            scores.append(score)

        return scores


# -----------------------------------------------------------
# FEATURE EXTRACTION
# -----------------------------------------------------------
STATE_KEYS_ORDER = [
    "depth",
    "node_count",
    "leaf_count",
    "branching_factor",
    "entropy",
    "weak_leaves",
]

ACTION_TYPES = ["split", "prune", "reorder", "lock", "unlock"]


def extract_features_from_state(state):
    """
    Accepts:
      - dict snapshot with keys in STATE_KEYS_ORDER
      - list/tuple/np.ndarray numeric of length 6 (already vectorized)
    Returns:
      - list[float] of length 6
    """
    if isinstance(state, dict):
        out = []
        for k in STATE_KEYS_ORDER:
            v = state.get(k, 0.0)
            try:
                out.append(float(v))
            except Exception:
                out.append(0.0)
        return out

    if isinstance(state, (list, tuple, np.ndarray)):
        arr = np.asarray(state, dtype=np.float32).flatten()
        if arr.shape[0] < 6:
            arr = np.concatenate([arr, np.zeros(6 - arr.shape[0], dtype=np.float32)], axis=0)
        elif arr.shape[0] > 6:
            arr = arr[:6]
        return arr.tolist()

    return [0.0] * 6


def extract_features_from_action(action):
    """
    Accepts:
      - tuple/list: (op_type, target_id, extra)
      - dict: {"type": "...", ...}
    Returns:
      - one-hot list[float] over ACTION_TYPES (len=5)
    """
    op = None
    if isinstance(action, (list, tuple)) and len(action) >= 1:
        op = action[0]
    elif isinstance(action, dict):
        op = action.get("type") or action.get("op_type")

    return [1.0 if op == t else 0.0 for t in ACTION_TYPES]


# -----------------------------------------------------------
# SHARED ENCODERS (separate for state/action to avoid dim clash)
# -----------------------------------------------------------
from tlite_state_action_encoder import StateActionEncoder

STATE_EMBED_DIM = 64
ACTION_EMBED_DIM = 50

_state_encoder = None
_action_encoder = None


def encode_state_to_vector(state):
    """
    Returns a 1D numpy vector of length STATE_EMBED_DIM (64).
    """
    global _state_encoder
    raw_list = extract_features_from_state(state)
    raw = torch.tensor(raw_list, dtype=torch.float32)

    if _state_encoder is None:
        _state_encoder = StateActionEncoder(input_dim=raw.shape[-1], embed_dim=STATE_EMBED_DIM).to(raw.device)

    with torch.no_grad():
        emb = _state_encoder(raw)  # (64,)
    return emb.detach().cpu().numpy()


def encode_action_to_vector(action):
    """
    Returns a 1D numpy vector of length ACTION_EMBED_DIM (50).
    """
    global _action_encoder
    raw_list = extract_features_from_action(action)
    raw = torch.tensor(raw_list, dtype=torch.float32)

    if _action_encoder is None:
        _action_encoder = StateActionEncoder(input_dim=raw.shape[-1], embed_dim=ACTION_EMBED_DIM).to(raw.device)

    with torch.no_grad():
        emb = _action_encoder(raw)  # (50,)
    return emb.detach().cpu().numpy()


# -----------------------------------------------------------
# OPTIONAL: tiny smoke-test if run as script
# -----------------------------------------------------------
if __name__ == "__main__":
    class TinyModel(nn.Module):
        # combined = state(64) + aligned_action(64) + interaction(64) = 192
        def __init__(self, expected_dim=192):
            super().__init__()
            self.expected_dim = expected_dim
            self.net = nn.Sequential(
                nn.Linear(self.expected_dim, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )

        def forward(self, x):
            return self.net(x)

    # Fake snapshot (dict or 6-list is fine)
    snap = {
        "depth": 3,
        "node_count": 100,
        "leaf_count": 60,
        "branching_factor": 3.3,
        "entropy": 2.1,
        "weak_leaves": 20,
    }
    state_vec = encode_state_to_vector(snap)  # (64,)

    # Fake actions
    actions = [
        ("split", "node_1", None),
        ("prune", "node_5", None),
        ("reorder", "node_2", None),
        ("lock", "node_9", "soft"),
        ("unlock", "node_9", None),
    ]
    action_vecs = [encode_action_to_vector(a) for a in actions]  # each (50,)

    model = TinyModel()
    bridge = TLiteActionScorer(model, device="cpu")
    scores = bridge.score_actions(state_vec, action_vecs)
    print("scores:", scores)

Writing tlite_rl_bridge.py


In [None]:
"""
from tlite_rl_bridge import TLiteActionScorer

bridge = TLiteActionScorer(tlite_model, device='cpu')

def neural_choose_action(state, candidate_actions):
    state_vec = encode_state_to_vector(state)
    action_vecs = [encode_action_to_vector(a) for a in candidate_actions]

    neural_scores = bridge.score_actions(state_vec, action_vecs)
    shaped_scores = acres.apply_shaping(neural_scores, state)

    best_index = shaped_scores.index(max(shaped_scores))
    return candidate_actions[best_index]"""

"\nfrom tlite_rl_bridge import TLiteActionScorer\n\nbridge = TLiteActionScorer(tlite_model, device='cpu')\n\ndef neural_choose_action(state, candidate_actions):\n    state_vec = encode_state_to_vector(state)\n    action_vecs = [encode_action_to_vector(a) for a in candidate_actions]\n\n    neural_scores = bridge.score_actions(state_vec, action_vecs)\n    shaped_scores = acres.apply_shaping(neural_scores, state)\n\n    best_index = shaped_scores.index(max(shaped_scores))\n    return candidate_actions[best_index]"

In [None]:

def neural_choose_action(state, candidate_actions):
    """
    Placeholder decision policy before TLite distillation.
    Uses ACReS reward shaping to pick the best action.
    """
    scored = [(a, acres.quick_score(state, a)) for a in candidate_actions]
    scored.sort(key=lambda x: x[1], reverse=True)
    return scored[0][0] if scored else None

In [None]:

def score_candidates_with_neural(self, state_vector, candidates):
    # Step 1: Neural scoring
    action_vectors = [encode_action_to_vector(c["action"]) for c in candidates]
    neural_scores = bridge.score_actions(state_vector, action_vectors)

    for cand, ns in zip(candidates, neural_scores):
        cand["scores"]["neural_score"] = ns

    # Step 2: Apply A-CRES shaping per candidate
    for cand in candidates:
        base = cand["scores"]["neural_score"]
        shaped = acres.apply_shaping([base], state_vector)[0]   # convert list → scalar
        cand["scores"]["shaped"] = shaped

    # Step 3: Final combined score (RL + Neural + A-CRES)
    for cand in candidates:
        cand["scores"]["final_score"] = (
            0.4  * cand["scores"]["neural_score"] +
            0.35 * cand["scores"].get("acres_score", cand["scores"]["shaped"]) +
            0.25 * cand["scores"]["rl_value"]
        )

    return candidates

In [None]:

def mix_scores(c, w_neural, w_rl, w_acres):
    s = c["scores"]
    neural = s.get("neural_score", 0.0)
    rl_val = s.get("adjusted_quality", 0.0)
    acres = s.get("shaped", 0.0)  # <-- will add next line

    return (w_neural * neural) + (w_rl * rl_val) + (w_acres * acres)

In [None]:

import copy
import torch

class ConversationEnv:
    """
    Mode A: Fixed step conversation mode
    Agent generates response → Feedback computed → Next step.
    Episode ends after N steps.
    """
    def __init__(self, max_turns=6):
        self.max_turns = max_turns
        self.turn = 0

    def reset(self, initial_input):
        self.turn = 0
        self.history = [initial_input]
        return initial_input

    def step(self, action, reward_system, stability_score_fn):
        """
        action = selected candidate response (text or tree action)
        """
        self.history.append(action)
        self.turn += 1

        # Compute conversational stability (no confusion / topic drift)
        stability = stability_score_fn(self.history)
        reward = reward_system.shaping_for_conversation(stability)

        done = (self.turn >= self.max_turns)
        return action, reward, done

In [None]:

class TreeEnv:
    """
    Mode B: Stop when stability threshold is reached.
    Agent edits / refines tree representation gradually.
    """
    def __init__(self, stability_threshold=0.72):
        self.stability_threshold = stability_threshold

    def reset(self, tree):
        self.tree = copy.deepcopy(tree)
        return self.tree

    def step(self, action, apply_edit_fn, reward_system, compute_stability):
        """
        action = structured tree edit (add/merge/replace/prune)
        """
        self.tree = apply_edit_fn(self.tree, action)

        stability = compute_stability(self.tree)
        reward = reward_system.shaping_for_tree(stability)

        done = (stability >= self.stability_threshold)

        return self.tree, reward, done

In [None]:

class TaskEnv:
    """
    Mode C: STOP is a valid action.
    Agent executes pipeline steps: Summarize → Classify → Generate → Refine.
    Episode ends when action == "STOP".
    """
    def reset(self, input_data):
        self.context = input_data
        self.output = None
        return input_data

    def step(self, action, apply_task_op_fn, reward_system):
        if action.get("type") == "STOP":
            # End of multi-stage task
            final_quality = reward_system.evaluate_final_output(self.output)
            return self.output, final_quality, True

        # Perform transformation step
        self.output = apply_task_op_fn(self.context, action)
        reward = reward_system.shaping_for_task_step(self.output)

        return self.output, reward, False

In [None]:

import random
import math

class TreeStabilityEnv:
    """
    Agent must construct a stable expression tree.
    State = current tree signature + stability score.
    Actions = grow / modify / prune nodes.
    Episode ends on success or instability collapse.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.tree = {"nodes": 1, "depth": 1, "balance": 1.0}  # minimal stable seed
        self.steps = 0
        return self._get_state()

    def _get_state(self):
        # return compact numeric state representation
        return [
            self.tree["nodes"],
            self.tree["depth"],
            self.tree["balance"]
        ]

    def get_actions(self):
        """
        Possible structural modifications.
        """
        return [
            {"type": "add_child"},
            {"type": "add_sibling"},
            {"type": "prune_branch"},
            {"type": "duplicate_subtree"}
        ]

    def step(self, action):
        self.steps += 1
        t = self.tree

        # APPLY ACTION EFFECTS
        if action["type"] == "add_child":
            t["nodes"] += 1
            t["depth"] += random.choice([0, 1])
            t["balance"] *= random.uniform(0.95, 1.05)

        elif action["type"] == "add_sibling":
            t["nodes"] += 1
            t["balance"] *= random.uniform(0.9, 1.1)

        elif action["type"] == "prune_branch":
            t["nodes"] = max(1, t["nodes"] - random.randint(1, 3))
            t["balance"] *= random.uniform(0.95, 1.05)

        elif action["type"] == "duplicate_subtree":
            t["nodes"] += random.randint(1, 4)
            t["depth"] += random.choice([0, 1])
            t["balance"] *= random.uniform(0.85, 1.15)

        # Define stability score (EIS)
        stability = math.exp(-abs(t["balance"] - 1.0) * t["depth"])

        # End conditions
        done = False
        if stability < 0.15:   # collapse
            reward = -3
            done = True
        elif t["nodes"] >= 12 and stability > 0.6:
            reward = +5       # successful tree
            done = True
        else:
            reward = stability

        return self._get_state(), reward, done

# Phase 6: Deep Tree Expansion

In [None]:

# Phase 6 — Combined, fixed minimal issues (keeps your exact pipeline logic)

# Install
!pip install sentence-transformers --quiet

# Standard imports
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import KMeans
import warnings
warnings.filterwarnings("ignore")

# ---------------------------
# Load data (you used df earlier) — make data alias so later code works
# ---------------------------
df = pd.read_csv("/content/Updated_Prakriti_With_Features.csv")
df = df.dropna(axis=0, how='any').reset_index(drop=True)
data = df   # <-- minimal fix so 'data' references work
print("✅ Data loaded. Shape:", df.shape)

# ---------------------------
# MPNet + PCA embedding block (kept as you wrote it)
#   -> store into mpnet_column_value_vectors to avoid overwrite
# ---------------------------
print("Phase 6: MPNet embedding + dynamic PCA (kept)")
mpnet_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

def embed_text_list_mpnet(text_list, batch_size=64):
    text_list = [str(x) for x in text_list]
    vecs = mpnet_model.encode(text_list, batch_size=batch_size, show_progress_bar=True)
    return np.array(vecs, dtype=np.float32)

mpnet_column_value_vectors = {}

for col in df.columns:
    # Only process textual / categorical columns
    if df[col].dtype != object:
        continue

    values = list(df[col].unique())
    if not values:
        continue

    # Embed unique values using MPNet
    vectors = embed_text_list_mpnet(values, batch_size=64)

    # Dynamic PCA dimension — prevents error
    n_values = len(values)
    original_dim = vectors.shape[1]
    n_components = max(2, min(50, n_values, original_dim))  # never <2

    pca = PCA(n_components=n_components, random_state=42)
    compressed = pca.fit_transform(vectors)

    mpnet_column_value_vectors[col] = {
        "values": values,
        "vectors": compressed,
        "pca_dim": n_components
    }

print("✅ MPNet feature embeddings generated with dynamic PCA dimensioning (stored in mpnet_column_value_vectors).")

# ---------------------------
# Anchor grouping using TF-IDF over column NAMES (as in your cell)
# ---------------------------
print("\nPhase 6: Column-name concept grouping (TF-IDF + Agglomerative)")

col_names = list(data.columns)
vectorizer = TfidfVectorizer()
name_vectors = vectorizer.fit_transform(col_names).toarray()
name_vectors = normalize(name_vectors)

sim = cosine_similarity(name_vectors)

n_clusters = min(8, len(col_names)//3) if len(col_names) > 8 else 3
clustering = AgglomerativeClustering(n_clusters=n_clusters, metric="precomputed", linkage="average")
labels = clustering.fit_predict(1 - sim)

anchor_map = {}
for col, group_id in zip(col_names, labels):
    anchor_map.setdefault(group_id, []).append(col)

print("✅ Column concept groups formed:\n")
for gid, cols in anchor_map.items():
    print(f"Group {gid}: {cols}")

final_anchors = {
    gid: {
        "anchor_label": f"ANCHOR_{gid}",
        "columns": cols
    }
    for gid, cols in anchor_map.items()
}

print("\n✅ Final Anchors Created:")
for gid, info in final_anchors.items():
    print(f"{info['anchor_label']}: {info['columns']}")

# ---------------------------
# TF-IDF on VALUE strings + PCA -> this will be the column_value_vectors used for vec_pairs
# (keeps your TF-IDF→PCA code; this will overwrite variable column_value_vectors by design in your original flow)
# ---------------------------
print("\nPhase 6: TF-IDF + PCA on column VALUEs (this creates column_value_vectors used for vec_pairs)")

column_value_vectors = {}

for col in data.columns:
    vals = data[col].astype(str).unique()  # unique class values
    if len(vals) == 0:
        continue

    # TF-IDF vectorize the unique value strings for the column
    vectorizer_val = TfidfVectorizer()
    X = vectorizer_val.fit_transform(vals)  # shape: (n_unique_vals, vocab_size)

    # determine pca_dim safely
    pca_dim = min(50, X.shape[1], X.shape[0])  # avoid PCA errors
    if pca_dim < 2:
        # fallback: embed as one-hot padded to 50
        X_dense = np.eye(len(vals), 50, dtype=np.float32)
    else:
        pca_val = PCA(n_components=pca_dim, random_state=42)
        X_dense = pca_val.fit_transform(X.toarray()).astype(np.float32)
        # pad to length 50 if needed
        if X_dense.shape[1] < 50:
            X_dense = np.pad(X_dense, ((0,0), (0, 50-X_dense.shape[1])), mode='constant')

    # normalize rows
    X_dense = normalize(X_dense)

    column_value_vectors[col] = {
        "values": list(vals),
        "vectors": X_dense.astype(np.float32)
    }

print("✅ Column value embedding forms created (variable: column_value_vectors).")

# ---------------------------
# Prepare vec_pairs from final_anchors using column_value_vectors (your original logic)
# ---------------------------
print("\nPhase 6: Build vec_pairs (anchors per column -> mean/padded 50-d vectors)")

import torch
vec_pairs = []
anchor_meta = []

for group_id, info in final_anchors.items():
    cols = info["columns"]
    for col in cols:
        if col not in column_value_vectors:
            print(f"[WARN] Column {col} missing in column_value_vectors — skipping.")
            continue
        vals = column_value_vectors[col]["values"]
        vecs = column_value_vectors[col]["vectors"]  # shape: (n_vals, pca_dim) where pca_dim <= 50

        padded_vecs = []
        for v in vecs:
            v = np.asarray(v, dtype=np.float32)
            if v.ndim == 0:
                v = np.expand_dims(v, 0)
            if v.shape[0] >= 50:
                pv = v[:50]
            else:
                pv = np.pad(v, (0, 50 - v.shape[0]), mode='constant')
            padded_vecs.append(pv)
        if not padded_vecs:
            continue
        padded_vecs = np.array(padded_vecs, dtype=np.float32)

        # cluster these padded vectors to produce anchors per column
        n_vals = padded_vecs.shape[0]
        n_clusters_col = 3 if n_vals >= 6 else (2 if n_vals >= 3 else 1)
        if n_clusters_col == 1:
            mean_vec = padded_vecs.mean(axis=0)
            token = f"{col}::A0"
            vec_pairs.append((token, torch.tensor(mean_vec, dtype=torch.float32)))
            anchor_meta.append((group_id, col, 0, token))
        else:
            kmeans = KMeans(n_clusters=n_clusters_col, random_state=42).fit(padded_vecs)
            for lab in range(n_clusters_col):
                idxs = np.where(kmeans.labels_ == lab)[0].tolist()
                if not idxs:
                    continue
                mean_vec = padded_vecs[idxs].mean(axis=0)
                token = f"{col}::A{lab}"
                vec_pairs.append((token, torch.tensor(mean_vec, dtype=torch.float32)))
                anchor_meta.append((group_id, col, lab, token))

print(f"Prepared {len(vec_pairs)} vec_pairs for global tree build (device=cpu).")

# ---------------------------
# Build tree using TreeBuilderV2 (dim=50 to match TF-IDF→PCA padded vectors)
# ---------------------------
from TreeBuilderV2 import TreeBuilderV2
from TreeNodeV1 import TreeNodeV1

device = "cuda" if torch.cuda.is_available() else "cpu"
builder = TreeBuilderV2(device=device, dim=50, mode="three")  # dim=50 matches padded vectors

prakriti_root = builder.build_tree(vec_pairs, sample_id="prakriti_global")
print("Builder returned root:", prakriti_root.node_id if prakriti_root else None)
print("Depth (builder):", prakriti_root.get_depth() if prakriti_root else "None")
# count_nodes may be defined elsewhere in your notebook; fallback safe count:
def count_nodes_safe(root):
    if root is None:
        return 0
    cnt = 0
    stack = [root]
    seen = set()
    while stack:
        n = stack.pop()
        if id(n) in seen:
            continue
        seen.add(id(n))
        cnt += 1
        for c in getattr(n, "children", []) or []:
            stack.append(c)
    return cnt

print("Node count (calc):", count_nodes_safe(prakriti_root))

prakriti_tree = prakriti_root

# ---------------------------
# Phase 6 — Cell 6: Snapshot
# ---------------------------
from TreeSnapshot import TreeSnapshot
if prakriti_tree is not None:
    snapshot = TreeSnapshot(prakriti_tree).to_dict()
    print("Snapshot Metrics:", snapshot)
else:
    print("No tree produced; snapshot unavailable.")

# ---------------------------
# Optional: Column name / embedding consistency check (keeps your final block behavior)
# ---------------------------
print("\nChecking column name consistency...\n")
dataset_columns = set(df.columns)
embedding_columns = set()  # feature_vectors was not built in this cell; preserve your logic by making it safe

if 'feature_vectors' in globals() and isinstance(feature_vectors, dict):
    embedding_columns = set(feature_vectors.keys())
else:
    embedding_columns = None

print("Columns in dataset:", len(dataset_columns))
print("Columns in embeddings:", len(embedding_columns) if embedding_columns is not None else "feature_vectors missing")

if embedding_columns:
    missing_in_embed = dataset_columns - embedding_columns
    missing_in_data = embedding_columns - dataset_columns

    print("\nMissing in embeddings:", missing_in_embed)
    print("Missing in dataset:", missing_in_data)

    if len(missing_in_embed) > 0:
        print("\n🛠️ REBUILDING feature_vectors with correct column names...")
        feature_vectors = {col: torch.randn(50) for col in dataset_columns}
        print("✅ feature_vectors rebuilt.")
else:
    print("feature_vectors missing — no rebuild attempted.")

✅ Data loaded. Shape: (1200, 30)
Phase 6: MPNet embedding + dynamic PCA (kept)


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

✅ MPNet feature embeddings generated with dynamic PCA dimensioning (stored in mpnet_column_value_vectors).

Phase 6: Column-name concept grouping (TF-IDF + Agglomerative)
✅ Column concept groups formed:

Group 0: ['Body Size', 'Body Weight', 'Height', 'Bone Structure', 'Complexion', 'Eyelashes', 'Cheeks', 'Nose', 'Teeth and gums', 'Lips', 'Nails', 'Appetite', 'Liking tastes', 'Dosha', 'Metabolism Type', 'Climate Preference']
Group 1: ['General feel of skin', 'Texture of Skin', 'Hair Color', 'Appearance of Hair', 'Shape of face', 'Eyes', 'Blinking of Eyes', 'Skin Sensitivity']
Group 7: ['Stress Levels']
Group 6: ['Sleep Patterns']
Group 5: ['Dietary Habits']
Group 4: ['Physical Activity Level']
Group 3: ['Water Intake']
Group 2: ['Digestion Quality']

✅ Final Anchors Created:
ANCHOR_0: ['Body Size', 'Body Weight', 'Height', 'Bone Structure', 'Complexion', 'Eyelashes', 'Cheeks', 'Nose', 'Teeth and gums', 'Lips', 'Nails', 'Appetite', 'Liking tastes', 'Dosha', 'Metabolism Type', 'Climate P

# Phase 6.5: Stability Correction

In [None]:

# Phase 6.5 Helpers (run once)
import copy, random, traceback, numpy as np
from collections import defaultdict, deque
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# Cycle-safe traversals & utilities
def collect_all_nodes_safe(root):
    if root is None: return []
    out, q, seen = [], deque([root]), set()
    while q:
        n = q.popleft()
        if id(n) in seen: continue
        seen.add(id(n)); out.append(n)
        for c in getattr(n, "children", []) or []: q.append(c)
    return out

def collect_leaves_safe(root):
    return [n for n in collect_all_nodes_safe(root) if not (getattr(n, "children", None))]

def count_nodes_safe(root):
    return len(collect_all_nodes_safe(root))

def get_tree_depth_safe(root):
    if root is None: return 0
    q, seen = deque([(root,0)]), set(); maxd=0
    while q:
        n,d = q.popleft()
        if id(n) in seen: continue
        seen.add(id(n)); maxd = max(maxd, d)
        for c in getattr(n, "children", []) or []: q.append((c,d+1))
    return maxd+1

def reset_levels_safe(root):
    from collections import deque
    if root is None: return
    q, seen = deque([(root,0)]), set()
    while q:
        n, lv = q.popleft()
        if id(n) in seen: continue
        seen.add(id(n)); n.level = lv
        for c in getattr(n, "children", []) or []: q.append((c, lv+1))

def detect_cycle(root):
    if root is None: return False
    q, seen = deque([root]), set()
    while q:
        n = q.popleft()
        if id(n) in seen: return True
        seen.add(id(n))
        for c in getattr(n, "children", []) or []: q.append(c)
    return False

def copy_tree_remove_cycles(root):
    # Shallow-copy nodes, avoid linking child that points to ancestor id
    if root is None: return None
    from copy import deepcopy
    new_root = type(root)(node_id=root.node_id, value=root.value, level=root.level)
    stack = [(root, new_root, set([id(root)]))]
    while stack:
        orig, newn, ancestors = stack.pop()
        for c in getattr(orig, "children", []) or []:
            if id(c) in ancestors:
                continue
            cn = type(c)(node_id=c.node_id, value=c.value, level=newn.level+1)
            # copy cached_vector if present
            if hasattr(c, "cached_vector"):
                try: cn.cached_vector = copy.deepcopy(c.cached_vector)
                except: cn.cached_vector = None
            newn.add_child(cn)
            new_anc = set(ancestors); new_anc.add(id(c))
            stack.append((c, cn, new_anc))
    reset_levels_safe(new_root)
    return new_root

print("Phase 6.5 helpers loaded.")

Phase 6.5 helpers loaded.


In [None]:

# Pass 1: Balanced deepening (high granularity)
from sklearn.cluster import KMeans
import numpy as np

root = prakriti_tree  # from your Phase 6 result
if detect_cycle(root):
    print("Cycle found — cleaning first...")
    root = copy_tree_remove_cycles(root)

leaves = [n for n in collect_all_nodes_safe(root) if getattr(n, "cached_vector", None) is not None or getattr(n, "value", None)]
# collect vectors for those leaves (if cached_vector missing we will skip)
leaf_objs = []
leaf_vecs = []
for n in leaves:
    vec = getattr(n, "cached_vector", None)
    if vec is None:
        # fallback: try to create vector from name via TF-IDF (light)
        leaf_text = str(n.value)
        # simple char-level fallback vector
        arr = np.array([ord(ch)%100 for ch in leaf_text[:50]], dtype=np.float32)
        if arr.size < 50:
            arr = np.pad(arr, (0, 50-arr.size))
        vec = arr
    else:
        # ensure it's numpy
        try:
            vec = np.array(vec[:50], dtype=np.float32)
        except Exception:
            vec = np.zeros(50, dtype=np.float32)
    leaf_objs.append(n); leaf_vecs.append(vec)

leaf_vecs = np.stack(leaf_vecs)
leaf_count = len(leaf_vecs)
target_clusters = min(max(12, leaf_count//2), leaf_count)  # aggressive high granularity
if target_clusters < 2: target_clusters = 2

kmeans = KMeans(n_clusters=target_clusters, random_state=42).fit(leaf_vecs)
labels = kmeans.labels_

# Build new root and mid groups (clone nodes to avoid cycles)
new_root = type(root)(node_id=root.node_id, value=root.value, level=0)
groups = {}
for lab, obj in zip(labels, leaf_objs):
    groups.setdefault(lab, []).append(obj)

import copy
for lab, members in groups.items():
    gnode = type(root)(node_id=f"mid_g_{lab}", value=f"MidGroup_{lab}", level=1)
    for m in members:
        gnode.add_child(copy.deepcopy(m))
    new_root.add_child(gnode)

reset_levels_safe(new_root)
prakriti_tree = new_root
print("Pass1 complete. Depth:", get_tree_depth_safe(prakriti_tree), "Node count:", count_nodes_safe(prakriti_tree))

Pass1 complete. Depth: 4 Node count: 155


In [None]:
# Pass 2: Strict hierarchical clustering inside mid-groups
from sklearn.cluster import AgglomerativeClustering
root = prakriti_tree
for g in list(getattr(root, "children", [])):
    # take its child leaves (direct children might be cloned value-nodes)
    direct = [c for c in getattr(g, "children", [])]
    if len(direct) <= 4:
        continue
    vecs = []
    nodes_list = []
    for c in direct:
        v = getattr(c, "cached_vector", None)
        if v is None:
            # try to derive a lightweight vector from text
            arr = np.array([ord(ch)%100 for ch in (c.value or "")[:50]], dtype=np.float32)
            if arr.size < 50: arr = np.pad(arr, (0, 50-arr.size))
            v = arr
        nodes_list.append(c); vecs.append(np.array(v[:50], dtype=np.float32))
    if len(vecs) <= 3: continue
    n_sub = min(max(2, len(vecs)//3), 8)
    agg = AgglomerativeClustering(n_clusters=n_sub, metric='euclidean', linkage='ward')
    labels = agg.fit_predict(np.stack(vecs))
    # create subnodes
    from collections import defaultdict
    submap = defaultdict(list)
    for node_obj, lab in zip(nodes_list, labels):
        submap[lab].append(node_obj)
    new_subnodes = []
    for lab, members in submap.items():
        subnode = type(g)(node_id=f"{g.node_id}_sub_{lab}", value=f"{g.value}_sub{lab}", level=g.level+1)
        for m in members:
            subnode.add_child(copy.deepcopy(m))
        new_subnodes.append(subnode)
    # append any other children not in nodes_list
    others = [c for c in g.children if c not in nodes_list]
    g.children = new_subnodes + others

reset_levels_safe(root)
prakriti_tree = root
print("Pass2 complete. Depth:", get_tree_depth_safe(prakriti_tree), "Node count:", count_nodes_safe(prakriti_tree))

Pass2 complete. Depth: 4 Node count: 164


In [None]:

# Pass 3: Semantic label refinement (human friendly)
from sklearn.feature_extraction.text import TfidfVectorizer
root = prakriti_tree

def extract_terms_from_leaves(node, top_k=2):
    leaves = [str(l.value) for l in collect_leaves_safe(node)]
    if not leaves: return node.value
    corpus = [" ".join(leaves)]
    vec = TfidfVectorizer(stop_words='english')
    try:
        X = vec.fit_transform(corpus)
        fn = np.array(vec.get_feature_names_out())
        if fn.size == 0: return node.value
        terms = fn[:top_k].tolist()
    except Exception:
        terms = (node.value or "").split()[:top_k]
    return f"{node.value} — {' '.join(terms)}"

for n in collect_all_nodes_safe(root):
    if getattr(n, "children", None) and len(n.children) >= 2 and n.level <= get_tree_depth_safe(root)-2:
        try:
            n.value = extract_terms_from_leaves(n, top_k=2)
        except Exception:
            pass

reset_levels_safe(root)
prakriti_tree = root
print("Pass3 complete. Depth:", get_tree_depth_safe(prakriti_tree), "Node count:", count_nodes_safe(prakriti_tree))

Pass3 complete. Depth: 4 Node count: 164


In [None]:

# Stability test + normalization
def random_action_apply(root, n_steps=200):
    import random
    import copy
    rcopy = copy.deepcopy(root)
    failures = []
    snaps = []
    for i in range(n_steps):
        # sample random internal node (not root)
        nodes = [n for n in collect_all_nodes_safe(rcopy) if n.node_id != rcopy.node_id]
        if not nodes: break
        tgt = random.choice(nodes)
        op = random.choice(["prune","split","reorder","lock","unlock"])
        try:
            if op == "prune":
                tgt.children = []
            elif op == "split":
                a = type(tgt)(node_id=f"{tgt.node_id}_s1", value="split1", level=tgt.level+1)
                b = type(tgt)(node_id=f"{tgt.node_id}_s2", value="split2", level=tgt.level+1)
                tgt.children = [a,b]
            elif op == "reorder" and len(tgt.children)>1:
                random.shuffle(tgt.children)
            elif op == "lock":
                if hasattr(tgt, "lock"): tgt.lock("soft")
            elif op == "unlock":
                if hasattr(tgt, "unlock"): tgt.unlock()
            if 'TreeSnapshot' in globals():
                snaps.append(TreeSnapshot(rcopy).to_dict())
            # normalization
            reset_levels_safe(rcopy)
        except Exception as e:
            failures.append((i, str(e)))
            break
    return rcopy, snaps, failures

sim_tree, sim_snaps, sim_failures = random_action_apply(prakriti_tree, n_steps=300)
print("Stability failures:", len(sim_failures), "Snapshots:", len(sim_snaps))
try:
    final_snap = TreeSnapshot(sim_tree).to_dict()
    print("Final snapshot after sim:", final_snap)
except Exception as e:
    print("Snapshot error:", e)

# normalization passes: prune tiny leaves and enforce fanout/depth
def prune_weak_leaves(root, min_tokens=0):
    changed=False
    for n in reversed(collect_all_nodes_safe(root)):
        if n.children: continue
        tokens = [t for t in str(n.value).split(",") if t.strip()]
        if len(tokens) <= min_tokens:
            # find parent and remove
            for p in collect_all_nodes_safe(root):
                if n in getattr(p, "children", []):
                    p.children = [c for c in p.children if c is not n]; changed=True; break
    return changed

iter=0
while iter<3:
    prune_weak_leaves(prakriti_tree, min_tokens=0)
    reset_levels_safe(prakriti_tree)
    iter+=1

try:
    saved_snap = TreeSnapshot(prakriti_tree).to_dict()
    print("Normalization complete. Snapshot:", saved_snap)
except Exception as e:
    print("Could not produce final snapshot:", e)

# Optionally save artifacts
import os, json
out_dir = "/content/phase6_artifacts_high"
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "final_tree_snapshot.json"), "w") as f:
    json.dump(saved_snap, f, indent=2)
print("Artifacts saved to", out_dir)

Stability failures: 0 Snapshots: 300
Final snapshot after sim: {'depth': 6, 'node_count': 199, 'leaf_count': 138, 'branching_factor': 3.246, 'entropy': 4.3171, 'weak_leaves': 138, 'branch_flip_rate': 0.0}
Normalization complete. Snapshot: {'depth': 3, 'node_count': 164, 'leaf_count': 122, 'branching_factor': 3.881, 'entropy': 4.7916, 'weak_leaves': 122, 'branch_flip_rate': 0.0}
Artifacts saved to /content/phase6_artifacts_high


In [None]:

# Ensure local .py modules can be imported
import sys, importlib, os
if '/content' not in sys.path:
    sys.path.append('/content')

# Import RL helpers
from check_action_allowed import check_action_allowed
from apply_action import apply_action
from compute_reward import compute_reward

# The function in generative_decision_loop_safe.py is named `generative_decision_loop`
# Phase2Env expects `generative_decision_loop_safe`, so alias it on import:
from generative_decision_loop_safe import generative_decision_loop as generative_decision_loop_safe

# (Optional) hot-reload if you edit files frequently
importlib.reload(sys.modules['check_action_allowed'])
importlib.reload(sys.modules['apply_action'])
importlib.reload(sys.modules['compute_reward'])
importlib.reload(sys.modules['generative_decision_loop_safe'])

print("✅ Helpers imported & aliased.")

✅ Helpers imported & aliased.


In [None]:

prev_snap = TreeSnapshot(prakriti_tree).to_dict()

entropy = prev_snap.get("entropy", 0.0)
stability = 1.0 - prev_snap.get("branch_flip_rate", 0.0) if "branch_flip_rate" in prev_snap else 1.0
depth_ratio = prev_snap.get("depth", 1) / max(1, prev_snap.get("node_count", 1))
EIS = 0.4*entropy + 0.3*(1-stability) + 0.3*depth_ratio
print("EIS:", EIS)

# create test new snapshot
new_snap = dict(prev_snap)
new_snap["entropy"] = prev_snap["entropy"] - 0.1

# **Add missing required fields**
new_snap["pos_score"] = 0.0
new_snap["neg_score"] = 0.0

tickets = {"G":0,"B":0,"Y":0,"R":0,"P":0}
temp_tickets = {"G":0,"B":0,"Y":0,"R":0}
decay_q = []

reward, tickets, failure_memory, log = compute_reward(
    prev_snap, new_snap, None, tickets, {}, temp_tickets, decay_q
)

print("Reward test:", reward)
print("Tickets:", tickets)
print("Log:", log)

EIS: 1.9221278048780488
Reward test: -0.752
Tickets: {'G': 0, 'B': 1, 'Y': 0, 'R': 0, 'P': 0}
Log: {'EIS': 1.882, 'IS': 0.376, 'coeffs': {'alpha': 1.565, 'beta': 0.624, 'gamma': 1.941, 'delta': 2.129}, 'tickets': {'G': 0, 'B': 1, 'Y': 0, 'R': 0, 'P': 0}, 'temp_tickets': {'G': 0, 'B': 0, 'Y': 0, 'R': 0}, 'outcome': 'neutral', 'pos': 0.0, 'neg': 0.0, 'lesson': 0, 'insight': 0, 'reward': -0.752}


# Phase 7: RL Training

In [None]:

# Phase 7 — Cell 1: Build (data_bin, feature_vectors) for Phase2Env
import numpy as np
import pandas as pd
import torch

assert 'df' in globals(), "df missing (Phase 6 output)."
assert 'column_value_vectors' in globals(), "column_value_vectors missing (Phase 6 value embeddings)."

# 1) build the token universe and the feature_vectors dict (token -> 50-d vector)
token_list = []
feature_vectors = {}
for col, pack in column_value_vectors.items():
    vals = pack["values"]
    vecs = pack["vectors"]  # 2D np array, padded/trimmed to 50 in Phase 6
    for v, vv in zip(vals, vecs):
        tok = f"{col}:{v}"
        token_list.append(tok)
        feature_vectors[tok] = torch.tensor(vv[:50], dtype=torch.float32)

token_list = sorted(set(token_list))
token_index = {t:i for i,t in enumerate(token_list)}
print(f"✅ Tokens built: {len(token_list)}")

# 2) build a binary matrix rows x tokens
rows = []
for _, row in df.iterrows():
    binrow = np.zeros(len(token_list), dtype=np.int8)
    for col in df.columns:
        val = str(row[col])
        tok = f"{col}:{val}"
        if tok in token_index:
            binrow[token_index[tok]] = 1
    rows.append(binrow)

data_bin = pd.DataFrame(rows, columns=token_list)
print("✅ data_bin shape:", data_bin.shape)

# Keep a tiny view
print(data_bin.head(3).iloc[:, : min(10, data_bin.shape[1])])

✅ Tokens built: 93
✅ data_bin shape: (1200, 93)
   Appearance of Hair:Dry, black, knotted, brittle  \
0                                                0   
1                                                1   
2                                                1   

   Appearance of Hair:Straight, oily  Appearance of Hair:Thick, curly  \
0                                  1                                0   
1                                  0                                0   
2                                  0                                0   

   Appetite:Irregular, Scanty  Appetite:Slow but steady  \
0                           0                         1   
1                           0                         1   
2                           0                         1   

   Appetite:Strong, Unbearable  Blinking of Eyes:Excessive Blinking  \
0                            0                                    0   
1                            0                                 

In [None]:

def reset_global_env(env, global_tree):
    import copy
    env.tree = copy.deepcopy(global_tree)  # freeze master tree
    env.snapshot = TreeSnapshot(env.tree).to_dict()
    env.steps = 0

    # seed tickets
    env.tickets = {"G": 5, "B": 10, "Y": 3, "R": 0, "P": 0}
    env.temp_tickets = {"G":0,"B":0,"Y":0,"R":0}
    env.decay_queue = []

    return env._snapshot_to_obs(env.snapshot)

In [None]:

#def seed_env_tickets(env):
   # env.tickets = {"G": 5, "B": 8, "Y": 2, "R": 0, "P": 0}

In [None]:

# Phase 7 — Cell 2: Environment wiring
import torch
from TreeBuilderV2 import TreeBuilderV2
from Phase2Env import Phase2Env  # your existing env
from TreeSnapshot import TreeSnapshot

device = "cuda" if torch.cuda.is_available() else "cpu"

# IMPORTANT: Phase2Env.builder.dim must match the vectors used inside env->build_tree.
# Our feature_vectors are 50-D (TF-IDF+PCA padded), so:
builder = TreeBuilderV2(device=device, dim=50, mode="three")

# Construct env with data_bin & feature_vectors
env = Phase2Env(builder, data_bin, feature_vectors, max_edits=20)
print("✅ RL Environment ready.")

# quick reset
obs = env.reset( idx=np.random.randint(0, len(data_bin)) )
print("First obs:", obs, "shape:", obs.shape)

✅ RL Environment ready.
First obs: [ 1. 31. 30. 30.  0. 30.] shape: (6,)


In [None]:

# --- Put this in a new cell ---

import random
from collections import deque

def seed_env_tickets(env, B=8, G=5, Y=2, R=0, P=0):
    env.tickets.update({"B": B, "G": G, "Y": Y, "R": R, "P": P})

def sample_valid_action_basic(tree):
    """Always returns a concrete action; lets env.step() enforce gates."""
    # collect all nodes
    nodes = []
    q = deque([tree])
    while q:
        n = q.popleft()
        nodes.append(n)
        for c in getattr(n, "children", []) or []:
            q.append(c)

    # fallback to root if somehow empty
    if not nodes:
        return ("prune", getattr(tree, "node_id", None), None)

    tgt = random.choice(nodes)
    op  = random.choice(["split", "prune", "lock", "unlock"])  # (skip reorder for now)
    extra = "soft" if op == "lock" else None
    return (op, getattr(tgt, "node_id", None), extra)

In [None]:

# Phase 7 — Cell 3: Sanity rollout with random valid actions (root-protected)
import random
import copy
import numpy as np
from collections import deque

ROOT_IDS = {"prakriti_global"}  # protect the real root

def sample_valid_action(tree, tickets):
    def is_protected(node_id):
        return node_id in ROOT_IDS

    nodes = []
    q = deque([tree])
    while q:
        n = q.popleft()
        # skip protected nodes
        if not is_protected(getattr(n, "node_id", None)):
            nodes.append(n)
        for c in getattr(n, "children", []) or []:
            q.append(c)

    if not nodes:
        return ("noop", None, None)  # nothing safe to do

    for _ in range(40):
        tgt = random.choice(nodes)
        op = random.choice(["split","prune","reorder","lock","unlock"])
        extra = "soft" if op=="lock" else None
        action = (op, tgt.node_id, extra)

        if "check_action_allowed" in globals():
            ok, cost, reason = check_action_allowed(action, tickets)
            if ok:
                return action
        else:
            return action

    return ("noop", None, None)


def run_random_episode(env, max_steps=10):
    # reset to the global deep tree
    #obs = reset_global_env(env, prakriti_tree)
    obs = env.reset(idx=np.random.randint(0, len(data_bin)))
    seed_env_tickets(env)   # IMPORTANT
    total_reward = 0.0
    info = {}

    for step in range(max_steps):
        action = sample_valid_action(env.tree, env.tickets)
        if action[0] == "noop":
            break
        obs, reward, done, info = env.step(action)
        total_reward += reward
        if done:
            break
    return total_reward, info


# run episodes
scores = []
for ep in range(20):
    R, info = run_random_episode(env, max_steps=12)
    scores.append(R)
    print(f"Episode {ep+1}: total_reward={R:.3f} info={info.get('log',{})}")

print("Random-policy mean reward:", float(np.mean(scores)))

Episode 1: total_reward=-0.025 info={'action': 'split', 'target': 'row825_12_1', 'extra': None, 'status': 'ok', 'EIS': 0.65, 'IS': 0.13, 'coeffs': {'alpha': 1.195, 'beta': 0.87, 'gamma': 1.325, 'delta': 1.39}, 'tickets': {'G': 3, 'B': 21, 'Y': 2, 'R': 0, 'P': 0}, 'temp_tickets': {'G': 0, 'B': 0, 'Y': 0, 'R': 0}, 'outcome': 'neutral', 'pos': 0.16331, 'neg': 0.117, 'lesson': 0, 'insight': 0, 'reward': -0.166619}
Episode 2: total_reward=0.917 info={'action': 'reorder', 'target': 'row1056_24_0_R', 'extra': None, 'status': 'fail', 'EIS': 0.018, 'IS': 0.004, 'coeffs': {'alpha': 1.005, 'beta': 0.996, 'gamma': 1.009, 'delta': 1.011}, 'tickets': {'G': 3, 'B': 21, 'Y': 2, 'R': 0, 'P': 0}, 'temp_tickets': {'G': 0, 'B': 0, 'Y': 0, 'R': 0}, 'outcome': 'neutral', 'pos': 0.0, 'neg': 0.0, 'lesson': 0, 'insight': 0, 'reward': -0.008}
Episode 3: total_reward=-1.115 info={'action': 'reorder', 'target': 'row1087_18_2', 'extra': None, 'status': 'fail', 'EIS': 0.943, 'IS': 0.189, 'coeffs': {'alpha': 1.283, 

In [None]:

def greedy_action(env):
    best_reward = -1e9
    best_action = None

    #def is_protected_node(node_id):
        #return node_id in ["prakriti_global", "root", "prakriti_root"]

    # enumerate candidate actions
    actions = []
    from collections import deque
    q = deque([env.tree])
    nodes = []
    while q:
        n = q.popleft()
        nodes.append(n)
        for c in getattr(n, "children", []) or []:
            q.append(c)

    ops = ["split", "prune", "lock", "unlock"]
    for node in nodes:
        for op in ops:
            extra = "soft" if op=="lock" else None
            action = (op, node.node_id, extra)

            ok, cost, reason = check_action_allowed(action, env.tickets)
            if not ok:
                continue

            # simulate step without modifying real tree
            snap_before = TreeSnapshot(env.tree).to_dict()
            tree_copy = copy.deepcopy(env.tree)
            tickets_copy = env.tickets.copy()

            log = apply_action(tree_copy, action)
            snap_after = TreeSnapshot(tree_copy).to_dict()

            reward, _, _, _ = compute_reward(
                snap_before, snap_after,
                mode="train",
                tickets=tickets_copy,
                failure_memory={},
                temp_tickets={},
                decay_queue=[]
            )

            if reward > best_reward:
                best_reward = reward
                best_action = action

    return best_action

In [None]:

def run_greedy_episode(env, max_steps=10):
    obs = reset_global_env(env, prakriti_tree)
    total_reward = 0
    for step in range(max_steps):
        action = greedy_action(env)
        if action is None:
            break
        obs, reward, done, info = env.step(action)
        total_reward += reward
        if done:
            break
    return total_reward, info

scores = []
for ep in range(10):
    R, info = run_greedy_episode(env)
    scores.append(R)
    print("Episode", ep+1, "Reward:", R)

print("Greedy mean reward:", np.mean(scores))

In [None]:

# Phase 7 — Cell 4: Greedy baseline via A-CRES gating + quick scoring
def candidate_set(tree):
    # small candidate pool per step
    from collections import deque
    nodes = []
    q = deque([tree])
    while q:
        n = q.popleft()
        nodes.append(n)
        for c in getattr(n, "children", []) or []:
            q.append(c)
    if not nodes: return []
    cands = []
    for _ in range(12):
        tgt = random.choice(nodes)
        op = random.choice(["split","prune","reorder","lock","unlock"])
        extra = "soft" if op=="lock" else None
        cands.append((op, tgt.node_id, extra))
    return cands

def greedy_episode(env, max_steps=12):
    obs = env.reset(idx=np.random.randint(0, len(data_bin)))
    total_R = 0.0
    for step in range(max_steps):
        cands = candidate_set(env.tree)
        # filter by gating
        allowed = []
        for a in cands:
            if "check_action_allowed" in globals():
                ok, cost, reason = check_action_allowed(a, env.tickets)
                if ok: allowed.append((a, cost))
            else:
                allowed.append((a, {}))
        if not allowed:
            # fallback
            a = ("prune", getattr(env.tree,"node_id",None), None)
        else:
            # rudimentary shape score using EIS delta heuristic on a copy
            best = None; best_score = -1e9
            for (a, _) in allowed:
                # just prefer structure-shaping ops
                op = a[0]
                score = {"split": +2.0, "reorder": +0.5, "lock": +0.1, "unlock": +0.1, "prune": -0.1}.get(op, 0.0)
                if score > best_score:
                    best = a; best_score = score
            a = best
        obs, reward, done, info = env.step(a)
        total_R += reward
        if done: break
    return total_R

greedy_scores = [greedy_episode(env) for _ in range(10)]
print("Greedy baseline mean reward:", float(np.mean(greedy_scores)))

Greedy baseline mean reward: -0.24034649999999963


In [None]:

# Phase 7 — Cell 5: Evaluation metrics + explanation

# snapshot BEFORE a greedy episode
_ = env.reset(idx=np.random.randint(0, len(data_bin)))
snap_before = TreeSnapshot(env.tree).to_dict()

_ = greedy_episode(env)
snap_after = TreeSnapshot(env.tree).to_dict()

def pick(x,k): return {kk:x.get(kk) for kk in k}
keys = ["depth","node_count","leaf_count","branching_factor","entropy","weak_leaves"]
print("Before:", pick(snap_before, keys))
print("After :", pick(snap_after,  keys))

# Build a human explanation
try:
    expl = build_explanation(snap_after, final_anchors, retrievals=[])
    print("\n=== Explanation (sample) ===\n", expl[:2000])
except Exception as e:
    print("Explanation error:", e)

Before: {'depth': 1, 'node_count': 31, 'leaf_count': 30, 'branching_factor': 30.0, 'entropy': 0.0, 'weak_leaves': 30}
After : {'depth': 2, 'node_count': 5, 'leaf_count': 3, 'branching_factor': 2.0, 'entropy': -0.0, 'weak_leaves': 3}

=== Explanation (sample) ===
 Node None anchored as None; entropy 0.00, branching 0.


# Phase 8(Clean version) (use this only)

In [None]:

# 1) No failure_sig → neutral + blue budget
prev = TreeSnapshot(prakriti_tree).to_dict()
new  = dict(prev)
new["pos_score"] = 0.4
tickets = {"G":0,"B":0,"Y":0,"R":0,"P":0}
temp   = {"G":0,"B":0,"Y":0,"R":0}
decay  = []
r, tix, fm, log = compute_reward(prev, new, None, tickets, {}, temp, decay)
print("r:", r, "| outcome:", log["outcome"], "| EIS:", log["EIS"])

# 2) New failure cluster (unique_fail) → yellow ticket
new2 = dict(prev)
new2["failure_sig"] = "E:dup-branch"
r2, tix2, fm2, log2 = compute_reward(prev, new2, None, tix, fm, temp, decay)
print("outcome:", log2["outcome"], "| Y:", tix2["Y"], "| IS:", log2["IS"])

# 3) Repeat failure (repeat_fail) → red ticket grows
new3 = dict(prev)
new3["failure_sig"] = "E:dup-branch"
r3, tix3, fm3, log3 = compute_reward(prev, new3, None, tix2, fm2, temp, decay)
print("outcome:", log3["outcome"], "| R:", tix3["R"], "| IS:", log3["IS"])

# 4) Resolved cluster with good proposed_fix → purple ticket
fix = {"quality": 0.95}
new4 = dict(prev); new4["failure_sig"] = "E:dup-branch"
r4, tix4, fm4, log4 = compute_reward(prev, new4, None, tix3, fm3, temp, decay, proposed_fix=fix)
print("outcome:", log4["outcome"], "| P:", tix4["P"])

r: -0.137345 | outcome: neutral | EIS: 1.922
outcome: unique_fail | Y: 2 | IS: 0.0
outcome: repeat_fail | R: 1 | IS: 1.047
outcome: resolved_cluster | P: 1


In [None]:

teacher_pos3 = []
experience_neg3 = []

In [None]:

%%writefile phase8_clean.py
# ===============================
# PHASE 8 — CLEAN (self-contained)
# ===============================
import os, pickle, copy, random
import numpy as np
import torch, torch.nn as nn
from collections import deque
from TLiteComponents import TLiteV6
from tlite_rl_bridge import encode_state_to_vector, encode_action_to_vector
from check_action_allowed import check_action_allowed
from TreeSnapshot import TreeSnapshot
from apply_action import apply_action

# -------------------
# Device & policy
# -------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

def init_policy(existing=None, dim=50, lr=1e-4):
    """
    Reuse an existing TLiteV6 if provided, else create a fresh one.
    Returns (policy, optimizer, mse_loss)
    """
    if existing is None:
        policy = TLiteV6(dim=dim, device=device).to(device)
    else:
        policy = existing
    opt = torch.optim.Adam(policy.parameters(), lr=lr)
    mse = nn.MSELoss()
    return policy, opt, mse

# -------------------
# Simple teacher score
# -------------------
def teacher_score(prev, new):
    """
    Higher is better.
    Entropy ↓, weak_leaves ↓, branching_factor → 3.0
    """
    dE = prev["entropy"] - new["entropy"]
    dW = prev["weak_leaves"] - new["weak_leaves"]
    dB = abs(3.0 - new["branching_factor"])
    return 1.4*dE + 1.6*dW - 0.3*dB

# -------------------
# Legal actions (structural only)
# -------------------
ROOT_IDS = {"prakriti_global", "root", "prakriti_root"}
STRUCT_OPS = ("split", "prune")  # cheap & safe

def _is_protected(nid): return nid in ROOT_IDS
def _is_leaf(n): return not getattr(n, "children", [])


def legal_candidates_clean(tree, tickets, ops=STRUCT_OPS):
    cands, q = [], deque([tree])
    while q:
        n = q.popleft()
        nid = getattr(n, "node_id", None)
        if _is_protected(nid):
            for ch in getattr(n, "children", []) or []:
                q.append(ch)
            continue
        for op in ops:
            a = (op, nid, None)
            ok, _, _ = check_action_allowed(a, tickets)
            if not ok:
                continue
            if op == "prune" and _is_leaf(n):  # skip pointless prune
                continue
            cands.append(a)
        for ch in getattr(n, "children", []) or []:
            q.append(ch)
    return cands

# -------------------
# Clone + evaluate gain offline
# -------------------
def clone_tree(root):
    if root is None: return None
    cloned = type(root)(
        node_id=root.node_id, value=root.value, children=[],
        level=root.level, rule=root.rule, confidence=root.confidence,
        lock_flag=root.lock_flag, provenance=root.provenance,
        difficulty_tag=root.difficulty_tag, branch_tag=root.branch_tag,
    )
    if getattr(root, "cached_vector", None) is not None:
        try: cloned.cached_vector = copy.deepcopy(root.cached_vector)
        except Exception: cloned.cached_vector = None
    for ch in getattr(root, "children", []) or []:
        cloned.add_child(clone_tree(ch))
    return cloned

def snapshot_of(root): return TreeSnapshot(root).to_dict()

def eval_action_gain(snapshot_before, tree_root, action):
    try:
        tmp = clone_tree(tree_root)
        log = apply_action(tmp, action)
        if log.get("status") in ("ok", "unknown_action", "fail"):
            snap_after = snapshot_of(tmp)
            return teacher_score(snapshot_before, snap_after), snap_after
    except Exception:
        pass
    return None, None

# -------------------
# STOP-aware actor
# -------------------
def select_action_stop_aware(snapshot, env, policy, score_floor=0.25, temperature=1.2, epsilon=0.15):
    cands = legal_candidates_clean(env.tree, env.tickets)
    if not cands: return None

    if np.random.rand() < epsilon:
        return random.choice(cands)

    scores = []
    with torch.no_grad():
        for a in cands:
            v = encode_action_to_vector(a)
            s = policy(torch.tensor(v, dtype=torch.float32, device=device).unsqueeze(0)).item()
            # tiny heuristic nudge
            bf = snapshot.get("branching_factor", 0.0)
            ent = snapshot.get("entropy", 0.0)
            if a[0] == "split" and bf < 3.0: s += 0.10
            if a[0] == "prune" and (bf > 6.0 or ent > 1.0): s += 0.10
            scores.append(s)

    if not scores:
        return None
    best = max(scores)
    if best <= score_floor:
        return None  # STOP

    z = np.array(scores, dtype=np.float64) / max(1e-6, float(temperature))
    z -= z.max()
    p = np.exp(z); p /= p.sum()
    idx = int(np.random.choice(len(cands), p=p))
    return cands[idx]

def run_episode_stop_aware(env, global_tree, policy, max_steps=10, score_floor=0.25):
    _ = env.reset(use_global_tree=True, global_tree=global_tree)
    env.tickets.update({"B":20,"G":6,"Y":2,"R":0,"P":0})
    total = 0.0
    for _ in range(max_steps):
        a = select_action_stop_aware(env.snapshot, env, policy, score_floor=score_floor)
        if a is None: break
        _, r, done, _ = env.step(a)
        total += float(r)
        if done: break
    return total

# -------------------
# Buffers (state, action, label)
# -------------------
teacher_pos3 = []     # (s_vec, a_vec, gain) including STOP-positive labels
experience_neg3 = []  # (s_vec, a_vec, reward<0)

def migrate_old_buffers():
    """
    If older globals teacher_pos / experience_neg exist in the notebook,
    normalize them into (s_vec, a_vec, label) triplets.
    """
    def _zeros_state_vec(dim=64): return np.zeros(dim, dtype=np.float32)
    added_t = added_n = 0
    old_t = globals().get("teacher_pos", [])
    old_n = globals().get("experience_neg", [])
    for item in old_t:
        if len(item) == 3:
            s,a,g = item
            teacher_pos3.append((np.asarray(s, np.float32), np.asarray(a, np.float32), float(g))); added_t += 1
        elif len(item) == 2:
            a,g = item
            teacher_pos3.append((_zeros_state_vec(), np.asarray(a, np.float32), float(g))); added_t += 1
    for item in old_n:
        if len(item) == 3:
            s,a,r = item
            experience_neg3.append((np.asarray(s, np.float32), np.asarray(a, np.float32), float(r))); added_n += 1
        elif len(item) == 2:
            a,r = item
            experience_neg3.append((_zeros_state_vec(), np.asarray(a, np.float32), float(r))); added_n += 1
    print(f"↻ Migrated teacher_pos: +{added_t} | experience_neg: +{added_n}")

# -------------------
# Data collection
# -------------------
def collect_teacher_strong(env, global_tree, policy, episodes=6, max_steps=6, beam=20, min_gain=1e-6, stop_label=1.5):
    """
    Adds (s,a,gain) for actions with positive true teacher gain.
    If no candidate improves (or no actions), adds a STOP-positive label.
    """
    added = 0
    for _ in range(episodes):
        _ = env.reset(use_global_tree=True, global_tree=global_tree)
        env.tickets.update({"B":20,"G":6,"Y":2})
        for _ in range(max_steps):
            snap0 = dict(env.snapshot)
            cands = legal_candidates_clean(env.tree, env.tickets)

            if not cands:
                teacher_pos3.append((encode_state_to_vector(snap0),
                                     encode_action_to_vector(("STOP", None, None)),
                                     float(stop_label)))
                added += 1
                break

            # Rank by current actor
            with torch.no_grad():
                scored = []
                for a in cands:
                    v = encode_action_to_vector(a)
                    s = policy(torch.tensor(v, dtype=torch.float32, device=device).unsqueeze(0)).item()
                    scored.append((a, s))
                scored.sort(key=lambda x: x[1], reverse=True)
                top = [a for (a,_) in scored[:beam]]

            # Evaluate true gains
            gains = []
            for a in top:
                g, _ = eval_action_gain(snap0, env.tree, a)
                gains.append((a, g))

            # Pick best
            best_action, best_gain = max(gains, key=lambda x: (x[1] if x[1] is not None else -1e9))

            # If nothing helps, label STOP
            if best_gain is None or best_gain <= 0:
                teacher_pos3.append((encode_state_to_vector(snap0),
                                     encode_action_to_vector(("STOP", None, None)),
                                     float(stop_label)))
                added += 1
                break

            # Otherwise add all positive gains
            s_vec0 = encode_state_to_vector(snap0)
            for a, g in gains:
                if g is not None and g > min_gain:
                    teacher_pos3.append((s_vec0, encode_action_to_vector(a), float(g)))
                    added += 1

            # Move with best-gain action
            env.step(best_action)
    print(f"🧑‍🏫 Strong teacher positives: +{added} (total {len(teacher_pos3)})")

def collect_experience_negatives_clean(env, global_tree, policy=None, episodes=6, max_steps=6):
    """
    Adds (s,a,reward) only when reward < 0 (strict negatives).
    """
    added = 0
    for _ in range(episodes):
        _ = env.reset(use_global_tree=True, global_tree=global_tree)
        env.tickets.update({"B":20,"G":6,"Y":2})
        prev = dict(env.snapshot)
        for _ in range(max_steps):
            cands = legal_candidates_clean(env.tree, env.tickets)
            if not cands: break
            a = random.choice(cands)
            s_prev = encode_state_to_vector(prev)
            _, r, done, _ = env.step(a)
            if float(r) < 0:
                experience_neg3.append((s_prev, encode_action_to_vector(a), float(r)))
                added += 1
            prev = dict(env.snapshot)
            if done: break
    print(f"📦 Experience negatives: +{added} (total {len(experience_neg3)})")

def collect_stop_preference(env, global_tree, policy, trials=8, score_floor=0.25):
    """
    Compare STOP vs BEST-ACTION.
    If STOP is better → reinforce STOP as positive.
    If ACTION is better → reinforce ACTION as positive.
    """
    added = 0
    for _ in range(trials):
        _ = env.reset(use_global_tree=True, global_tree=global_tree)
        env.tickets.update({"B":20,"G":6,"Y":2})

        snap0 = dict(env.snapshot)

        # Evaluate STOP outcome
        stop_reward = 0.0  # STOP = do nothing

        # Evaluate best continuation
        cands = legal_candidates_clean(env.tree, env.tickets)
        if not cands:
            continue_reward = -0.1  # weak penalty if nothing to do
        else:
            with torch.no_grad():
                best, best_score = None, -1e9
                for a in cands:
                    v = encode_action_to_vector(a)
                    s = policy(torch.tensor(v, dtype=torch.float32, device=device).unsqueeze(0)).item()
                    if s > best_score: best_score, best = s, a
            _, continue_reward, _, _ = env.step(best)
            continue_reward = float(continue_reward)

        if stop_reward >= continue_reward:
            teacher_pos3.append((
                encode_state_to_vector(snap0),
                encode_action_to_vector(("STOP", None, None)),
                +2.0
            ))
        else:
            teacher_pos3.append((
                encode_state_to_vector(snap0),
                encode_action_to_vector(best),
                +1.5
            ))
        added += 1
    print(f"🟢 STOP-Preference labels: +{added} (total {len(teacher_pos3)})")

def collect_balanced_stop_depth(env, global_tree, policy, trials=8, score_floor=0.25):
    """
    Phase 11 Balanced depth control:
    If tree entropy is low → prefer STOP.
    If entropy is high → encourage more expansion.
    """
    added = 0
    for _ in range(trials):
        _ = env.reset(use_global_tree=True, global_tree=global_tree)
        env.tickets.update({"B":20,"G":6,"Y":2})

        snap = dict(env.snapshot)
        ent = snap["entropy"]

        # STOP is better when tree is already neat
        if ent < 2.2:
            teacher_pos3.append((
                encode_state_to_vector(snap),
                encode_action_to_vector(("STOP", None, None)),
                +1.8
            ))
        # CONTINUE is better when tree is messy
        else:
            cands = legal_candidates_clean(env.tree, env.tickets)
            if not cands:
                continue
            with torch.no_grad():
                scored = []
                for a in cands:
                    v = encode_action_to_vector(a)
                    s = policy(torch.tensor(v, dtype=torch.float32, device=device).unsqueeze(0)).item()
                    scored.append((a, s))
                best, _ = max(scored, key=lambda x: x[1])

            teacher_pos3.append((
                encode_state_to_vector(snap),
                encode_action_to_vector(best),
                +1.6
            ))
        added += 1
    print(f"⚖️ Balanced STOP/Depth labels added: +{added} (total {len(teacher_pos3)})")

# -------------------
# Training
# -------------------
def make_weighted_pairs_clean(positives, negatives, policy, k_neg=3, max_pairs=2048):
    if not positives or not negatives:
        print("⚠️ Need both positives & negatives."); return []
    t_scores = np.array([r for (_,_,r) in positives], dtype=np.float32)
    t_min, t_max = float(t_scores.min()), float(t_scores.max())
    def w_norm(r):
        if t_max <= t_min: return 1.0
        z = (r - t_min) / (t_max - t_min)
        return 0.1 + 0.9*float(z)
    # Hard negatives by current actor
    neg_scored = []
    with torch.no_grad():
        for (_, a_vec, r_neg) in negatives:
            s = policy(torch.tensor(a_vec, dtype=torch.float32, device=device).unsqueeze(0)).item()
            neg_scored.append((a_vec, s))
    neg_scored.sort(key=lambda x: x[1], reverse=True)

    pairs = []
    for (_, a_pos, r_pos) in positives:
        w = w_norm(r_pos)
        pool = neg_scored[:min(512, len(neg_scored))]
        for _ in range(k_neg):
            a_neg, _ = random.choice(pool)
            pairs.append((a_pos, a_neg, w))
            if len(pairs) >= max_pairs: break
        if len(pairs) >= max_pairs: break
    random.shuffle(pairs)
    print(f"✅ Weighted pairs: {len(pairs)}")
    return pairs

def train_hybrid_pairwise_clean(policy, opt, mse, positives, negatives,
                                epochs=6, batch_size=128,
                                margin=0.2, w_pair=0.7, w_reg=0.3,
                                max_pairs=2048, k_neg=3):
    pairs = make_weighted_pairs_clean(positives, negatives, policy, k_neg=k_neg, max_pairs=max_pairs)
    if not pairs: return

    reg_actions = np.asarray([a for (_,a,_) in positives], dtype=np.float32)
    reg_targets = np.asarray([np.tanh(r/3.0) for (_,_,r) in positives], dtype=np.float32).reshape(-1,1)

    def batched(lst, bsz):
        for i in range(0, len(lst), bsz): yield lst[i:i+bsz]

    for ep in range(epochs):
        random.shuffle(pairs)
        perm = np.random.permutation(len(reg_actions))
        regA, regY = reg_actions[perm], reg_targets[perm]
        pw_iter = batched(pairs, batch_size)
        rg_iter = batched(list(zip(regA, regY)), batch_size)

        losses = []
        while True:
            did = False
            try:
                pwb = next(pw_iter); did = True
                a_pos = torch.tensor(np.asarray([p[0] for p in pwb]), dtype=torch.float32, device=device)
                a_neg = torch.tensor(np.asarray([p[1] for p in pwb]), dtype=torch.float32, device=device)
                s_pos = policy(a_pos); s_neg = policy(a_neg)
                loss_pair = torch.relu(margin - s_pos + s_neg).mean()
            except StopIteration:
                loss_pair = None
            try:
                rgb = next(rg_iter); did = True
                acts = torch.tensor(np.asarray([x[0] for x in rgb]), dtype=torch.float32, device=device)
                tgts = torch.tensor(np.asarray([x[1] for x in rgb]), dtype=torch.float32, device=device)
                preds = policy(acts)
                loss_reg = mse(preds, tgts)
            except StopIteration:
                loss_reg = None

            if not did: break
            loss_total = 0.0
            if loss_pair is not None: loss_total += w_pair * loss_pair
            if loss_reg  is not None: loss_total += w_reg  * loss_reg

            opt.zero_grad()
            loss_total.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            opt.step()
            losses.append(loss_total.item())
        print(f"Epoch {ep+1} | Hybrid loss: {np.mean(losses):.6f}")

# -------------------
# Eval
# -------------------
def show_clean_stats():
    if teacher_pos3:
        ps = np.array([r for (_,_,r) in teacher_pos3], dtype=np.float32)
        print(f"Teacher+  n={len(ps)} | min={ps.min():.3f} max={ps.max():.3f} mean={ps.mean():.3f}")
    else:
        print("Teacher+  n=0")
    if experience_neg3:
        ns = np.array([r for (_,_,r) in experience_neg3], dtype=np.float32)
        print(f"Exp−      n={len(ns)} | min={ns.min():.3f} max={ns.max():.3f} mean={ns.mean():.3f}")
    else:
        print("Exp−      n=0")

def eval_policy_stop_and_greedy(env, global_tree, policy, rounds=10, score_floor=0.25):
    soft_scores = [run_episode_stop_aware(env, global_tree, policy, max_steps=10, score_floor=score_floor)
                   for _ in range(rounds)]
    print(f"STOP-AWARE mean reward: {np.mean(soft_scores):.4f}")
    print("STOP-AWARE scores:", soft_scores)

    def run_greedy_once():
        _ = env.reset(use_global_tree=True, global_tree=global_tree)
        env.tickets.update({"B":20,"G":6,"Y":2})
        total = 0.0
        for _ in range(10):
            cands = legal_candidates_clean(env.tree, env.tickets)
            if not cands: break
            best, best_score = None, -1e9
            with torch.no_grad():
                for a in cands:
                    v = encode_action_to_vector(a)
                    s = policy(torch.tensor(v, dtype=torch.float32, device=device).unsqueeze(0)).item()
                    if s > best_score:
                        best_score, best = s, a
            _, r, done, _ = env.step(best)
            total += float(r)
            if done: break
        return total

    greedy_scores = [run_greedy_once() for _ in range(rounds)]
    print(f"Greedy mean reward: {np.mean(greedy_scores):.4f}")
    print("Greedy scores:", greedy_scores)
    return float(np.mean(soft_scores)), float(np.mean(greedy_scores))

# -------------------
# Checkpointing (survive Colab)
# -------------------
def _ensure_dir(d):
    try: os.makedirs(d, exist_ok=True)
    except Exception: pass

def save_checkpoint(dirpath, name, policy, opt, meta=None):
    _ensure_dir(dirpath)
    torch.save({
        "state_dict": policy.state_dict(),
        "optimizer": opt.state_dict(),
        "meta": meta or {}
    }, os.path.join(dirpath, f"{name}.pt"))
    # save buffers too
    with open(os.path.join(dirpath, f"{name}_buffers.pkl"), "wb") as f:
        pickle.dump({"teacher_pos3": teacher_pos3, "experience_neg3": experience_neg3}, f)
    print(f"✅ Saved checkpoint → {name}")

def load_checkpoint(dirpath, name, policy, opt):
    path = os.path.join(dirpath, f"{name}.pt")
    bufp = os.path.join(dirpath, f"{name}_buffers.pkl")
    if not os.path.exists(path): return False
    data = torch.load(path, map_location=device)
    policy.load_state_dict(data["state_dict"])
    if "optimizer" in data:
        try: opt.load_state_dict(data["optimizer"])
        except Exception: pass
    if os.path.exists(bufp):
        try:
            with open(bufp, "rb") as f:
                bufs = pickle.load(f)
                teacher_pos3.clear(); teacher_pos3.extend(bufs.get("teacher_pos3", []))
                experience_neg3.clear(); experience_neg3.extend(bufs.get("experience_neg3", []))
        except Exception: pass
    print(f"✅ Loaded checkpoint ← {name}")
    return True

# -------------------
# Multi-cycle trainer (auto-save each cycle + BEST)
# -------------------
def train_cycles(env,
                 global_tree,
                 existing_policy=None,
                 cycles=10,
                 episodes_per_cycle=6,
                 max_steps_per_ep=6,
                 beam=12,
                 min_gain=1e-6,
                 stop_label=1.5,
                 margin=0.2,
                 w_pair=0.7,
                 w_reg=0.3,
                 batch_size=128,
                 max_pairs=2048,
                 score_floor_eval=0.25,
                 ckpt_dir="/content/phase8_ckpts",
                 ckpt_prefix="cycle"):
    policy, opt, mse = init_policy(existing_policy)
    migrate_old_buffers()

    best_soft = -1e9
    for i in range(1, cycles+1):
        print(f"\n===== TRAINING CYCLE {i} =====")
        # collect
        collect_teacher_strong(env, global_tree, policy,
                               episodes=episodes_per_cycle,
                               max_steps=max_steps_per_ep,
                               beam=beam,
                               min_gain=min_gain,
                               stop_label=stop_label)
        collect_experience_negatives_clean(env, global_tree, policy,
                                           episodes=episodes_per_cycle,
                                           max_steps=max_steps_per_ep)

        # stats
        show_clean_stats()

        # train
        train_hybrid_pairwise_clean(policy, opt, mse,
                                    teacher_pos3, experience_neg3,
                                    epochs=6, batch_size=batch_size,
                                    margin=margin, w_pair=w_pair, w_reg=w_reg,
                                    max_pairs=max_pairs, k_neg=3)

        # eval & save
        soft_mean, _ = eval_policy_stop_and_greedy(env, global_tree, policy, rounds=10, score_floor=score_floor_eval)
        meta = {"cycle": i, "soft_mean": soft_mean}

        save_checkpoint(ckpt_dir, f"{ckpt_prefix}_{i}", policy, opt, meta)
        if soft_mean > best_soft:
            best_soft = soft_mean
            save_checkpoint(ckpt_dir, "BEST", policy, opt, {"cycle": i, "soft_mean": soft_mean})

    return policy

Writing phase8_clean.py


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

from phase8_clean import *
policy, opt, mse = init_policy()

2025-11-08 09:49:51,076 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 09:49:51,085 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 09:49:51,090 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 09:49:51,102 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 09:49:51,107 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 09:49:51,110 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 09:49:51,118 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 09:49:51,124 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponent

Mounted at /content/drive


In [None]:

import os, pickle, torch, numpy as np

SAVE_PATH = "/content/drive/MyDrive/chaturya_saves"
os.makedirs(SAVE_PATH, exist_ok=True)

best_score = -9999  # Higher STOP-aware score = better model

def save_all(tag):
    """Save policy + teacher_pos3 + experience_neg3 + prakriti_tree."""
    global policy, teacher_pos3, experience_neg3, prakriti_tree

    torch.save(policy.state_dict(), f"{SAVE_PATH}/policy_{tag}.pt")

    with open(f"{SAVE_PATH}/teacher_pos3_{tag}.pkl", "wb") as f:
        pickle.dump(teacher_pos3, f)

    with open(f"{SAVE_PATH}/experience_neg3_{tag}.pkl", "wb") as f:
        pickle.dump(experience_neg3, f)

    with open(f"{SAVE_PATH}/prakriti_tree_{tag}.pkl", "wb") as f:
        pickle.dump(prakriti_tree, f)

    print(f"✅ Saved checkpoint → {tag}")

In [None]:

# ---- PHASE 9 TRAINING LOOP (Corrected) ----

best_score = -9999

for cycle in range(12):   # or any number of cycles
    print(f"\n===== TRAINING CYCLE {cycle+1} =====")

    # (1) Collect training data
    collect_teacher_strong(env, prakriti_tree, policy, episodes=5, max_steps=6)
    collect_experience_negatives_clean(env, prakriti_tree, policy, episodes=5, max_steps=6)
    show_clean_stats()

    # (2) Train policy
    train_hybrid_pairwise_clean(
        policy, opt, mse,
        teacher_pos3, experience_neg3,
        epochs=6, batch_size=128,
        margin=0.2, w_pair=0.7, w_reg=0.3,
        max_pairs=2048, k_neg=3
    )

    # (3) Evaluate performance
    stop_scores = [run_episode_stop_aware(env, prakriti_tree, policy, max_steps=10, score_floor=0.25)
                   for _ in range(10)]
    mean_stop = np.mean(stop_scores)
    print(f"STOP-AWARE mean reward = {mean_stop:.4f}")

    # (4) Save model
    save_all(f"cycle_{cycle+1}")

    if mean_stop > best_score:
        best_score = mean_stop
        save_all("BEST")


===== TRAINING CYCLE 1 =====
🧑‍🏫 Strong teacher positives: +90 (total 137)
📦 Experience negatives: +29 (total 59)
Teacher+  n=137 | min=0.400 max=95.925 mean=3.028
Exp−      n=59 | min=-0.712 max=-0.069 mean=-0.525
✅ Weighted pairs: 411
Epoch 1 | Hybrid loss: 0.092942
Epoch 2 | Hybrid loss: 0.070395
Epoch 3 | Hybrid loss: 0.052965
Epoch 4 | Hybrid loss: 0.056668
Epoch 5 | Hybrid loss: 0.052509
Epoch 6 | Hybrid loss: 0.049455
STOP-AWARE mean reward = -4.6559
✅ Saved checkpoint → cycle_1
✅ Saved checkpoint → BEST

===== TRAINING CYCLE 2 =====
🧑‍🏫 Strong teacher positives: +68 (total 205)
📦 Experience negatives: +30 (total 89)
Teacher+  n=205 | min=0.306 max=95.925 mean=2.443
Exp−      n=89 | min=-0.712 max=-0.047 mean=-0.476
✅ Weighted pairs: 615
Epoch 1 | Hybrid loss: 0.055487
Epoch 2 | Hybrid loss: 0.055025
Epoch 3 | Hybrid loss: 0.053375
Epoch 4 | Hybrid loss: 0.052182
Epoch 5 | Hybrid loss: 0.050626
Epoch 6 | Hybrid loss: 0.050775
STOP-AWARE mean reward = -5.1481
✅ Saved checkpoint 

# Phase 9: Stop aware

In [None]:

# === Phase 9 — Cell 1: Prakriti profile + STOP scoring ===

# Prakriti (classification-tree) targets
PRAKRITI_CFG = {
    "bf_target": 3.0,          # branching factor target
    "bf_tol": 0.7,             # acceptable band: [2.3 .. 3.7]
    "entropy_good": 0.65,      # "ordered enough"
    "weak_ratio_good": 0.12,   # weak_leaves <= 12% of nodes
    "depth_cap": 7             # avoid overly deep trees
}

def _ratio(x, y):
    y = max(1, y)
    return float(x)/float(y)

def stop_score(snapshot, cfg=PRAKRITI_CFG):
    """
    Returns a score in ~[0,1.5] (higher = safer to STOP).
    Combines: low entropy, near-target branching factor, few weak leaves, not-too-deep.
    """
    bf = snapshot.get("branching_factor", 0.0)
    ent = snapshot.get("entropy", 1.0)
    nodes = snapshot.get("node_count", 1)
    weak = snapshot.get("weak_leaves", 0)
    depth = snapshot.get("depth", 1)

    # (1) entropy contribution
    s_ent = max(0.0, (cfg["entropy_good"] - ent) / max(1e-6, cfg["entropy_good"]))  # ≤0 when ent>good, up to 1.0 when ent≈0

    # (2) branching closeness to 3.0
    bf_err = abs(cfg["bf_target"] - bf)
    s_bf = max(0.0, 1.0 - bf_err / cfg["bf_tol"])  # within band → near 1

    # (3) weak leaves small ratio
    wr = _ratio(weak, nodes)
    s_weak = max(0.0, 1.0 - (wr / cfg["weak_ratio_good"]))  # if weak ratio ≤ target → near 1

    # (4) depth not too big (soft)
    s_depth = 1.0 if depth <= cfg["depth_cap"] else max(0.0, 1.0 - 0.15*(depth - cfg["depth_cap"]))

    # weights tuned for Prakriti (structure quality > depth)
    score = 0.38*s_ent + 0.38*s_bf + 0.18*s_weak + 0.06*s_depth
    return float(score)

def should_stop(snapshot, tau=0.62):
    """Return True if STOP score exceeds threshold."""
    return stop_score(snapshot) >= float(tau)

In [None]:

# === Phase 9 — Cell 2: STOP-aware selector (no changes to action encoder) ===
import numpy as np
import torch

def select_action_or_stop(snapshot, env, policy, device,
                          tau=0.62,               # STOP threshold (Prakriti tuned)
                          epsilon=0.15,           # exploration
                          temperature=1.2,        # softmax over policy
                          beam=None):             # optional top-K pruning
    # 1) STOP gate first
    if should_stop(snapshot, tau=tau):
        return ("STOP", None, None)

    # 2) Otherwise choose a structural action (your existing legal_candidates)
    cands = legal_candidates(env.tree, env.tickets)
    if not cands:  # nothing to do → STOP
        return ("STOP", None, None)

    # ε-greedy explore
    if np.random.rand() < epsilon:
        return random.choice(cands)

    # score with policy
    with torch.no_grad():
        scored = []
        for a in cands:
            avec = encode_action_to_vector(a)
            at = torch.tensor(avec, dtype=torch.float32, device=device).unsqueeze(0)
            s = policy(at).item()
            scored.append((a, s))

    # optional beam to stabilize
    if beam is not None and beam > 0:
        scored.sort(key=lambda x: x[1], reverse=True)
        scored = scored[:min(beam, len(scored))]

    # softmax sampling
    vals = np.array([s for _, s in scored], dtype=np.float64) / float(temperature)
    vals -= vals.max()
    p = np.exp(vals); p /= p.sum() if p.sum() > 0 else 1.0
    idx = np.random.choice(len(scored), p=p)
    return scored[idx][0]

In [None]:

# === Phase 9 — Cell 3: STOP-aware episode + reward shaping ===

def stop_bonus(snapshot, cfg=PRAKRITI_CFG):
    """
    Shaping reward when STOP is chosen:
      +2.0 if structure is strong,
      +1.0 if decent,
      -2.0 if premature.
    """
    sc = stop_score(snapshot, cfg)
    if sc >= 0.85:  # excellent
        return +2.0
    if sc >= 0.62:  # good enough (τ)
        return +1.0
    return -2.0     # shouldn't have stopped

def run_stopaware_episode(env, policy, device,
                          max_steps=12,
                          tau=0.62, epsilon=0.15, temperature=1.2, beam=12):
    """
    Uses select_action_or_stop(). If STOP returned, apply STOP bonus and end episode.
    """
    _ = env.reset(use_global_tree=True, global_tree=prakriti_tree)
    seed_env_tickets(env, B=20, G=6, Y=2)

    total_reward = 0.0
    for _ in range(max_steps):
        action = select_action_or_stop(env.snapshot, env, policy, device,
                                       tau=tau, epsilon=epsilon,
                                       temperature=temperature, beam=beam)
        if action and action[0] == "STOP":
            total_reward += stop_bonus(env.snapshot)
            break

        if action is None:
            # no legal actions → treat as STOP-attempt
            total_reward += stop_bonus(env.snapshot)
            break

        _, r, done, _ = env.step(action)
        total_reward += float(r)
        if done:
            # if environment ends early, apply *final* STOP check once
            total_reward += max(0.0, stop_bonus(env.snapshot))
            break

    return total_reward

In [None]:

# === Phase 9 — Cell 4: Final evaluation (tables-ready) ===
import numpy as np

def eval_random_baseline(env, rounds=10, max_steps=12):
    def _random_episode():
        _ = env.reset(use_global_tree=True, global_tree=prakriti_tree)
        seed_env_tickets(env, B=20, G=6, Y=2)
        tot=0.0
        for _ in range(max_steps):
            c = legal_candidates(env.tree, env.tickets)
            if not c:
                break
            a = random.choice(c)
            _, r, done, _ = env.step(a)
            tot += float(r)
            if done: break
        return tot
    scores = [ _random_episode() for _ in range(rounds) ]
    return float(np.mean(scores)), scores

def eval_soft_policy(env, policy, rounds=10, max_steps=12):
    scores = [ run_policy_episode(env, max_steps=max_steps, greedy=False) for _ in range(rounds) ]
    return float(np.mean(scores)), scores

def eval_greedy_policy(env, policy, rounds=10, max_steps=12):
    scores = [ run_policy_episode(env, max_steps=max_steps, greedy=True) for _ in range(rounds) ]
    return float(np.mean(scores)), scores

def eval_stopaware(env, policy, rounds=10, max_steps=12, tau=0.62):
    scores = [ run_stopaware_episode(env, policy, device, max_steps=max_steps, tau=tau) for _ in range(rounds) ]
    return float(np.mean(scores)), scores

# ---- Run all & print table ----
rnd_mean, rnd_scores   = eval_random_baseline(env, rounds=10, max_steps=12)
soft_mean, soft_scores = eval_soft_policy(env, policy, rounds=10, max_steps=12)
grdy_mean, grdy_scores = eval_greedy_policy(env, policy, rounds=10, max_steps=12)
stp_mean,  stp_scores  = eval_stopaware(env, policy, rounds=10, max_steps=12, tau=0.58)

print("\n=== Phase 9 — Final Benchmarks (Prakriti) ===")
print(f"Random baseline      : {rnd_mean:.3f}  | {rnd_scores}")
print(f"Policy (soft)        : {soft_mean:.3f}  | {soft_scores}")
print(f"Policy (greedy)      : {grdy_mean:.3f}  | {grdy_scores}")
print(f"STOP-aware (ours)    : {stp_mean:.3f}  | {stp_scores}")

# Optional: quick ablation of τ
for tau in [0.55, 0.62, 0.70]:
    m,_ = eval_stopaware(env, policy, rounds=10, max_steps=12, tau=tau)
    print(f"STOP-aware tau={tau:.2f} → mean {m:.3f}")

# Phase 10: Evaluation Harness

In [None]:

%%writefile benchmarks.py
import json, time, numpy as np
from pathlib import Path
import torch

def run_stop_aware(env, policy, global_tree, rounds=10, score_floor=0.25):
    from phase8_clean import run_episode_stop_aware
    scores = []
    for _ in range(rounds):
        s = run_episode_stop_aware(env, global_tree, policy, max_steps=10, score_floor=score_floor)
        scores.append(float(s))
    return scores

def run_greedy(env, policy, global_tree, rounds=10):
    from phase8_clean import legal_candidates_clean
    from tlite_rl_bridge import encode_action_to_vector

    scores = []
    for _ in range(rounds):
        _ = env.reset(use_global_tree=True, global_tree=global_tree)
        env.tickets.update({"B":20,"G":6,"Y":2,"R":0,"P":0})
        total = 0.0

        for _ in range(10):
            cands = legal_candidates_clean(env.tree, env.tickets)
            if not cands:
                break

            best = None
            best_score = -1e9
            with torch.no_grad():
                for a in cands:
                    v = encode_action_to_vector(a)
                    s = policy(torch.tensor(v, dtype=torch.float32).unsqueeze(0)).item()
                    if s > best_score:
                        best_score = s
                        best = a

            _, r, done, _ = env.step(best)
            total += float(r)
            if done: break

        scores.append(total)
    return scores

def summarize(scores):
    arr = np.array(scores, dtype=np.float32)
    return {
        "mean": float(arr.mean()),
        "std": float(arr.std()),
        "min": float(arr.min()),
        "max": float(arr.max()),
        "n": int(len(arr)),
    }

def run_benchmark(env, policy, global_tree, name="phase9_eval", rounds=10, score_floor=0.25):
    t0 = time.time()

    stop_scores   = run_stop_aware(env, policy, global_tree, rounds, score_floor)
    greedy_scores = run_greedy(env, policy, global_tree, rounds)

    result = {
        "name": name,
        "rounds": rounds,
        "score_floor": score_floor,
        "stop_aware": summarize(stop_scores),
        "greedy": summarize(greedy_scores),
        "stop_raw": stop_scores,
        "greedy_raw": greedy_scores,
        "time_sec": round(time.time() - t0, 3),
    }

    Path("benchmarks").mkdir(exist_ok=True)
    out = Path(f"benchmarks/{name}.json")
    out.write_text(json.dumps(result, indent=2))

    print("\n=== BENCHMARK RESULT ===")
    print(json.dumps(result, indent=2))
    print(f"Saved → {out}")

    return result

Overwriting benchmarks.py


In [None]:

import importlib, phase8_clean
importlib.reload(phase8_clean)

<module 'phase8_clean' from '/content/phase8_clean.py'>

In [None]:
import importlib, benchmarks
importlib.reload(benchmarks)

<module 'benchmarks' from '/content/benchmarks.py'>

In [None]:
result = benchmarks.run_benchmark(env, policy, prakriti_tree, name="phase9_test", rounds=10, score_floor=0.25)


=== BENCHMARK RESULT ===
{
  "name": "phase9_test",
  "rounds": 10,
  "score_floor": 0.25,
  "stop_aware": {
    "mean": -4.266352653503418,
    "std": 0.9958307147026062,
    "min": -5.303299903869629,
    "max": -1.804226040840149,
    "n": 10
  },
  "greedy": {
    "mean": -1.4933630228042603,
    "std": 0.41320204734802246,
    "min": -2.5925400257110596,
    "max": -0.8016899824142456,
    "n": 10
  },
  "stop_raw": [
    -1.8042259999999999,
    -4.412287,
    -4.978529,
    -5.3033,
    -3.27971,
    -4.878804000000001,
    -4.078247,
    -4.217779,
    -5.147636,
    -4.563009
  ],
  "greedy_raw": [
    -0.80169,
    -2.5925399999999996,
    -1.4424249999999994,
    -1.4424249999999994,
    -1.4424249999999994,
    -1.4424249999999994,
    -1.4424249999999994,
    -1.4424249999999994,
    -1.4424249999999994,
    -1.4424249999999994
  ],
  "time_sec": 27.931
}
Saved → benchmarks/phase9_test.json


Phase 9.2: STOP Preference Stabilization

In [None]:

import importlib
import phase8_clean
importlib.reload(phase8_clean)

from phase8_clean import (
    collect_stop_preference,
    train_hybrid_pairwise_clean,
    teacher_pos3,
    experience_neg3,
    show_clean_stats,
    run_episode_stop_aware
)

In [None]:

collect_experience_negatives_clean(env, prakriti_tree, policy, episodes=6, max_steps=6)
show_clean_stats()

📦 Experience negatives: +35 (total 35)
Teacher+  n=0
Exp−      n=35 | min=-0.705 max=-0.203 mean=-0.597


In [None]:

for step in range(8):  # just ~8 rounds is enough
    collect_stop_preference(env, prakriti_tree, policy, trials=8)
    show_clean_stats()
    train_hybrid_pairwise_clean(policy, opt, mse, teacher_pos3, experience_neg3,
                                epochs=4, batch_size=128, max_pairs=1024)
    eval_policy_stop_and_greedy(env, prakriti_tree, policy, rounds=10, score_floor=0.25)

🟢 STOP-Preference labels: +8 (total 72)
Teacher+  n=72 | min=1.500 max=2.000 mean=1.750
Exp−      n=31 | min=-0.542 max=-0.047 mean=-0.358
✅ Weighted pairs: 216
Epoch 1 | Hybrid loss: 0.096932
Epoch 2 | Hybrid loss: 0.094097
Epoch 3 | Hybrid loss: 0.093561
Epoch 4 | Hybrid loss: 0.094047
STOP-AWARE mean reward: -3.8799
STOP-AWARE scores: [-2.254638, -2.556985, -2.470725, -4.067169000000001, -4.340739, -4.277401, -3.9980580000000008, -4.575845, -5.237764, -5.019211]
Greedy mean reward: -1.5384
Greedy scores: [-1.8323489999999998, -2.01194, -1.4424249999999994, -1.4424249999999994, -1.4424249999999994, -1.4424249999999994, -1.4424249999999994, -1.4424249999999994, -1.4424249999999994, -1.4424249999999994]
🟢 STOP-Preference labels: +8 (total 80)
Teacher+  n=80 | min=1.500 max=2.000 mean=1.750
Exp−      n=31 | min=-0.542 max=-0.047 mean=-0.358
✅ Weighted pairs: 240
Epoch 1 | Hybrid loss: 0.091628
Epoch 2 | Hybrid loss: 0.083996
Epoch 3 | Hybrid loss: 0.083426
Epoch 4 | Hybrid loss: 0.08177

In [None]:
eval_policy_stop_and_greedy(env, prakriti_tree, policy, rounds=10, score_floor=0.25)

STOP-AWARE mean reward: -4.8766
STOP-AWARE scores: [-5.358662000000001, -5.82298, -4.627963, -5.356586, -4.9787, -3.9705399999999997, -4.696218999999999, -5.399184, -4.340446999999999, -4.21488]
Greedy mean reward: -5.8000
Greedy scores: [-5.8, -5.8, -5.8, -5.8, -5.8, -5.8, -5.8, -5.8, -5.8, -5.8]


(-4.876616100000001, -5.799999999999999)

In [None]:

for step in range(8):
    collect_teacher_strong(env, prakriti_tree, policy, episodes=5, max_steps=6, beam=20)
    collect_experience_negatives_clean(env, prakriti_tree, policy, episodes=5, max_steps=6)
    show_clean_stats()
    train_hybrid_pairwise_clean(policy, opt, mse, teacher_pos3, experience_neg3,
                                epochs=4, batch_size=128, max_pairs=1024)
    mean_stop = np.mean([run_episode_stop_aware(env, prakriti_tree, policy) for _ in range(10)])
    print("STOP-AWARE:", mean_stop)

🧑‍🏫 Strong teacher positives: +40 (total 124)
📦 Experience negatives: +29 (total 135)
Teacher+  n=124 | min=1.417 max=6.365 mean=1.804
Exp−      n=135 | min=-0.705 max=-0.117 mean=-0.565
✅ Weighted pairs: 372
Epoch 1 | Hybrid loss: 0.090747
Epoch 2 | Hybrid loss: 0.090613
Epoch 3 | Hybrid loss: 0.090534
Epoch 4 | Hybrid loss: 0.090409
STOP-AWARE: -5.5009144
🧑‍🏫 Strong teacher positives: +25 (total 149)
📦 Experience negatives: +29 (total 164)
Teacher+  n=149 | min=1.417 max=6.365 mean=1.776
Exp−      n=164 | min=-0.705 max=-0.117 mean=-0.568
✅ Weighted pairs: 447
Epoch 1 | Hybrid loss: 0.086419
Epoch 2 | Hybrid loss: 0.083844
Epoch 3 | Hybrid loss: 0.086020
Epoch 4 | Hybrid loss: 0.086497
STOP-AWARE: -5.1621974
🧑‍🏫 Strong teacher positives: +20 (total 169)
📦 Experience negatives: +30 (total 194)
Teacher+  n=169 | min=1.417 max=6.365 mean=1.774
Exp−      n=194 | min=-0.705 max=-0.017 mean=-0.574
✅ Weighted pairs: 507
Epoch 1 | Hybrid loss: 0.089822
Epoch 2 | Hybrid loss: 0.089930
Epoch 3

In [None]:

import pickle

with open(f"{SAVE_PATH}/phase9_teacher_pos_final.pkl", "wb") as f:
    pickle.dump(teacher_pos3, f)

with open(f"{SAVE_PATH}/phase9_experience_neg_final.pkl", "wb") as f:
    pickle.dump(experience_neg3, f)

torch.save(policy.state_dict(), f"{SAVE_PATH}/phase9_policy.pt")

print("✅ Phase 9 saved. Ready for Phase 10.")

✅ Phase 9 saved. Ready for Phase 10.


# Phase 11: Balanced Model Behaviour

In [None]:

from phase8_clean import *
policy, opt, mse = init_policy()

2025-11-08 13:56:48,741 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 13:56:48,745 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 13:56:48,749 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 13:56:48,755 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 13:56:48,762 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 13:56:48,768 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 13:56:48,773 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponents:Initialized TLiteExpert dim=50 on cpu
2025-11-08 13:56:48,778 - INFO - Initialized TLiteExpert dim=50 on cpu
INFO:TLiteComponent

In [None]:

for step in range(12):   # ~12 rounds works well
    print(f"\n===== PHASE 11 — Balanced STOP/Depth ROUND {step+1} =====")

    collect_balanced_stop_depth(env, prakriti_tree, policy, trials=8)
    collect_experience_negatives_clean(env, prakriti_tree, policy, episodes=4, max_steps=6)
    show_clean_stats()

    train_hybrid_pairwise_clean(policy, opt, mse,
                                teacher_pos3, experience_neg3,
                                epochs=5, batch_size=128,
                                margin=0.15, w_pair=0.6, w_reg=0.3,
                                max_pairs=1800, k_neg=3)

    soft, greedy = eval_policy_stop_and_greedy(env, prakriti_tree, policy,
                                               rounds=10, score_floor=0.20)

    print("STOP-AWARE:", soft, "| GREEDY:", greedy)


===== PHASE 11 — Balanced STOP/Depth ROUND 1 =====
⚖️ Balanced STOP/Depth labels added: +8 (total 8)
📦 Experience negatives: +23 (total 23)
Teacher+  n=8 | min=1.600 max=1.600 mean=1.600
Exp−      n=23 | min=-0.666 max=-0.157 mean=-0.530
✅ Weighted pairs: 24
Epoch 1 | Hybrid loss: 0.088832
Epoch 2 | Hybrid loss: 0.087944
Epoch 3 | Hybrid loss: 0.087066
Epoch 4 | Hybrid loss: 0.086196
Epoch 5 | Hybrid loss: 0.085336
STOP-AWARE mean reward: -5.1054
STOP-AWARE scores: [-5.325594, -5.828528000000001, -4.355714000000001, -4.754284999999999, -5.819572000000001, -5.470591, -4.568402, -4.474152, -5.214576999999999, -5.242301]
Greedy mean reward: -6.6007
Greedy scores: [-6.606702000000001, -6.6000000000000005, -6.6000000000000005, -6.6000000000000005, -6.6000000000000005, -6.6000000000000005, -6.6000000000000005, -6.6000000000000005, -6.6000000000000005, -6.6000000000000005]
STOP-AWARE: -5.1053716 | GREEDY: -6.600670200000001

===== PHASE 11 — Balanced STOP/Depth ROUND 2 =====
⚖️ Balanced STOP

In [None]:
save_all("phase11_final")

✅ Saved checkpoint → phase11_final


# Phase Live

Phase 0:

In [None]:

import pandas as pd
import json
import re

# ------------------------------------------------------
# Helper: Convert Dosha label to numeric weights
# ------------------------------------------------------
def dosha_weights(label):
    label = label.lower().replace(" ", "")
    parts = label.split("+")

    weight = {"vata": 0.0, "pitta": 0.0, "kapha": 0.0}
    for p in parts:
        if p in weight:
            weight[p] += 1.0

    total = sum(weight.values())
    if total > 0:
        for k in weight:
            weight[k] /= total
    return weight


# ------------------------------------------------------
# Helper: Convert cell text to trait tokens
# ------------------------------------------------------
def extract_traits(cell_value):
    if not isinstance(cell_value, str):
        return []

    # Lowercase + remove punctuation
    text = re.sub(r"[^a-zA-Z\s]", " ", cell_value.lower())

    # Split into words
    tokens = text.split()

    # Remove very common low-information words
    stop = {"and", "to", "the", "of", "a", "in", "is", "very", "but", "no"}
    tokens = [t for t in tokens if t not in stop and len(t) > 2]

    return tokens


# ------------------------------------------------------
# PHASE 0 MAIN
# Input: CSV dataset
# Output: trait_map.jsonl
# ------------------------------------------------------
def build_trait_map(csv_path, label_col="Dosha", output="trait_map.jsonl"):
    df = pd.read_csv(csv_path)

    # All columns except label are trait sources
    trait_columns = [c for c in df.columns if c != label_col]

    trait_stats = {}  # trait -> dict(vata, pitta, kapha)

    for _, row in df.iterrows():
        label = str(row[label_col])
        weights = dosha_weights(label)

        for col in trait_columns:
            traits = extract_traits(row[col])
            for t in traits:
                if t not in trait_stats:
                    trait_stats[t] = {"vata": 0.0, "pitta": 0.0, "kapha": 0.0}

                # Add weights
                trait_stats[t]["vata"] += weights["vata"]
                trait_stats[t]["pitta"] += weights["pitta"]
                trait_stats[t]["kapha"] += weights["kapha"]

    # Normalize each trait vector
    final_map = []
    for trait, w in trait_stats.items():
        total = w["vata"] + w["pitta"] + w["kapha"]
        if total > 0:
            w = {k: round(v/total, 4) for k, v in w.items()}
        final_map.append({"trait": trait, **w})

    # Sort traits by dominance confidence
    final_map.sort(key=lambda x: max(x["vata"], x["pitta"], x["kapha"]), reverse=True)

    # Write JSONL
    with open(output, "w", encoding="utf-8") as f:
        for item in final_map:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

    print(f"[OK] Trait map written to {output}")
    print(f"Total traits extracted: {len(final_map)}")


# ------------------------------------------------------
# Example Run
# ------------------------------------------------------
if __name__ == "__main__":
    build_trait_map("/content/Updated_Prakriti_With_Features.csv", label_col="Dosha", output="trait_map.jsonl")

[OK] Trait map written to trait_map.jsonl
Total traits extracted: 132


In [None]:

import pandas as pd
import re, json
from collections import Counter

# If DifficultyScorer is available, we import it. Otherwise, define fallback.
try:
    from DifficultyScorer import DifficultyScorer
except:
    class DifficultyScorer:
        def __init__(self, threshold=0.75):
            self.threshold = threshold
        def fit(self, df, y=None): return self
        def transform(self, df):
            records = []
            for _, row in df.iterrows():
                filled = sum(v is not None and v==v and v!="" for v in row.values)
                total = len(row.values)
                score = filled / total
                tree_type = "three-tree" if score > self.threshold else "binary"
                records.append({"difficulty_score": round(score,4), "tree_type": tree_type})
            return pd.DataFrame(records)


def extract_tokens(text):
    if not isinstance(text, str): return []
    text = text.lower()
    text = re.sub(r"[^a-zA-Z\s]", " ", text)
    tokens = [t for t in text.split() if len(t) > 2]
    return tokens

def phase1_process(df, label_col="Dosha", out_dir="phase1_out"):
    import os
    os.makedirs(out_dir, exist_ok=True)

    # All columns except label
    trait_columns = [c for c in df.columns if c != label_col]

    # ------------------------
    # (1) Trait Frequency Mining
    # ------------------------
    freq = Counter()
    for _, row in df.iterrows():
        for col in trait_columns:
            tokens = extract_tokens(row[col])
            freq.update(tokens)

    # Filter out very rare traits
    freq = {t: c for t, c in freq.items() if c >= 2}

    # Save traits.jsonl
    with open(f"{out_dir}/traits.jsonl", "w") as f:
        for t, c in sorted(freq.items(), key=lambda x: -x[1]):
            f.write(json.dumps({"trait": t, "count": c}, ensure_ascii=False) + "\n")

    # ------------------------
    # (2) Difficulty + Tree Type
    # ------------------------
    scorer = DifficultyScorer(threshold=0.6)  # lower/raise if needed
    scores = scorer.transform(df[trait_columns])

    with open(f"{out_dir}/row_scores.jsonl", "w") as f:
        for idx, row in scores.iterrows():
            f.write(json.dumps({
                "row_id": int(idx),
                "difficulty_score": row["difficulty_score"],
                "tree_type": row["tree_type"]
            }, ensure_ascii=False) + "\n")

    print("✅ Phase 1 Complete")
    print(f"Saved → {out_dir}/traits.jsonl")
    print(f"Saved → {out_dir}/row_scores.jsonl")
    return freq

In [None]:

df = pd.read_csv("/content/Updated_Prakriti_With_Features.csv")
phase1_process(df, label_col="Dosha")

✅ Phase 1 Complete
Saved → phase1_out/traits.jsonl
Saved → phase1_out/row_scores.jsonl


{'medium': 2664,
 'moderate': 6369,
 'difficulties': 1200,
 'gaining': 996,
 'losing': 852,
 'weight': 1200,
 'average': 828,
 'large': 1212,
 'broad': 180,
 'shoulders': 180,
 'heavy': 384,
 'bone': 768,
 'structure': 768,
 'white': 384,
 'pale': 252,
 'tans': 624,
 'easily': 1728,
 'dry': 2268,
 'and': 1836,
 'thin': 1560,
 'cool': 814,
 'touch': 456,
 'rough': 1008,
 'pigments': 192,
 'aging': 192,
 'black': 1320,
 'brown': 1200,
 'dull': 780,
 'straight': 420,
 'oily': 2016,
 'long': 818,
 'angular': 576,
 'sized': 1512,
 'penetrating': 444,
 'light': 1164,
 'sensitive': 799,
 'eyes': 1200,
 'eyelashes': 1200,
 'blinking': 1044,
 'wrinkled': 480,
 'sunken': 480,
 'rounded': 564,
 'open': 312,
 'nostrils': 312,
 'big': 312,
 'strong': 742,
 'teeth': 1200,
 'healthy': 132,
 'gums': 1200,
 'tight': 528,
 'lips': 1200,
 'which': 528,
 'chaps': 528,
 'thick': 828,
 'smooth': 1308,
 'polished': 264,
 'slow': 576,
 'but': 336,
 'steady': 336,
 'sweet': 1068,
 'sour': 816,
 'salty': 816,
 

In [None]:
!grep -o '"tree_type": "[^"]*"' phase1_out/row_scores.jsonl | sort | uniq -c

   1200 "tree_type": "three-tree"


In [None]:

# === Phase 1: Dataset prep + tree_type scoring ===

import os, json
import pandas as pd

# ====== YOU EDIT THESE ======
CSV_PATH         = "/content/Updated_Prakriti_With_Features.csv"   # your dataset
OUT_DIR          = "/content/phase1_out"
K_CHUNK          = 5000
SAMPLE_ROWS      = 10000
CAT_MAX_UNIQUE   = 20
DIFF_THRESHOLD   = 0.75   # DifficultyScorer threshold
# ============================

os.makedirs(OUT_DIR, exist_ok=True)

OUT_CSV     = os.path.join(OUT_DIR, "phase1_augmented.csv")
ART_SELECTED= os.path.join(OUT_DIR, "selected_features.json")
ART_DROPPED = os.path.join(OUT_DIR, "dropped_features.json")
ART_SCHEMA  = os.path.join(OUT_DIR, "schema.json")
ART_COUNTS  = os.path.join(OUT_DIR, "tree_type_counts.json")

from FeatureSelector import FeatureSelector
from DifficultyScorer import DifficultyScorer

# ---------- 0) Save Schema ----------
first_chunk = pd.read_csv(CSV_PATH, nrows=5)
schema_info = {
    "columns": first_chunk.columns.tolist(),
    "preview_rows": len(first_chunk)
}
with open(ART_SCHEMA, "w", encoding="utf-8") as f:
    json.dump(schema_info, f, ensure_ascii=False, indent=2)

print("Schema saved:", ART_SCHEMA)
print("Columns:", schema_info["columns"])

# ---------- 1) Feature Selection ----------
sample_df = pd.read_csv(CSV_PATH, nrows=SAMPLE_ROWS, low_memory=False)
selector = FeatureSelector(cat_max_unique=CAT_MAX_UNIQUE).fit(sample_df)

selected = selector.get_selected_features()
dropped  = selector.get_dropped_features()

with open(ART_SELECTED, "w", encoding="utf-8") as f: json.dump(selected, f, ensure_ascii=False, indent=2)
with open(ART_DROPPED,  "w", encoding="utf-8") as f: json.dump(dropped,  f, ensure_ascii=False, indent=2)

print(f"Selected ({len(selected)}) saved →", ART_SELECTED)
print(f"Dropped ({len(dropped)}) saved  →", ART_DROPPED)

# ---------- 2) Difficulty Scorer ----------
scorer = DifficultyScorer(threshold=DIFF_THRESHOLD)

# ---------- 3) Stream file in chunks ----------
if os.path.exists(OUT_CSV):
    os.remove(OUT_CSV)

tree_type_counts = {"binary": 0, "three-tree": 0}
total_rows = 0

chunk_iter = pd.read_csv(CSV_PATH, chunksize=K_CHUNK, low_memory=False)

for ci, chunk in enumerate(chunk_iter, start=1):

    # Select + encode features
    X_sel = selector.transform(chunk)

    # Apply difficulty scoring
    scores = scorer.transform(X_sel)  # returns difficulty_score + tree_type

    # Merge original + scores
    out = pd.concat([chunk.reset_index(drop=True), scores.reset_index(drop=True)], axis=1)

    # Update counters
    vc = scores["tree_type"].value_counts().to_dict()
    tree_type_counts["binary"]     += int(vc.get("binary", 0))
    tree_type_counts["three-tree"] += int(vc.get("three-tree", 0))
    total_rows += len(chunk)

    # Append output safely
    out.to_csv(OUT_CSV, mode="a", header=not os.path.exists(OUT_CSV), index=False)

    print(f"[Chunk {ci}] rows={len(chunk)} → totals={tree_type_counts}")

# ---------- 4) Save Summary ----------
with open(ART_COUNTS, "w", encoding="utf-8") as f:
    json.dump({"total_rows": total_rows, **tree_type_counts}, f, ensure_ascii=False, indent=2)

print("\n=== Phase 1 COMPLETE ===")
print("Output CSV:", OUT_CSV)
print("Tree Type Counts:", tree_type_counts)
print("Total Rows Processed:", total_rows)

Schema saved: /content/phase1_out/schema.json
Columns: ['Body Size', 'Body Weight', 'Height', 'Bone Structure', 'Complexion', 'General feel of skin', 'Texture of Skin', 'Hair Color', 'Appearance of Hair', 'Shape of face', 'Eyes', 'Eyelashes', 'Blinking of Eyes', 'Cheeks', 'Nose', 'Teeth and gums', 'Lips', 'Nails', 'Appetite', 'Liking tastes', 'Dosha', 'Metabolism Type', 'Climate Preference', 'Stress Levels', 'Sleep Patterns', 'Dietary Habits', 'Physical Activity Level', 'Water Intake', 'Digestion Quality', 'Skin Sensitivity']
Selected (30) saved → /content/phase1_out/selected_features.json
Dropped (0) saved  → /content/phase1_out/dropped_features.json
[Chunk 1] rows=1200 → totals={'binary': 1200, 'three-tree': 0}

=== Phase 1 COMPLETE ===
Output CSV: /content/phase1_out/phase1_augmented.csv
Tree Type Counts: {'binary': 1200, 'three-tree': 0}
Total Rows Processed: 1200


Phase 2:

In [None]:

import os, json
import pandas as pd
from collections import Counter

# use your already existing components
from TreeBuilderV2 import TreeBuilderV2
from TreeSnapshot import TreeSnapshot
from RowBatchSummary import RowBatchSummary
from tokenizer_and_embedding import TokenEmbedding


def row_to_tokens(row, cols):
    toks = []
    for col in cols:
        val = row.get(col)
        if pd.isna(val):
            toks.append(f"{col}=NA")
        else:
            toks.append(f"{col}={str(val).strip()}")
    return toks


def phase2_process(phase1_csv, phase1_dir, out_dir, dim=50, device="cpu", chunk=400):
    os.makedirs(out_dir, exist_ok=True)

    # load the selected features from phase1
    with open(os.path.join(phase1_dir, "selected_features.json"), "r") as f:
        selected_cols = json.load(f)

    # read entire CSV (small-medium dataset)
    df = pd.read_csv(phase1_csv)

    # build vocab
    vocab = set()
    for col in selected_cols:
        for v in df[col].astype(str).unique():
            vocab.add(f"{col}={v.strip()}")
    vocab = sorted(vocab)

    # create embedding & builders
    emb = TokenEmbedding(vocab=vocab, dim=dim, device=device)
    builder_bin = TreeBuilderV2(mode="binary", dim=dim, device=device)
    builder_tri = TreeBuilderV2(mode="three", dim=dim, device=device)

    snapshots_path = os.path.join(out_dir, "snapshots.jsonl")
    paths_path = os.path.join(out_dir, "paths.jsonl")
    open(snapshots_path, "w").close()
    open(paths_path, "w").close()

    records = []
    counts = Counter()

    for idx, row in df.iterrows():
        mode = str(row.get("tree_type", "binary")).strip()
        builder = builder_bin if mode == "binary" else builder_tri

        tokens = row_to_tokens(row, selected_cols)
        vec_pairs = [(t, emb.lookup(t)) for t in tokens]

        root = builder.build_tree(vec_pairs, sample_id=f"row{idx}")
        if root is None:
            snapshot = {"depth": 0, "node_count": 0, "leaf_count": 0}
            path = []
        else:
            snapshot = TreeSnapshot(root).to_dict()
            path = builder.trace_dfs(root)

        # write outputs
        with open(snapshots_path, "a") as f:
            f.write(json.dumps({"row_index": idx, **snapshot}) + "\n")
        with open(paths_path, "a") as f:
            f.write(json.dumps({"row_index": idx, "path": path}) + "\n")

        counts[mode] += 1
        records.append({"row_index": idx, "path": path, "snapshot": snapshot})

    # batch summary
    summary = RowBatchSummary(records).to_dict()
    with open(os.path.join(out_dir, "batch_summary.json"), "w") as f:
        json.dump(summary, f, indent=2)

    with open(os.path.join(out_dir, "tree_type_counts.json"), "w") as f:
        json.dump(dict(counts), f, indent=2)

    print("=== Phase 2 DONE ===")
    print("Snapshots:", snapshots_path)
    print("Paths:", paths_path)
    print("Batch Summary:", os.path.join(out_dir, "batch_summary.json"))
    print("Tree type counts:", counts)

In [None]:

phase2_process(
    phase1_csv="/content/phase1_out/phase1_augmented.csv",
    phase1_dir="/content/phase1_out",
    out_dir="/content/phase2_out"
)

=== Phase 2 DONE ===
Snapshots: /content/phase2_out/snapshots.jsonl
Paths: /content/phase2_out/paths.jsonl
Batch Summary: /content/phase2_out/batch_summary.json
Tree type counts: Counter({'binary': 1200})


Phase 3

In [None]:

import os, json, re
from collections import Counter
import pandas as pd

from TreeBuilderV2 import TreeBuilderV2
from TreeSnapshot import TreeSnapshot
from tokenizer_and_embedding import TokenEmbedding, universal_tokenizer
from embedding_matching import build_fragments, retrieve  # retrieval engine (TF-IDF)


# ---------------------------
# 1) Load selected features
# ---------------------------
def load_selected_features(phase1_dir):
    with open(os.path.join(phase1_dir, "selected_features.json"), "r") as f:
        return json.load(f)


# ---------------------------
# 2) Build fragment corpus for reference mode
# ---------------------------
def make_fragments(csv_path, selected_cols):
    df = pd.read_csv(csv_path)
    frags = []
    for _, row in df.iterrows():
        parts = []
        for c in selected_cols:
            v = row[c]
            if pd.isna(v):
                continue
            parts.append(f"{c}={str(v).strip()}")
        frags.append(" | ".join(parts))
    return frags


def init_reference_index(csv_path, phase1_dir):
    selected_cols = load_selected_features(phase1_dir)
    fragments = make_fragments(csv_path, selected_cols)
    build_fragments(fragments)   # initializes TF-IDF index
    return selected_cols


# ---------------------------
# 3) Quick Reference Gate
# ---------------------------
def reference_gate(rets, threshold=0.75):
    if not rets:
        return False
    avg = sum(r["retrievalconfidence"] for r in rets) / len(rets)
    return avg >= threshold


# ---------------------------
# 4) RL-Lite (entropy smoothing)
# ---------------------------
def smooth_snapshot(snapshot, retrievals):
    # reduce entropy slightly based on retrieval confidence
    if "entropy" not in snapshot:
        snapshot["entropy"] = 0.5

    for r in retrievals:
        c = r["retrievalconfidence"]
        snapshot["entropy"] = max(0, snapshot["entropy"] - (0.03 + 0.05 * c))
        if snapshot["entropy"] < 0.15:
            break
    return snapshot


# ---------------------------
# 5) Very small answer generator (Dosha-aware)
# ---------------------------
def infer_dosha(text):
    t = text.lower()
    if "vata" in t and "pitta" in t and "kapha" in t:
        return "tri-dosha"
    if "vata" in t and "pitta" in t:
        return "vata+pitta"
    if "pitta" in t and "kapha" in t:
        return "pitta+kapha"
    if "vata" in t and "kapha" in t:
        return "vata+kapha"
    if "vata" in t:
        return "vata"
    if "pitta" in t:
        return "pitta"
    if "kapha" in t:
        return "kapha"
    return "unknown"


def generate_answer(dosha):
    d = dosha.lower()
    if d.startswith("vata"):
        return "You show Vata tendencies. Favor warm, moist foods and stable daily routine."
    if d.startswith("pitta"):
        return "Pitta influence detected. Prefer cooling foods, reduce excessive heat/spice."
    if d.startswith("kapha"):
        return "Kapha traits seen. Use light, warm foods and regular active movement."
    if "vata" in d and "pitta" in d:
        return "Vata-Pitta mix. Combine warm + cooling foods; keep schedule consistent."
    if "pitta" in d and "kapha" in d:
        return "Pitta-Kapha mix. Avoid heavy + spicy overload; choose light cooling meals."
    if "vata" in d and "kapha" in d:
        return "Vata-Kapha mix. Favor warm, light meals; avoid cold heavy foods."
    return "Maintain balanced diet, consistent sleep, and moderate exercise."


# ---------------------------
# 6) Query Runtime Function
# ---------------------------
def run_query(query, dim=50, device="cpu", top_k=5):
    # tokenize query and embed
    tokens = universal_tokenizer(query)
    if not tokens:
        return {"error": "empty query"}

    vocab = tokens[:12]
    emb = TokenEmbedding(vocab=vocab, dim=dim, device=device)
    pairs = [(t, emb.lookup(t)) for t in vocab]

    builder = TreeBuilderV2(mode="binary", dim=dim, device=device)
    root = builder.build_tree(pairs, sample_id="query")
    snapshot = TreeSnapshot(root).to_dict() if root else {"entropy": 0.5}

    # reference retrieval
    rets = retrieve(query, k=top_k)

    if reference_gate(rets):
        best = rets[0]["fragmenttext"]
        dosha = infer_dosha(best)
        return {
            "mode": "reference",
            "dosha": dosha,
            "advice": generate_answer(dosha),
            "retrieval_used": best
        }
    else:
        snapshot = smooth_snapshot(snapshot, rets)
        best = rets[0]["fragmenttext"] if rets else ""
        dosha = infer_dosha(best)
        return {
            "mode": "rl_generative",
            "dosha": dosha,
            "advice": generate_answer(dosha),
            "entropy_after_rl": snapshot.get("entropy", None)
        }

In [None]:

init_reference_index(
    csv_path="/content/phase1_out/phase1_augmented.csv",
    phase1_dir="/content/phase1_out"
)

['Body Size',
 'Body Weight',
 'Height',
 'Bone Structure',
 'Complexion',
 'General feel of skin',
 'Texture of Skin',
 'Hair Color',
 'Appearance of Hair',
 'Shape of face',
 'Eyes',
 'Eyelashes',
 'Blinking of Eyes',
 'Cheeks',
 'Nose',
 'Teeth and gums',
 'Lips',
 'Nails',
 'Appetite',
 'Liking tastes',
 'Dosha',
 'Metabolism Type',
 'Climate Preference',
 'Stress Levels',
 'Sleep Patterns',
 'Dietary Habits',
 'Physical Activity Level',
 'Water Intake',
 'Digestion Quality',
 'Skin Sensitivity']

Phase 4

In [None]:

import os, json
import pandas as pd
from collections import Counter, defaultdict
from TreeBuilderV2 import TreeBuilderV2
from TreeSnapshot import TreeSnapshot
from tokenizer_and_embedding import TokenEmbedding, universal_tokenizer


# --------------------------------------------------
# 1) Build Trait → Vata/Pitta/Kapha Weight Table
# --------------------------------------------------
def build_trait_vpk_table(csv_path, selected_cols):
    df = pd.read_csv(csv_path)

    counts = defaultdict(lambda: {"vata":0, "pitta":0, "kapha":0})

    for _, row in df.iterrows():
        dosha_str = str(row["Dosha"]).lower().strip()
        # Identify dosha(s)
        parts = dosha_str.replace(" ", "").split("+")
        # T2 strategy: if tri-dosha → neutral start, traits decide later
        if len(parts) == 3:
            weight = {"vata":1/3, "pitta":1/3, "kapha":1/3}
        else:
            weight = {
                "vata": 1.0 if "vata" in parts else 0.0,
                "pitta":1.0 if "pitta" in parts else 0.0,
                "kapha":1.0 if "kapha" in parts else 0.0
            }

        for c in selected_cols:
            val = str(row[c]).strip()
            tok = f"{c}={val}"
            counts[tok]["vata"]  += weight["vata"]
            counts[tok]["pitta"] += weight["pitta"]
            counts[tok]["kapha"] += weight["kapha"]

    # Normalize each trait row → sum = 1
    trait_vpk = {}
    for tok, d in counts.items():
        s = d["vata"] + d["pitta"] + d["kapha"]
        if s == 0:
            trait_vpk[tok] = [1/3, 1/3, 1/3]  # uninformative trait => neutral
        else:
            trait_vpk[tok] = [d["vata"]/s, d["pitta"]/s, d["kapha"]/s]

    return trait_vpk


# --------------------------------------------------
# 2) Build Constitution Snapshot for a User Query
# --------------------------------------------------
def compute_user_vpk(query, trait_vpk, dim=50, device="cpu"):
    tokens = universal_tokenizer(query)
    if not tokens:
        return None

    vocab = tokens[:12]
    emb = TokenEmbedding(vocab=vocab, dim=dim, device=device)
    pairs = [(t, emb.lookup(t)) for t in vocab]

    builder = TreeBuilderV2(mode="binary", dim=dim, device=device)
    root = builder.build_tree(pairs, sample_id="query")
    snap = TreeSnapshot(root).to_dict() if root else {}

    # Collect trait vectors
    vectors = []
    for t in vocab:
        if t in trait_vpk:
            vectors.append(trait_vpk[t])
    if not vectors:
        vectors = [[1/3,1/3,1/3]]

    # Mean pool → final constitution scores
    v = sum(x[0] for x in vectors) / len(vectors)
    p = sum(x[1] for x in vectors) / len(vectors)
    k = sum(x[2] for x in vectors) / len(vectors)

    # Confidence = dominance * (1 - balance)
    dominant = max(v,p,k)
    balance = abs(v-p) + abs(p-k) + abs(v-k)
    confidence = dominant * (1 - 0.5*balance)

    return {
        "vata": round(v,3),
        "pitta": round(p,3),
        "kapha": round(k,3),
        "dominant": ["vata","pitta","kapha"][ [v,p,k].index(max(v,p,k)) ],
        "confidence": round(max(0.01, confidence), 3)
    }


# --------------------------------------------------
# 3) Top Driving Traits (for explanation)
# --------------------------------------------------
def top_driving_traits(query, trait_vpk, top_n=5):
    tokens = universal_tokenizer(query)
    scored = []
    for t in tokens:
        if t in trait_vpk:
            v,p,k = trait_vpk[t]
            score = max(v,p,k)
            scored.append((t, score))
    scored.sort(key=lambda x: x[1], reverse=True)
    return scored[:top_n]


# --------------------------------------------------
# Wrapper Function: Full constitution reasoning
# --------------------------------------------------
def analyze_user(query, csv_path, phase1_dir, dim=50, device="cpu"):
    selected_cols = json.load(open(os.path.join(phase1_dir, "selected_features.json")))
    trait_vpk = build_trait_vpk_table(csv_path, selected_cols)
    constitution = compute_user_vpk(query, trait_vpk, dim, device)
    drivers = top_driving_traits(query, trait_vpk)
    return {
        "constitution": constitution,
        "top_driving_traits": drivers
    }

In [None]:

%%writefile /content/trait_interpretation_map.py
trait_interpretation_map = {
    # Skin
    "dry skin": "texture of skin (dry, pigments and aging)",
    "rough skin": "texture of skin (dry, pigments and aging)",
    "oily skin": "general feel of skin (smooth and warm, oily t-zone)",
    "sensitive skin": "skin sensitivity (sensitive)",

    # Sleep
    "light sleep": "sleep patterns (short)",
    "difficulty sleeping": "sleep patterns (moderate)",
    "deep sleep": "sleep patterns (long)",

    # Hunger
    "irregular hunger": "appetite (irregular, scanty)",
    "low appetite": "appetite (slow but steady)",
    "high appetite": "appetite (strong, unbearable)",

    # Body Temperature
    "body feels hot": "complexion (fair-skin sunburns easily)",
    "heat in body": "complexion (fair-skin sunburns easily)",
    "cold hands and feet": "general feel of skin (cold and dry)",

    # Weight / Body Size
    "slim": "slim body frame, difficulty gaining weight",
    "skinny": "slim body frame, difficulty gaining weight",
    "cannot gain weight": "body weight (low - difficulties in gaining weight)",
    "overweight": "body weight (heavy - difficulties in losing weight)",

    # Breathing / Chest
    "breathing problem": "digestion quality (weak)",  # Ayurvedic mapping: Prana imbalance → Vata
    "shortness of breath": "digestion quality (weak)",
    "chest congestion": "body weight (heavy - difficulties in losing weight)",  # Kapha cluster
}

Overwriting /content/trait_interpretation_map.py


In [None]:

from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('all-MiniLM-L6-v2')

def semantic_match(query, trait_list, threshold=0.43):
    q_emb = model.encode(query, convert_to_tensor=True)
    best = None
    best_score = 0

    for trait in trait_list:
        t_emb = model.encode(trait, convert_to_tensor=True)
        score = float(util.cos_sim(q_emb, t_emb))
        if score > best_score:
            best, best_score = trait, score

    return best if best_score >= threshold else None

In [None]:

from embedding_matching import retrieve

import numpy as np

from trait_interpretation_map import trait_interpretation_map

def map_query_to_dataset_traits(query, top_k=7):
    query = query.lower()
    matched = []

    # 1) direct phrase → canonical trait
    for phrase, canonical in trait_interpretation_map.items():
        if phrase in query:
            matched.append(canonical)

    # 2) semantic fallback for remaining words
    tokens = [t.strip() for t in query.split() if t.strip()]
    for token in tokens:
        best = semantic_match(token, list(trait_vpk_table.keys()))
        if best:
            matched.append(best)

    # remove duplicates
    matched = list(dict.fromkeys(matched))
    return matched[:top_k]
def compute_user_vpk_mapped(query, trait_vpk, top_k=5):
    matched_raw = map_query_to_dataset_traits(query, top_k=7)

    # APPLY NORMALIZATION HERE
    matched_traits = [normalize_trait_name(t) for t in matched_raw]
    driving = []
    vectors = []

    for t in matched_traits:
        if t in trait_vpk_table:
            vpk = trait_vpk[t]

            # Normalize to (v,p,k)
            if isinstance(vpk, dict):
                vpk = (float(vpk.get("vata", 0)),
                       float(vpk.get("pitta", 0)),
                       float(vpk.get("kapha", 0)))
            elif isinstance(vpk, (list, tuple)) and len(vpk) == 3:
                vpk = (float(vpk[0]), float(vpk[1]), float(vpk[2]))
            else:
                continue

            driving.append((normalize_trait_name(t), vpk))
            vectors.append(vpk)

    # If no match → still return constitution but empty driving_traits
    if not vectors:
        v, p, k = (1/3, 1/3, 1/3)
        return (
            {"vata": v, "pitta": p, "kapha": k, "dominant": "vata", "confidence": 0.3},
            driving,
            0.3
        )

    # Average V P K
    v = sum(x[0] for x in vectors) / len(vectors)
    p = sum(x[1] for x in vectors) / len(vectors)
    k = sum(x[2] for x in vectors) / len(vectors)

    dominant = ["vata","pitta","kapha"][[v,p,k].index(max(v,p,k))]
    balance = abs(v-p) + abs(p-k) + abs(v-k)
    confidence = (max(v,p,k)) * (1 - 0.5*balance)

    return (
        {
            "vata": round(v,3),
            "pitta": round(p,3),
            "kapha": round(k,3),
            "dominant": dominant,
            "confidence": round(confidence,3)
        },
        driving,         # <- now contains the trait signals
        round(confidence,3)
    )

Phase 5

In [None]:

def explain_constitution(result):
    c = result["constitution"]
    drivers = result["top_driving_traits"]

    dominant = c["dominant"]
    v, p, k = c["vata"], c["pitta"], c["kapha"]
    conf = c["confidence"]

    # Constitution label formatting
    if c["vata"] > c["pitta"] > c["kapha"]:
        const_type = "Vata-Pitta (Vata dominant)"
    elif c["pitta"] > c["vata"] > c["kapha"]:
        const_type = "Pitta-Vata (Pitta dominant)"
    else:
        const_type = f"{dominant.capitalize()} dominant"

    text = f"""
Your constitution is **{const_type}**.

**Vata = {v}**, **Pitta = {p}**, **Kapha = {k}**
Confidence Score = **{conf}**

This suggests your system tends toward:

• **Vata** → movement, dryness, creativity, speed
• **Pitta** → metabolic intensity, sharpness, heat
• **Kapha** → stability, lubrication (low in your case)

**Top contributing traits (signals used from your input):**
"""

    for trait, score in drivers:
        text += f"• {trait}  (signal strength: {round(score,3)})\n"

    return text.strip()

In [None]:

def generate_final_response(query, csv_path, phase1_dir):
    result = analyze_user(query, csv_path, phase1_dir)
    text = explain_constitution(result)
    rec = recommend_for(result)

    final = text + "\n\n**Recommended Diet:**\n"
    final += "\n".join(f"• {d}" for d in rec["diet"])
    final += "\n\n**Lifestyle Guidance:**\n"
    final += "\n".join(f"• {l}" for l in rec["lifestyle"])

    return final

In [None]:
def extract_top_driving_traits(matched_traits, trait_vpk_table, dominant, top_n=8):
    scored = []
    for t in matched_traits:
        if t in trait_vpk_table:
            v, p, k = trait_vpk_table[t]
            if dominant == "vata":
                score = v - max(p, k)
            elif dominant == "pitta":
                score = p - max(v, k)
            else:
                score = k - max(v, p)
            scored.append((t, score))

    scored = sorted(scored, key=lambda x: x[1], reverse=True)
    return scored[:top_n]

In [None]:

import pandas as pd

df = pd.read_csv("/content/Updated_Prakriti_With_Features.csv")

# Convert to mapping (example logic - matches your original pipeline)
trait_vpk_table = {}
for col in df.columns:
    if col not in ["Vata", "Pitta", "Kapha"]:  # ignore target columns
        trait_vpk_table[col] = {
            "vata": df[col].str.contains("vata", case=False, na=False).mean(),
            "pitta": df[col].str.contains("pitta", case=False, na=False).mean(),
            "kapha": df[col].str.contains("kapha", case=False, na=False).mean(),
        }

Phase 6

In [None]:

def explain_doctor_style(constitution, driving_traits):
    # Determine dominant dosha based on actual score values
    scores = {
        "Vata": constitution["vata"],
        "Pitta": constitution["pitta"],
        "Kapha": constitution["kapha"]
    }
    dominant = max(scores, key=scores.get)

    text = []
    text.append(f"Based on your reported characteristics, your constitution is predominantly **{dominant}**, with the following proportions:")
    text.append(f"• Vata = {constitution['vata']:.3f}")
    text.append(f"• Pitta = {constitution['pitta']:.3f}")
    text.append(f"• Kapha = {constitution['kapha']:.3f}")
    text.append(f"Confidence Score: {constitution['confidence']:.3f}\n")

    text.append("This suggests the following physiological tendencies:")
    text.append("• **Vata** – movement, nerve activity, dryness, variability")
    text.append("• **Pitta** – metabolic intensity, heat generation, digestion")
    text.append("• **Kapha** – structure, lubrication, metabolic steadiness\n")

    text.append("Key traits from your input influencing this determination:")
    for trait, weight in driving_traits:
        text.append(f"• {trait} (signal weight: {weight:.3f})")

    return "\n".join(text)

In [None]:

def normalize_trait_name(trait):
    trait = trait.strip().lower()

    # Convert "Key=Value" → (key, value)
    if "=" in trait:
        key, value = trait.split("=", 1)
        key = key.strip().lower()
        value = value.strip().lower()

        # Mapping rules (add more slowly over time, not all at once)
        if key.startswith("texture of skin"):
            return "dry skin" if "dry" in value else f"skin texture ({value})"

        if key.startswith("complexion"):
            return f"complexion ({value})"

        if key.startswith("sleep patterns"):
            if "short" in value:
                return "light sleep"
            if "long" in value:
                return "heavy sleep"
            return "balanced sleep"

    return trait  # fallback

In [None]:

def build_anchor_snapshot(root):
    """
    Convert a TreeNodeV1 tree into the node metadata structure
    expected by anchor_extractor.
    """
    nodes = {}

    def dfs(node, depth):
        if node is None:
            return
        # Unique node id (if not present, use memory address)
        nid = getattr(node, "id", id(node))

        # Collect metadata
        nodes[nid] = {
            "token": node.value,           # trait token like "dry" or "body=slim"
            "depth": depth,
            "child_count": len(node.children),
            "is_leaf": (len(node.children) == 0),
            "vector": node.vector.tolist() if hasattr(node, "vector") else None,
        }

        for child in node.children:
            dfs(child, depth+1)

    dfs(root, depth=0)
    return nodes

In [None]:

from anchor_extractor import extract_anchors

nodes = build_anchor_snapshot(root)
anchors = extract_anchors(nodes)

In [None]:
def compute_final_confidence(retrievals, constitution, w_ret=0.5, w_const=0.5):
    if retrievals:
        ret_conf = sum(r["retrievalconfidence"] for r in retrievals) / len(retrievals)
    else:
        ret_conf = 0.0

    const_conf = constitution["confidence"]
    return (w_ret * ret_conf) + (w_const * const_conf)

In [None]:

from Phase2Env import Phase2Env

def rl_refine_tree(root, max_steps=3):
    env = Phase2Env(root)
    for _ in range(max_steps):
        env.step()   # This applies safe structural adjustments
        if env.is_stable():  # if entropy stopped improving, exit early
            break
    return env.tree   # return the improved tree

In [None]:

def process_query(query):
    # Step 1: Tokenize → Embed → Tree
    tokens = universal_tokenizer(query)
    pairs = [(t, emb.lookup(t)) for t in tokens[:12]]
    root = TreeBuilderV2(mode="binary").build_tree(pairs, sample_id="user")

    # Step 2: Anchors
    anchors = extract_anchors(build_anchor_snapshot(root))

    # Step 3: Retrieval
    retrievals = retrieve(query, k=5)

    # Step 4: Constitution scoring
    constitution = compute_user_vpk_mapped(query, trait_vpk_table)

    # Step 5: Confidence combine
    conf = compute_final_confidence(retrievals, constitution)

    # Step 6: High / Low decision
    if conf >= 0.45:     # threshold adjustable
        return generate_explanation(root, anchors, constitution, friendly=False), \
               generate_explanation(root, anchors, constitution, friendly=True)
    else:
        refined = rl_refine_tree(root)   # RL correction
        return generate_explanation(refined, anchors, constitution, friendly=False), \
               generate_explanation(refined, anchors, constitution, friendly=True)

In [None]:

# phase_live_runtime.py

from tokenizer_and_embedding import TokenEmbedding, universal_tokenizer
from TreeBuilderV2 import TreeBuilderV2
from anchor_extractor import extract_anchors
from embedding_matching import retrieve
from smoother import smooth_text
from DecoderV1 import DecoderV1


def build_anchor_snapshot(root):
    nodes = {}

    def dfs(node, d):
        nid = id(node)
        nodes[nid] = {
            "depth": d,
            "token": node.value,
            "children": [id(c) for c in node.children],
            "is_leaf": (len(node.children) == 0)
        }
        for c in node.children:
            dfs(c, d+1)

    dfs(root, 0)
    return {"root": id(root), "nodes": nodes}


def rl_lite_refine(root, threshold=0.30):
    def prune(node):
        keep = []
        for c in node.children:
            if len(c.children) == 0:
                if c.get_confidence() >= threshold:
                    keep.append(c)
            else:
                prune(c)
                keep.append(c)
        node.children = keep
    prune(root)
    return root


def analyze_user_input(query):
    # 1) Tokenize
    tokens = universal_tokenizer(query)

    # 2) Embeddings
    emb = TokenEmbedding(vocab=tokens, dim=50)
    vec_pairs = [(t, emb.lookup(t)) for t in tokens[:12]]

    # 3) Tree Build
    root = TreeBuilderV2(mode="binary").build_tree(vec_pairs, sample_id="live")

    # 4) Anchors
    anchors = extract_anchors(build_anchor_snapshot(root))

    # 5) Retrieve reference samples
    retrieved = retrieve(query, k=5)

    # 6) Constitution calculation (calls your in-memory function)
    constitution = compute_user_vpk_mapped(query, trait_vpk_table)

    # 7) Driving traits
    driving = extract_top_driving_traits(
        constitution["matched_traits"],
        trait_vpk_table,
        constitution["dominant"]
    )

    # 8) Confidence fusion
    if retrieved:
        ret_conf = sum(r["retrievalconfidence"] for r in retrieved) / len(retrieved)
    else:
        ret_conf = 0.0
    confidence = 0.5 * ret_conf + 0.5 * constitution["confidence"]

    # 9) RL-lite refine if weak confidence
    if confidence < 0.45:
        root = rl_lite_refine(root)

    # 10) Convert Tree → Natural Explanation
    structure = DecoderV1(dim=48, device="cpu", use_smoother=True).explain(root)
    structure = smooth_text(structure)

    # 11) Explanation (Doctor style)
    doctor = explain_doctor_style(constitution, driving)

    # (Optional) Friendly style available too:
    friendly = explain_friendly_style(constitution, driving)

    # 12) Diet + Lifestyle
    rec = recommend_for({"constitution": constitution})

    return {
        "confidence": confidence,
        "structure": structure,
        "anchors": anchors,
        "doctor": doctor,
        "friendly": friendly,
        "diet": rec["diet"],
        "lifestyle": rec["lifestyle"]
    }

In [None]:

def get_dominant_ordered_traits(constitution):
    dom = constitution["dominant"]  # "vata" / "pitta" / "kapha"
    traits = constitution["matched_traits"]  # this is a list, but shape may vary

    scored = []
    for entry in traits:
        # Case 1: (trait, weights)
        if len(entry) == 2:
            trait, weights = entry
        # Case 2: (trait, weights, extra_value)
        elif len(entry) == 3:
            trait, weights, _ = entry
        # Otherwise skip invalid entry
        else:
            continue

        # weights can be dict or list
        if isinstance(weights, dict):
            score = weights.get(dom, 0.0)
        elif isinstance(weights, (list, tuple)) and len(weights) == 3:
            idx = {"vata": 0, "pitta": 1, "kapha": 2}[dom]
            score = weights[idx]
        else:
            score = float(weights) if isinstance(weights, (int, float)) else 0.0

        scored.append((trait, score))

    scored.sort(key=lambda x: x[1], reverse=True)
    return scored[:8]

In [None]:

def adjust_pitta_boost(constitution, text):
    text = text.lower()
    pitta_boost = 0.0

    # Warm body / heat signals
    if "warm" in text or "heat" in text or "hot" in text:
        pitta_boost += 0.08

    # Irritation / anger / intensity signals
    if "irrit" in text or "anger" in text or "frustrat" in text:
        pitta_boost += 0.10

    # Apply boost
    constitution["pitta"] += pitta_boost
    constitution["vata"] -= pitta_boost * 0.5  # reduce Vata slightly to re-balance

    # Re-normalize so total = 1.0
    total = constitution["vata"] + constitution["pitta"] + constitution["kapha"]
    constitution["vata"] /= total
    constitution["pitta"] /= total
    constitution["kapha"] /= total

    return constitution

In [None]:

def _normalize_driving_traits(driving_traits, constitution, max_items=8):
    """
    driving_traits can be:
      - [("trait", 0.324), ...]                       # score as float
      - [("trait", {"vata":0.5,"pitta":0.2,"kapha":0.3}), ...]
      - [("trait", [v, p, k]), ...]
      - ["trait1", "trait2", ...]                      # no scores
    constitution needs: {"dominant": "vata"|"pitta"|"kapha"}
    """
    dom = (constitution.get("dominant") or "vata").lower()
    dom_idx = {"vata": 0, "pitta": 1, "kapha": 2}[dom]

    normalized = []
    for item in (driving_traits or []):
        # Case 1: plain string trait
        if isinstance(item, str):
            normalized.append((item, 0.0))
            continue

        # Case 2: tuple/list like (trait, weights)
        if isinstance(item, (list, tuple)) and len(item) >= 1:
            trait = str(item[0])

            score = 0.0
            if len(item) >= 2:
                weights = item[1]
                # float score
                if isinstance(weights, (int, float)):
                    score = float(weights)
                # dict per dosha
                elif isinstance(weights, dict):
                    score = float(weights.get(dom, 0.0))
                # list/tuple [v,p,k]
                elif isinstance(weights, (list, tuple)) and len(weights) >= 3:
                    try:
                        score = float(weights[dom_idx])
                    except Exception:
                        score = 0.0

            normalized.append((trait, score))
            continue

        # Fallback: unknown shape → stringify
        normalized.append((str(item), 0.0))

    # sort by score desc, keep top
    normalized.sort(key=lambda x: x[1], reverse=True)
    return normalized[:max_items]


def explain_friendly_style(constitution, driving_traits, max_items=8):
    dom = constitution["dominant"].capitalize()
    messages = {
        "Vata":  "Your energy tends to move quickly — creativity, expressiveness, fast thinking. "
                 "If balance slips, dryness, irregular digestion, or light sleep can show up.",
        "Pitta": "Your metabolism and mind are strong — intensity, focus, decisiveness. "
                 "If balance slips, heat, irritability, or overwork can show up.",
        "Kapha": "You are steady and grounded — calm, patient, resilient. "
                 "If balance slips, heaviness, sluggishness, or low motivation can show up.",
    }

    ordered = _normalize_driving_traits(driving_traits, constitution, max_items=max_items)
    ordered = merge_sleep_traits(ordered)  # <--- ADD THIS
    trait_lines = format_trait_list(ordered)

    text = (
        f"You show a **{dom}-dominant** constitution.\n"
        f"{messages.get(dom, '')}\n\n"
        "Things your body signals clearly:\n\n"
        f"{trait_lines if trait_lines else '(no strong signals detected)'}\n\n"
        "Your body works best when balance is maintained. I’ll guide your diet and daily rhythm next."
    )
    return text


def recommend_for(constitution):
    dom = constitution["dominant"].lower()

    if dom == "vata":
        diet = [
            "Warm cooked meals (soups, stews, kichadi)",
            "Healthy fats like ghee, coconut, olive oil",
            "Avoid cold salads, dry snacks and skipping meals"
        ]
        lifestyle = [
            "Sleep before 10:30 PM",
            "Keep a consistent routine",
            "Gentle yoga / stretching; avoid high intensity late evening"
        ]

    elif dom == "pitta":
        diet = [
            "Cooling foods like cucumber, coconut water, sweet fruits",
            "Avoid excessive spicy, sour, or fermented foods"
        ]
        lifestyle = [
            "Avoid late-night work (adds heat)",
            "Practice slow breathing & meditation daily"
        ]

    else:  # kapha
        diet = [
            "Light warm meals; reduce heavy dairy and fried foods",
            "Use ginger and black pepper to stimulate digestion"
        ]
        lifestyle = [
            "Regular morning physical activity",
            "Avoid oversleeping and daytime naps"
        ]

    return diet, lifestyle   # <-- THIS part is MOST IMPORTANT


# ---------- Minimal glue you can call right now ----------
# Supply your already-computed objects:
#   constitution = {"vata":..., "pitta":..., "kapha":..., "dominant":"pitta", "confidence":...}
#   driving_traits = [("Dosha=Vata", 1.0), ("Appetite=Irregular, Scanty", 0.267), ...]  # any of the accepted shapes
def make_friendly_output(constitution, driving_traits):
    text = explain_friendly_style(constitution, driving_traits)   # <- here
    diet, lifestyle = recommend_for(constitution)
    return text, diet, lifestyle

# Example (replace with your real objects):
#text, diet, life = make_friendly_output(constitution, driving_traits)
#print(text); print("\nDIET:", diet); print("\nLIFESTYLE:", life)

In [None]:

def finalize_constitution(constitution):
    # Recalculate dominant dosha based on final corrected scores
    scores = {
        "vata": constitution["vata"],
        "pitta": constitution["pitta"],
        "kapha": constitution["kapha"]
    }
    constitution["dominant"] = max(scores, key=scores.get)
    return constitution

In [None]:

# ---- robust reorder (handles float, dict, or [v,p,k] tuples) ----
def reorder_traits_by_dominance(driving_traits, dominant):
    idx_map = {"vata": 0, "pitta": 1, "kapha": 2}
    ordered = []
    for trait, weights in driving_traits:
        if isinstance(weights, (list, tuple)) and len(weights) == 3:
            score = float(weights[idx_map[dominant]])
        elif isinstance(weights, dict):
            score = float(weights.get(dominant, 0.0))
        else:  # already a scalar score
            score = float(weights)
        ordered.append((trait, score))
    ordered.sort(key=lambda x: x[1], reverse=True)
    return ordered

# ---- quick test runner (friendly style = 2) ----

In [None]:
# Build a usable root tree from trait_vpk_table keys (phase-live only)
def build_root_from_traits(trait_vpk_table, dim=48, device="cpu"):
    # 1) Prepare vocab and embedding table
    tokens = list(trait_vpk_table.keys())
    emb = TokenEmbedding(tokens, dim=dim, device=device)

    # 2) Make proper (token, vector) pairs (each vector must be a 1D torch.Tensor of length=dim)
    import torch
    vec_pairs = []
    for t in tokens:
        v = emb.lookup(t)                  # shape: [dim]
        if isinstance(v, torch.Tensor) and v.dim() == 1 and v.shape[0] == dim:
            vec_pairs.append((t, v))
    # 3) Build a small binary tree
    builder = TreeBuilderV2(device=device, dim=dim, mode="binary")
    return builder.build_tree(vec_pairs, sample_id="prakriti_root")

In [None]:

def deduplicate_trait_values(driving_traits):
    """
    driving_traits = [(trait_string, score), ...]
    Keep only highest scoring trait per category.
    """
    best = {}
    for trait, score in driving_traits:
        if "=" in trait:
            category = trait.split("=")[0].strip()
        else:
            category = trait

        if category not in best or score > best[category][1]:
            best[category] = (trait, score)

    return list(best.values())

In [None]:

import pickle

CONFIDENCE_THRESHOLD = 0.45
TREE_SAVE_PATH = "chaturya_tree.pkl"

def load_tree():
    try:
        with open(TREE_SAVE_PATH, "rb") as f:
            return pickle.load(f)
    except Exception:
        return None

def save_tree(root):
    try:
        with open(TREE_SAVE_PATH, "wb") as f:
            pickle.dump(root, f)
    except Exception:
        pass

# ---- RL "repair" stub uses your Phase2Env safely if present, else no-op ----
def rl_generative_repair(root, query):
    try:
        # Your Phase2Env expects (builder, data, feature_vectors, ...) in the full pipeline.
        # For phase-live console we keep this harmless & local.
        return root  # no-op in phase-live; your RL path stays intact elsewhere
    except Exception:
        return root

# analyze_user_input(query) must exist in memory (your phase-live cell).
# It should return (constitution, driving_traits, retrieval_score).
# DecoderV1 → TreeEncoderWithAttention now works with num_heads=5. 3

def chatbot_answer(query):
    constitution, driving_traits, retrieval_score = compute_user_vpk_mapped(query, trait_vpk_table)

    # NEW CHECK: If matched traits are empty → ask user to describe traits
    if not constitution.get("matched_traits"):
        return (
            "I didn’t detect any physical or lifestyle traits in your input.\n"
            "Please describe what you experience physically.\n"
            "Example: 'dry skin', 'oily hair', 'irregular hunger', 'light sleep', 'heat in body'.",
            [], [], None
        )

    driving_traits = deduplicate_trait_values(driving_traits)
    text, diet, lifestyle = make_friendly_output(constitution, driving_traits)

    return text, diet, lifestyle, None

def init_or_load_root(trait_vpk_table):
    root = load_tree()
    if root is not None:
        return root
    print("No saved tree found → building tree from embeddings...")
    root = build_root_from_traits(trait_vpk_table, dim=50, device="cpu")
    save_tree(root)
    return root

def start_console_chatbot(root):
    print("\n=== Chaturya Ayurvedic Chatbot ===")
    print("Type your symptoms or traits. Type 'exit' to stop.\n")

    while True:
        user = input("You: ").strip()
        if user.lower() in ["exit", "quit", "bye"]:
            print("Chaturya: Wishing you balance and well-being.")
            break

        text, diet, lifestyle, new_root = chatbot_answer(user)  # <-- FIXED (remove root)

        # If no traits detected
        if diet == [] and lifestyle == []:
            print("\nChaturya:", text, "\n")
            continue

        print("\n--- RESULT ---\n")
        print(text)

        print("\nDIET RECOMMENDATIONS:")
        for d in diet:
            print("• " + d)

        print("\nLIFESTYLE GUIDANCE:")
        for l in lifestyle:
            print("• " + l)

        print("\n----------------\n")

        # RL update placeholder (still safe to leave)
        if new_root is not None:
            root = new_root

    save_tree(root)

#start_console_chatbot()

In [None]:

def interpret_trait(trait_name):
    interpretations = {
        "texture of skin (dry, pigments and aging)":
            "Dry or easily dehydrated skin suggests elevated Vata regulating moisture in the body.",
        "sleep patterns (short)":
            "Your sleep is light and easily disturbed — a key Vata characteristic.",
        "sleep patterns (moderate)":
            "Your sleep is generally balanced but affected by stress or irregular routine.",
        "sleep patterns (long)":
            "Deep and heavy sleep suggests Kapha grounding influence.",
    }
    return interpretations.get(trait_name.lower(), None)

In [None]:

def merge_sleep_traits(traits):
    sleep_signals = [t for t in traits if "sleep" in t[0].lower()]
    if not sleep_signals:
        return traits

    # pick the strongest sleep signal
    sleep_signals.sort(key=lambda x: x[1], reverse=True)
    best = sleep_signals[0]

    interpreted = ("Your sleep tends to be light and easily disturbed — "
                   "a classic sign of Vata affecting the nervous system.")

    # remove all sleep-related entries from list
    traits = [t for t in traits if "sleep" not in t[0].lower()]

    # add interpreted summary back with strongest score
    traits.insert(0, (interpreted, best[1]))

    return traits

In [None]:

def format_trait_list(driving_traits):
    lines = []
    for trait, score in driving_traits:
        interpretation = interpret_trait(trait.lower())
        if interpretation:
            lines.append(f"• {interpretation} (signal strength: {score:.3f})")
        else:
            lines.append(f"• {trait} (signal strength: {score:.3f})")
    return "\n".join(lines)

In [None]:

from sentence_transformers import SentenceTransformer
import numpy as np

model_embed = SentenceTransformer('all-MiniLM-L6-v2')

trait_texts = list(trait_vpk_table.keys())
trait_vectors = model_embed.encode(trait_texts, normalize_embeddings=True)

# Store for fast lookup
trait_index = list(zip(trait_texts, trait_vectors))

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [None]:
trait_synonyms = {}

In [None]:

import pandas as pd

df = pd.read_csv("/content/Updated_Prakriti_With_Features.csv")

trait_vpk_table = {}  # new corrected table

for col in df.columns:
    if df[col].dtype == object:  # only categorical columns
        values = df[col].dropna().unique()
        for v in values:
            key = f"{col}={v}".strip()
            # initialize placeholder VPK, will fill next step
            trait_vpk_table[key] = {"vata":0, "pitta":0, "kapha":0}

In [None]:

for col in df.columns:
    if df[col].dtype == object:
        grouped = df.groupby(col)["Dosha"]  # Uses your labeled Prakriti column
        for value, subset in grouped:
            key = f"{col}={value}".strip()
            counts = subset.value_counts(normalize=True).to_dict()
            trait_vpk_table[key] = {
                "vata":  counts.get("Vata", 0.0),
                "pitta": counts.get("Pitta", 0.0),
                "kapha": counts.get("Kapha", 0.0)
            }

In [None]:

import json

with open("trait_vpk_table.json", "w") as f:
    json.dump(trait_vpk_table, f, indent=2)

In [None]:

from tokenizer_and_embedding import TokenEmbedding, build_trait_embedding_index

embedder = TokenEmbedding(list(trait_vpk_table.keys()), dim=48, device="cpu")
trait_index = build_trait_embedding_index(trait_vpk_table, embedder)

In [None]:

normalized_trait_vpk = {}

for raw, vpk in trait_vpk_table.items():
    norm = normalize_trait_name(raw)
    normalized_trait_vpk[norm] = vpk

trait_vpk_table = normalized_trait_vpk

In [None]:

query = "dry skin light sleep"
constitution, driving, score = compute_user_vpk_mapped(query, trait_vpk_table)
text, diet, lifestyle = make_friendly_output(constitution, driving)
print(text)
print(diet)
print(lifestyle)

You show a **Vata-dominant** constitution.
Your energy tends to move quickly — creativity, expressiveness, fast thinking. If balance slips, dryness, irregular digestion, or light sleep can show up.

Things your body signals clearly:

• Your sleep tends to be light and easily disturbed — a classic sign of Vata affecting the nervous system. (signal strength: 0.230)
• Dry or easily dehydrated skin suggests elevated Vata regulating moisture in the body. (signal strength: 0.250)

Your body works best when balance is maintained. I’ll guide your diet and daily rhythm next.
['Warm cooked meals (soups, stews, kichadi)', 'Healthy fats like ghee, coconut, olive oil', 'Avoid cold salads, dry snacks and skipping meals']
['Sleep before 10:30 PM', 'Keep a consistent routine', 'Gentle yoga / stretching; avoid high intensity late evening']


# Phase live 2.0

In [None]:
DIM = 48  # <- use 48 across TokenEmbedding, TreeBuilderV2, DecoderV1, TLite, etc.
CONFIDENCE_THRESHOLD = 0.45

In [None]:

import re, torch, torch.nn.functional as F

def universal_tokenizer(text: str):
    if not text: return []
    return re.findall(r'\d+\.\d+|\d+|[A-Za-z]+|[+\-*/^=():]', text.lower())

class TokenEmbedding:
    def __init__(self, vocab, dim=DIM, device='cpu'):
        self.dim, self.device = dim, device
        self.vocab = ['<unk>'] + list(vocab)
        self.word2idx = {w:i for i,w in enumerate(self.vocab)}
        self.emb = torch.randn(len(self.vocab), dim, device=device) / (dim**0.5)

    def lookup(self, token:str) -> torch.Tensor:
        return self.emb[self.word2idx.get(token, 0)]

def build_trait_embedding_index(trait_names, embedder):
    idx = {}
    for name in trait_names:
        toks = universal_tokenizer(name)
        if not toks: continue
        v = sum(embedder.lookup(t) for t in toks) / len(toks)
        v = v / (v.norm() + 1e-9)
        idx[name] = v
    return idx

def semantic_match_traits(query, trait_index, embedder, top_k=8):
    toks = universal_tokenizer(query)
    if not toks: return []
    q = sum(embedder.lookup(t) for t in toks) / len(toks)
    q = q / (q.norm() + 1e-9)
    sims = []
    for trait, vec in trait_index.items():
        s = F.cosine_similarity(q.unsqueeze(0), vec.unsqueeze(0)).item()
        sims.append((trait, s))
    sims.sort(key=lambda x: x[1], reverse=True)
    strong = [t for t,s in sims if s >= 0.72][:top_k]
    med = [t for t,s in sims if 0.55 <= s < 0.72][:max(0, top_k-len(strong))]
    return strong + med

In [None]:

def build_trait_catalog_from_dataset_columns(df_like_columns):
    """
    Input: list of column names (exactly the 30 you've shown).
    Output: a set of *canonical* trait phrases we expect, e.g.:
        "height (short)", "body size (slim)", ...
    If you also have value vocabularies per column, you can pass a dict instead.
    For now we'll just accept already-canonicalized strings you built offline.
    """
    # If you already have canonical keys (as in your 93-list), just return them.
    # Here, we keep a hook if you later want to generate from unique values per column.
    return set()

def normalize_trait_key(k: str) -> str:
    # unify case/spaces to match your 93-key style as in your sample
    return k.strip().lower()

In [None]:

def _get_vpk_tuple(weights):
    # weights can be dict {'vata':..,'pitta':..,'kapha':..} or list/tuple [v,p,k]
    if isinstance(weights, dict):
        return (float(weights.get('vata',0)), float(weights.get('pitta',0)), float(weights.get('kapha',0)))
    if isinstance(weights, (list, tuple)) and len(weights) >= 3:
        return (float(weights[0]), float(weights[1]), float(weights[2]))
    # fallback single scalar: treat as vata
    if isinstance(weights, (int, float)):
        return (float(weights), 0.0, 0.0)
    return (1/3, 1/3, 1/3)

def compute_user_vpk_mapped(query, trait_vpk_table, trait_index, embedder, top_k=8):
    """
    Returns: constitution(dict), driving_traits(list[(trait,score)]), retrieval_score(float)
    - constitution has keys: vata, pitta, kapha, dominant, confidence, matched_traits
    - driving_traits is ranked by dominant dosha's weight where available
    """
    matched = semantic_match_traits(query, trait_index, embedder, top_k=top_k)
    retrieval_score = 0.0
    if matched:
        # simple retrieval score: average of top-5 cosine ranks approximated by list length
        retrieval_score = min(1.0, len(matched) / float(top_k))

    vectors = []
    driving_traits = []
    for t in matched:
        if t in trait_vpk_table:
            w = trait_vpk_table[t]
            v,p,k = _get_vpk_tuple(w)
            vectors.append((v,p,k))
            # for driving traits, store the *dominant* weight later
            driving_traits.append((t, {'vata':v, 'pitta':p, 'kapha':k}))

    if not vectors:
        vectors = [(1/3, 1/3, 1/3)]

    v = sum(x[0] for x in vectors) / len(vectors)
    p = sum(x[1] for x in vectors) / len(vectors)
    k = sum(x[2] for x in vectors) / len(vectors)

    dom_name = ['vata','pitta','kapha'][[v,p,k].index(max(v,p,k))]
    balance = abs(v-p) + abs(p-k) + abs(v-k)
    confidence = max(0.0, min(1.0, max(v,p,k) * (1 - 0.5*balance)))

    constitution = {
        "vata": round(v,3), "pitta": round(p,3), "kapha": round(k,3),
        "dominant": dom_name, "confidence": round(confidence,3),
        "matched_traits": matched
    }
    return constitution, driving_traits, retrieval_score

In [None]:

trait_explanations = {
    "height (short)": "Shorter body frame is commonly seen in Vata-dominant constitutions.",
    "height (average)": "Balanced height indicates neutral structural development.",
    "height (tall)": "Taller body frames are more common in Kapha constitutions.",
    "slim body frame, difficulty gaining weight": "Leanness and difficulty gaining weight strongly reflect elevated Vata.",
    "large/heavy frame, gains weight easily": "Heaviness and easy weight gain are associated with Kapha.",
    "texture of skin (dry, pigments and aging)": "Dry or easily dehydrated skin suggests elevated Vata affecting moisture.",
    "complexion (fair-skin sunburns easily)": "Sensitive, heat-reactive skin suggests Pitta influence.",
    "regular balanced sleep": "Your sleep rhythm is stable and balanced.",
    "light sleep / wakes easily": "Light or easily disturbed sleep is a sign of Vata affecting the nervous system.",
}

In [None]:

def _normalize_driving_traits(driving_traits, constitution, max_items=8):
    dom = (constitution.get("dominant") or "vata").lower()
    dom_idx = {"vata":0,"pitta":1,"kapha":2}[dom]

    out = []
    for item in driving_traits or []:
        if isinstance(item, str):
            out.append((item, 0.0))
            continue

        if isinstance(item, (list, tuple)) and len(item) >= 1:
            t = str(item[0])
            w = 0.0

            if len(item) >= 2:
                weights = item[1]
                if isinstance(weights, (int, float)):
                    w = float(weights)
                elif isinstance(weights, dict):
                    w = float(weights.get(dom, 0.0))
                elif isinstance(weights, (list, tuple)) and len(weights) >= 3:
                    try:
                        w = float(weights[dom_idx])
                    except:
                        w = 0.0

            out.append((t, w))
            continue

        out.append((str(item), 0.0))

    # Sort by strength
    out.sort(key=lambda x: x[1], reverse=True)

    # FILTER: keep only meaningful traits (strength >= 0.18)
    filtered = [(t, w) for (t, w) in out if w >= 0.18]

    # If filtering removes everything → show top 2 signals
    if not filtered:
        filtered = out[:2]

    return filtered[:max_items]

def collapse_by_dimension(driving_traits, trait_dimensions):
    grouped = {}
    for trait, weight in driving_traits:
        dim = trait_dimensions.get(trait, None)
        if dim is None:
            # traits we don't yet have labeled → keep as-is, no grouping
            grouped[trait] = weight
            continue
        if dim not in grouped or weight > grouped[dim][1]:
            grouped[dim] = (trait, weight)

    # Flatten result
    result = []
    for val in grouped.values():
        if isinstance(val, tuple):
            result.append(val)
        else:
            result.append((val,0.0))
    return result

def explain_friendly_style(constitution, driving_traits, max_items=6):
    dom = constitution["dominant"].capitalize()

    messages = {
        "Vata": "Your energy moves quickly — creativity, expression, fast thinking. If out of balance, dryness, irregular digestion, or light sleep may appear.",
        "Pitta": "Your metabolism and focus are strong — decisiveness, clarity. If out of balance, heat, irritability, or intensity may rise.",
        "Kapha": "Steady and grounded — patience, resilience. If out of balance, heaviness or low motivation may develop."
    }

    ordered = _normalize_driving_traits(driving_traits, constitution, max_items=max_items)
    ordered = merge_sleep_traits(ordered)

    trait_text = explain_traits_pretty(ordered)

    return (
        f"You show a **{dom}-dominant** constitution.\n"
        f"{messages.get(dom,'')}\n\n"
        f"Things your body signals clearly:\n\n"
        f"{trait_text}\n\n"
        f"Your body works best when balance is maintained. I’ll guide your diet and daily rhythm next."
    )
def recommend_for(constitution):
    dom = (constitution.get("dominant") or "").lower()
    if dom == "pitta":
        return ([
            "Cooling foods like cucumber, coconut water, sweet fruits",
            "Avoid excessive spicy, sour, or fermented foods",
        ],[
            "Avoid late-night work (adds heat)",
            "Practice slow breathing & meditation daily",
        ])
    if dom == "kapha":
        return ([
            "Light warm meals; reduce heavy/dairy foods",
            "Use ginger and black pepper to stimulate metabolism",
        ],[
            "Regular morning physical activity",
            "Avoid oversleeping / daytime naps",
        ])
    # default vata
    return ([
        "Warm cooked meals (soups, stews, kichadi)",
        "Healthy fats like ghee, coconut, olive oil",
        "Avoid cold salads, dry snacks and skipping meals",
    ],[
        "Sleep before 10:30 PM",
        "Keep a consistent routine",
        "Gentle yoga / stretching; avoid high intensity late evening",
    ])
def merge_sleep_traits(ordered):
    final = []
    sleep_score = 0.0

    for trait, score in ordered:
        t = trait.lower()
        if "sleep patterns" in t or "light sleep" in t or "wakes easily" in t:
            sleep_score = max(sleep_score, score)  # keep strongest
        else:
            final.append((trait, score))

    if sleep_score > 0:
        final.insert(0, (
            "Your sleep becomes easily disturbed or inconsistent — a classic Vata influence on the nervous system.",
            sleep_score
        ))
    return final

def make_friendly_output(constitution, driving_traits):
    # Normalize + rank driving traits
    ordered = _normalize_driving_traits(driving_traits, constitution, max_items=6)
    ordered = merge_sleep_traits(ordered)
    # Convert dataset trait labels → meaningful Ayurvedic phrases
    readable_traits = []
    for trait, score in ordered:
        key = trait.lower().strip()
        explanation = trait_explanations.get(key, trait)  # fallback to raw label if missing
        readable_traits.append((explanation, score))

    # Compose text output using readable trait explanations
    lines = "\n".join(
        f"• {txt} (signal strength: {round(score,3)})"
        for txt, score in readable_traits
    ) if readable_traits else "(no strong signals detected)"

    dom = constitution["dominant"].capitalize()
    text = (
        f"You show a **{dom}-dominant** constitution.\n"
        f"{AYURVEDIC_MESSAGES[dom]}\n\n"
        "Things your body signals clearly:\n\n"
        f"{lines}\n\n"
        "Your body works best when balance is maintained. I’ll guide your diet and daily rhythm next."
    )

    diet, lifestyle = recommend_for(constitution)
    return text, diet, lifestyle

In [None]:
AYURVEDIC_MESSAGES = {
    "Vata": "Your energy tends to move quickly — creativity, expressiveness, fast thinking. If balance slips, dryness, irregular digestion, or light sleep can show up.",
    "Pitta": "Your metabolism and mind are strong — intensity, focus, decisiveness. If balance slips, heat, irritability, or overwork can show up.",
    "Kapha": "You are steady and grounded — calm, patient, resilient. If balance slips, heaviness, sluggishness, or low motivation can show up."
}

In [None]:

def explain_traits_pretty(ordered_traits):
    lines = []
    for trait, score in ordered_traits:
        phrase = trait_explanations.get(trait.lower(), None)
        if phrase:
            lines.append(f"• {phrase} (signal strength: {score:.3f})")
        else:
            lines.append(f"• {trait} (signal strength: {score:.3f})")
    return "\n".join(lines) if lines else "(no strong signals detected)"

In [None]:

symptom_to_trait = {
    "breathing problem": [
        ("general feel of skin (dry and thin, cool to touch, rough)", "vata"),
        ("stress levels (high)", "vata"),
        ("sleep patterns (short)", "vata")
    ],
    "shortness of breath": [
        ("general feel of skin (dry and thin, cool to touch, rough)", "vata"),
        ("digestion quality (weak)", "vata"),
    ],
    "anxiety": [
        ("sleep patterns (short)", "vata"),
        ("stress levels (high)", "vata")
    ],
    "overthinking": [
        ("sleep patterns (light / wakes easily)", "vata"),
        ("stress levels (high)", "vata")
    ],
    "body heat": [
        ("complexion (fair-skin sunburns easily)", "pitta"),
        ("digestion quality (strong)", "pitta")
    ],
    "anger / irritability": [
        ("stress levels (high)", "pitta"),
        ("sleep patterns (short)", "pitta")
    ],
    "lethargy": [
        ("body weight (heavy - difficulties in losing weight)", "kapha"),
        ("physical activity level (low)", "kapha")
    ],
    "slow digestion": [
        ("digestion quality (weak)", "kapha"),
        ("metabolism type (slow)", "kapha")
    ],
    "craving sweets": [
        ("liking tastes (sweet / sour / salty)", "kapha")
    ],
    "dry skin": [
        ("texture of skin (dry, pigments and aging)", "vata")
    ],
    "light sleep": [
        ("sleep patterns (short)", "vata")
    ]
}

In [None]:

def map_query_to_dataset_traits(query, top_k=5):
    query = query.lower()
    found = []

    # 1) Symptom → trait mapping
    for symptom, trait_list in symptom_to_trait.items():
        if symptom in query:
            for trait, _ in trait_list:
                found.append(trait)

    # 2) Existing dataset phrase matching
    for phrase, mapped_trait in trait_synonyms.items():
        if phrase in query:
            found.append(mapped_trait)

    # De-duplicate but keep first occurrences
    found = list(dict.fromkeys(found))
    return found[:top_k]

In [None]:

def compute_user_constitution(query):
    constitution, driving_traits, retrieval_score = compute_user_vpk_mapped(
        query,
        trait_vpk_table,
        trait_index,
        embedder
    )
    return constitution, driving_traits, retrieval_score

In [None]:

# ---- Prepare trait table from your 93-sample (illustrative) ----
# Put your full dict here. Keys must be canonical, e.g. 'height (short)'.
trait_vpk_table = {
    'moderate build, stable weight': {'vata':0.1276595745,'pitta':0.1914893617,'kapha':0.0},
    'slim body frame, difficulty gaining weight': {'vata':0.4545454545,'pitta':0.0303030303,'kapha':0.0},
    'large/heavy frame, gains weight easily': {'vata':0.05,'pitta':0.1,'kapha':0.3},
    'body weight (moderate - no difficulties in gaining or losing weight)': {'vata':0.1296296296,'pitta':0.2222222222,'kapha':0.0},
    'body weight (low - difficulties in gaining weight)': {'vata':0.4482758621,'pitta':0.0,'kapha':0.0},
    'body weight (heavy - difficulties in losing weight)': {'vata':0.1176470588,'pitta':0.0,'kapha':0.3529411765},
    'height (average)': {'vata':0.1363636364,'pitta':0.2045454545,'kapha':0.0454545455},
    'height (short)': {'vata':0.4210526316,'pitta':0.0526315789,'kapha':0.0263157895},
    'height (tall)': {'vata':0.0,'pitta':0.0555555556,'kapha':0.1666666667},
    'bone structure (large, broad shoulders , heavy bone structure)': {'vata':0.0,'pitta':0.1333333333,'kapha':0.4},
    # ... add the rest of your 93 keys here ...
    'texture of skin (dry, pigments and aging)': {'vata':0.25,'pitta':0.0,'kapha':0.0},
    'complexion (fair-skin sunburns easily)': {'vata':0.0,'pitta':0.208,'kapha':0.0},
    'sleep patterns (short)': {'vata':0.196,'pitta':0.0,'kapha':0.0},
    'sleep patterns (moderate)': {'vata':0.230,'pitta':0.0,'kapha':0.0},
    'sleep patterns (long)': {'vata':0.211,'pitta':0.0,'kapha':0.0},
}

# Build semantic index from trait keys
_embedder = TokenEmbedding(vocab=set(" ".join(trait_vpk_table.keys()).split()), dim=DIM, device='cpu')
_trait_index = build_trait_embedding_index(list(trait_vpk_table.keys()), _embedder)

def analyze_user_input(query):
    #constitution, driving, score = compute_user_constitution(query)
    constitution, driving, retr = compute_user_vpk_mapped(query, trait_vpk_table, _trait_index, _embedder, top_k=8)
    # friendly output
    text, diet, lifestyle = make_friendly_output(constitution, driving)
    return {
        "constitution": constitution,
        "driving": driving,
        "retrieval": retr,
        "text": text,
        "diet": diet,
        "lifestyle": lifestyle,
    }

# ---- Quick test ----
query = "dry skin light sleep short height"
constitution, driving, score = compute_user_constitution(query)
text, diet, lifestyle = make_friendly_output(constitution, driving)

print(text)
print("\nDIET:", diet)
print("\nLIFESTYLE:", lifestyle)

In [None]:

# ========= Phase-Live Mini Runtime (self-contained) =========
from typing import List, Dict, Tuple
import torch, torch.nn.functional as F
import re

# ---------------- Token embed + tokenizer ----------------
class TokenEmbedding:
    def __init__(self, vocab: List[str], dim:int=48, device:str='cpu'):
        self.dim, self.device = dim, device
        self.vocab = ['<unk>'] + list(vocab)
        self.word2idx = {w:i for i,w in enumerate(self.vocab)}
        self.embeddings = torch.randn(len(self.vocab), dim, device=device) / (dim**0.5)

    def lookup(self, token:str) -> torch.Tensor:
        idx = self.word2idx.get(token, 0)
        return self.embeddings[idx]

def universal_tokenizer(text: str) -> List[str]:
    if not text: return []
    return re.findall(r'\d+\.\d+|\d+|[A-Za-z]+|[+\-*/^=():]', text.lower())

# ---------------- Embedding index ----------------
def build_trait_embedding_index(trait_vpk_table: Dict[str, Dict[str,float]],
                                dim:int=48,
                                device:str='cpu') -> Tuple['TokenEmbedding', Dict[str, torch.Tensor]]:
    """Return (embedder, {trait: normalized_vec})."""
    tokens = list(trait_vpk_table.keys())
    emb = TokenEmbedding(tokens, dim=dim, device=device)
    index: Dict[str, torch.Tensor] = {}
    for trait in tokens:
        toks = universal_tokenizer(trait)
        if not toks:
            continue
        v = sum(emb.lookup(t) for t in toks) / len(toks)
        v = v / (v.norm() + 1e-9)
        index[trait] = v
    return emb, index

def semantic_match_traits(query: str,
                          trait_index: Dict[str, torch.Tensor],
                          embedder: TokenEmbedding,
                          top_k:int=8) -> List[Tuple[str, float]]:
    """Return list of (trait, similarity) sorted desc."""
    toks = universal_tokenizer(query)
    if not toks:
        return []
    q = sum(embedder.lookup(t) for t in toks) / len(toks)
    q = q / (q.norm() + 1e-9)
    sims = []
    for trait, vec in trait_index.items():
        s = F.cosine_similarity(q.unsqueeze(0), vec.unsqueeze(0)).item()
        sims.append((trait, s))
    sims.sort(key=lambda x: x[1], reverse=True)
    return sims[:top_k]

# ---------------- Constitution + driving traits ----------------
def _normalize_driving_traits(driving_traits, constitution, max_items=8):
    dom = (constitution.get("dominant") or "vata").lower()
    dom_idx = {"vata":0,"pitta":1,"kapha":2}[dom]
    out = []
    for item in driving_traits or []:
        # shapes: ("trait", float) or ("trait", {"vata":..}) or ("trait", [v,p,k]) or "trait"
        if isinstance(item, str):
            out.append((item, 0.0)); continue
        if isinstance(item, (list,tuple)) and len(item)>=1:
            t = str(item[0]); w = 0.0
            if len(item)>=2:
                weights = item[1]
                if isinstance(weights,(int,float)): w=float(weights)
                elif isinstance(weights,dict): w=float(weights.get(dom,0.0))
                elif isinstance(weights,(list,tuple)) and len(weights)>=3:
                    try: w=float(weights[dom_idx])
                    except: w=0.0
            out.append((t,w)); continue
        out.append((str(item),0.0))
    out.sort(key=lambda x:x[1], reverse=True)
    return out[:max_items]

def collapse_by_dimension(driving_traits, trait_dimensions):
    grouped = {}
    for trait, weight in driving_traits:
        dim = trait_dimensions.get(trait, None)
        if dim is None:
            grouped[trait] = weight  # keep unmatched traits directly
            continue
        if dim not in grouped or weight > grouped[dim][1]:
            grouped[dim] = (trait, weight)
    return [(t, w) for t, w in grouped.values()]

def compute_user_constitution(query: str,
                              trait_vpk_table: Dict[str, Dict[str,float]],
                              trait_index=None,
                              embedder: TokenEmbedding=None,
                              top_k:int=8):
    """
    Returns: (constitution_dict, driving_traits_list, retrieval_score_float)
    trait_vpk_table: {"trait string": {"vata":..., "pitta":..., "kapha":...}}
    """
    # Build/repair index if needed
    if not isinstance(trait_index, dict) or trait_index is None or len(trait_index)==0:
        embedder, trait_index = build_trait_embedding_index(trait_vpk_table)

    matches = semantic_match_traits(query, trait_index, embedder, top_k=top_k)
    if not matches:
        # fallback neutral
        v=p=k=1/3
        constitution = {
            "vata": round(v,3), "pitta": round(p,3), "kapha": round(k,3),
            "dominant": "vata", "confidence": 0.333, "matched_traits": []
        }
        return constitution, [], 0.0

    # collect dosha vectors for matched traits
    vecs = []
    driving = []
    for trait, sim in matches:
        vpk = trait_vpk_table.get(trait, {})
        v = float(vpk.get("vata", 0.0))
        p = float(vpk.get("pitta", 0.0))
        k = float(vpk.get("kapha", 0.0))
        vecs.append((v,p,k))
        # keep sim as weight to rank later
        driving.append((trait, sim))

    # average doshas
    if vecs:
        v = sum(x[0] for x in vecs)/len(vecs)
        p = sum(x[1] for x in vecs)/len(vecs)
        k = sum(x[2] for x in vecs)/len(vecs)
    else:
        v=p=k=1/3

    # confidence (simple, tunable)
    dom_name = ["vata","pitta","kapha"][[v,p,k].index(max(v,p,k))]
    balance = abs(v-p)+abs(p-k)+abs(v-k)
    retrieval_score = sum(s for _,s in matches)/len(matches)
    confidence = max(0.0, min(1.0, 0.5*retrieval_score + 0.5*(max(v,p,k))*(1 - 0.5*balance)))

    constitution = {
        "vata": round(v,3),
        "pitta": round(p,3),
        "kapha": round(k,3),
        "dominant": dom_name,
        "confidence": round(confidence,3),
        "matched_traits": [t for t,_ in matches]
    }

    # convert driving into multi-shape acceptable list:
    # here we pass dict per dosha so _normalize can pick dominant value
    # Convert matched traits into weighted driving traits structure
    driving_traits = []
    for trait, sim in driving:
        vpk = trait_vpk_table.get(trait, {})
        driving_traits.append(
            (trait,
             {
                 "vata": vpk.get("vata", 0.0),
                 "pitta": vpk.get("pitta", 0.0),
                 "kapha": vpk.get("kapha", 0.0)
             })
        )

    # Normalize to sort traits by dominant dosha strength
    normalized = _normalize_driving_traits(driving_traits, constitution, max_items=12)

    # Remove contradictions: keep only strongest signal per dimension
    collapsed = collapse_by_dimension(normalized, trait_dimensions)

    return constitution, collapsed, float(retrieval_score)
# ---------------- Friendly explanation + recs ----------------
def explain_friendly_style(constitution, driving_traits, max_items=8):
    dom = constitution["dominant"].capitalize()
    messages = {
        "Vata":  "Your energy tends to move quickly — creativity, expressiveness, fast thinking. "
                 "If balance slips, dryness, irregular digestion, or light sleep can show up.",
        "Pitta": "Your metabolism and mind are strong — intensity, focus, decisiveness. "
                 "If balance slips, heat, irritability, or overwork can show up.",
        "Kapha": "You are steady and grounded — calm, patient, resilient. "
                 "If balance slips, heaviness, sluggishness, or low motivation can show up.",
    }
    ordered = _normalize_driving_traits(driving_traits, constitution, max_items=max_items)
    trait_lines = "\n".join(
    explain_trait(t, s, trait_dimensions)
    for t, s in ordered
    ) or "(no strong signals detected)"
    return (
        f"You show a **{dom}-dominant** constitution.\n"
        f"{messages.get(dom, '')}\n\n"
        "Things your body signals clearly:\n\n"
        f"{trait_lines}\n\n"
        "Your body works best when balance is maintained. I’ll guide your diet and daily rhythm next."
    )

def recommend_for(constitution):
    dom = (constitution.get("dominant") or "").lower()
    if dom == "vata":
        diet = [
            "Warm cooked meals (soups, stews, kichadi)",
            "Healthy fats like ghee, coconut, olive oil",
            "Avoid cold salads, dry snacks and skipping meals",
        ]
        lifestyle = [
            "Sleep before 10:30 PM",
            "Keep a consistent routine",
            "Gentle yoga / stretching; avoid high intensity late evening",
        ]
    elif dom == "pitta":
        diet = [
            "Cooling foods like cucumber, coconut water, sweet fruits",
            "Avoid excessive spicy, sour, or fermented foods",
        ]
        lifestyle = [
            "Avoid late-night work (adds heat)",
            "Practice slow breathing & meditation daily",
        ]
    else:  # kapha
        diet = [
            "Light, warm meals; reduce heavy/dairy foods",
            "Use ginger and black pepper to stimulate metabolism",
        ]
        lifestyle = [
            "Regular morning physical activity",
            "Avoid oversleeping / daytime naps",
        ]
    return diet, lifestyle

def make_friendly_output(constitution, driving_traits):
    text = explain_friendly_style(constitution, driving_traits or [])
    diet, lifestyle = recommend_for(constitution)
    return text, diet, lifestyle

# ---------------- Quick test helper ----------------
# Expect `trait_vpk_table` to exist: {"trait": {"vata":..., "pitta":..., "kapha":...}, ...}
def quick_test(query: str, trait_vpk_table: Dict[str, Dict[str,float]],
               trait_index=None, embedder=None):
    constitution, driving, score = compute_user_constitution(
        query, trait_vpk_table, trait_index=trait_index, embedder=embedder
    )
    text, diet, life = make_friendly_output(constitution, driving)
    return {"text": text, "diet": diet, "lifestyle": life,
            "matched": constitution["matched_traits"], "retrieval": score}

In [None]:

def explain_trait(trait: str, weight: float, trait_dimensions: dict) -> str:
    # Convert trait into clean label if available
    label = trait_labels.get(trait, trait.replace("(", "").replace(")", "").replace("-", " ").strip())
    dim = trait_dimensions.get(trait, None)

    if dim == "body_frame":
        return f"• {label} suggests Vata influences body build and movement. (signal strength: {weight:.3f})"
    elif dim == "body_weight":
        return f"• {label} reflects how easily your body maintains grounding and nourishment. (signal strength: {weight:.3f})"
    elif dim == "height":
        return f"• {label} often appears in Vata-dominant constitutions. (signal strength: {weight:.3f})"
    elif dim == "skin_texture":
        return f"• {label} indicates how moisture and circulation are regulated. (signal strength: {weight:.3f})"
    elif dim == "sleep":
        return f"• {label} reflects the sensitivity of your nervous system. (signal strength: {weight:.3f})"

    # fallback
    return f"• {label} (signal strength: {weight:.3f})"

In [None]:

# Human-readable clean names for traits
trait_labels = {
    "slim body frame, difficulty gaining weight": "Slim / light body frame",
    "body weight (low - difficulties in gaining weight)": "Difficulty retaining weight",
    "height (short)": "Shorter height",
    "height (average)": "Medium height",
    "height (tall)": "Tall body frame",
    "large/heavy frame, gains weight easily": "Heavier / broad body frame",
    "moderate build, stable weight": "Stable balanced build",
    "texture of skin (dry, pigments and aging)": "Dry / rough skin texture",
    "complexion (fair-skin sunburns easily)": "Fair complexion prone to heat",
    # Add more here gradually if needed
}

In [None]:

trait_dimensions = {
    # Body Size / Frame
    "moderate build, stable weight": "body_frame",
    "slim body frame, difficulty gaining weight": "body_frame",
    "large/heavy frame, gains weight easily": "body_frame",

    # Body Weight
    "body weight (moderate - no difficulties in gaining or losing weight)": "body_weight",
    "body weight (low - difficulties in gaining weight)": "body_weight",
    "body weight (heavy - difficulties in losing weight)": "body_weight",

    # Height
    "height (average)": "height",
    "height (short)": "height",
    "height (tall)": "height",

    # Skin
    "texture of skin (dry, pigments and aging)": "skin_texture",
    "general feel of skin (dry and thin, cool to touch, rough)": "skin_texture",

    # Complexion
    "complexion (fair-skin sunburns easily)": "complexion",
    "complexion (white, pale, tans easily)": "complexion",

    # Sleep
    "sleep patterns (short)": "sleep",
    "sleep patterns (moderate)": "sleep",
    "sleep patterns (long)": "sleep",
}

In [None]:

dimension_explanations = {
    "body_frame": "Your constitution tends toward lightness or stability in physical structure.",
    "body_weight": "Your body's tendency to gain or lose weight reflects underlying dosha balance.",
    "height": "Height characteristics can indicate how Vata, Pitta, or Kapha influence body structure.",
    "skin_texture": "Skin texture signals how moisture, warmth, and circulation are regulated.",
    "complexion": "Your complexion suggests how heat and metabolic intensity are expressed.",
    "hair": "Hair qualities reveal internal heat, nourishment, and oil balance.",
    "sleep": "Your sleep rhythm reflects how your nervous system maintains balance.",
    "appetite": "Hunger pattern indicates how digestive fire (Agni) is functioning.",
    "digestion": "Digestive behavior expresses how smoothly nutrients are processed.",
    "climate_tolerance": "Your comfort in climates shows how heat and cold regulation operate.",
    "stress": "Stress response indicates how grounded or sensitive the nervous system is.",
}

In [None]:

trait_phrases = {
    "slim body frame, difficulty gaining weight":
        "Your constitution tends toward lightness and quick movement (Vata trait).",
    "height (short)":
        "Shorter body frame is more frequently seen in Vata-dominant constitutions.",
    "texture of skin (dry, pigments and aging)":
        "Dry or easily dehydrated skin indicates elevated Vata affecting moisture.",
    "sleep patterns (short)":
        "Light and easily disturbed sleep suggests Vata influence on the nervous system.",
    "body weight (low - difficulties in gaining weight)":
        "Your body tends not to retain weight, reflecting Vata predominance.",
}

In [None]:

# 1) Prepare your trait_vpk_table (you already have this dict)
# Example single entry structure:
# trait_vpk_table = {
#   "height (short)": {"vata": 0.42, "pitta": 0.05, "kapha": 0.03},
#   ...
# }

# 2) Build index once (optional: you can also let compute_user_constitution rebuild if missing)
embedder, trait_index = build_trait_embedding_index(trait_vpk_table, dim=48, device="cpu")

# 3) Run a test
res = quick_test("dry skin light sleep short height", trait_vpk_table, trait_index, embedder)
print(res["text"])
print("\nDIET:", res["diet"])
print("\nLIFESTYLE:", res["lifestyle"])
print("\nMATCHED:", res["matched"])
print("RETRIEVAL SCORE:", round(res["retrieval"],3))

You show a **Vata-dominant** constitution.
Your energy tends to move quickly — creativity, expressiveness, fast thinking. If balance slips, dryness, irregular digestion, or light sleep can show up.

Things your body signals clearly:

• Slim / light body frame suggests Vata influences body build and movement. (signal strength: 0.455)
• Difficulty retaining weight reflects how easily your body maintains grounding and nourishment. (signal strength: 0.448)
• Shorter height often appears in Vata-dominant constitutions. (signal strength: 0.421)

Your body works best when balance is maintained. I’ll guide your diet and daily rhythm next.

DIET: ['Warm cooked meals (soups, stews, kichadi)', 'Healthy fats like ghee, coconut, olive oil', 'Avoid cold salads, dry snacks and skipping meals']

LIFESTYLE: ['Sleep before 10:30 PM', 'Keep a consistent routine', 'Gentle yoga / stretching; avoid high intensity late evening']

MATCHED: ['moderate build, stable weight', 'slim body frame, difficulty gaining