# Section 1:

In [None]:
%%writefile /content/TLiteComponents.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

"""InputProjector: This is a simple linear layer that transforms an input vector from one dimension to another.
It's useful for ensuring that different parts of a model expect and receive feature vectors of a consistent size."""

class InputProjector(nn.Module):
    def __init__(self, input_dim, target_dim):
        super().__init__()
        self.project = nn.Linear(input_dim, target_dim).to(input_dim.device if isinstance(input_dim, torch.Tensor) else 'cpu')
        logger.info(f"Initialized InputProjector: {input_dim} -> {target_dim}")

    def forward(self, x):
        return self.project(x)
"""
patch_model_input: A utility function that can dynamically add an InputProjector to an existing model.
If a model expects an input of a certain dimension but receives a different one, this function can 'patch' the model's forward method to include the projection automatically.

"""
def patch_model_input(model, input_vector, expected_dim):
    actual_dim = input_vector.shape[-1]
    if actual_dim != expected_dim:
        if hasattr(model, 'input_projector'):
            logger.warning(f"Model already has input_projector, skipping patch")
            return
        logger.info(f"Auto-patching model input: {actual_dim} -> {expected_dim}")
        projector = InputProjector(actual_dim, expected_dim).to(input_vector.device)
        old_forward = model.forward

        def new_forward(x):
            x_proj = projector(x)
            return old_forward(x_proj)

        model.forward = new_forward
        model.input_projector = projector

class ShapeTagEmbedder(nn.Module):
    def __init__(self, shape_vocab, dim, device='cpu'):
        super().__init__()
        self.device = device
        self.shape_vocab = shape_vocab
        self.tag2idx = {tag: i for i, tag in enumerate(shape_vocab)}
        self.embedding = nn.Embedding(len(shape_vocab), dim).to(device)
        logger.info(f"ShapeTagEmbedder initialized with {len(shape_vocab)} tags, dim={dim}")

    def forward(self, shape_tag: str):
        idx = self.tag2idx.get(shape_tag, 0)
        idx_tensor = torch.tensor([idx], dtype=torch.long, device=self.device)
        return self.embedding(idx_tensor).squeeze(0)

class TreeEncoderWithAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 5, device: str = 'cpu'):
        super().__init__()
        self.dim = dim
        self.device = device
        self.projectors = {}
        if dim % num_heads != 0:
            num_heads = max(1, dim // 4)
            logger.warning(f"Adjusted num_heads to {num_heads} for embed_dim={dim}")
        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads).to(device)
        self.norm = nn.LayerNorm(dim).to(device)
        self.shape_embedder = ShapeTagEmbedder(
            shape_vocab=['unknown', 'text', 'circle', 'line', 'arrow', 'box'],
            dim=dim, device=device
        )
        logger.info(f"Initialized TreeEncoderWithAttention with dim={dim}")

    def _ensure_dim(self, vec: torch.Tensor) -> torch.Tensor:
        vec = vec.to(self.device)
        if vec.shape[-1] == self.dim:
            return vec
        key = f"{vec.shape[-1]}>{self.dim}"
        if key not in self.projectors:
            self.projectors[key] = nn.Linear(vec.shape[-1], self.dim).to(self.device)
            logger.info(f"Auto-projecting leaf: {vec.shape[-1]} -> {self.dim}")
        return self.projectors[key](vec)

    def encode(self, node, get_vector_fn):
        if not hasattr(node, 'is_leaf') or not hasattr(node, 'get_vector'):
            logger.error(f"Invalid node type: {type(node)}")
            return torch.zeros(self.dim, device=self.device)
        if node.is_leaf():
            token_vec = get_vector_fn(node)
            shape_tag = getattr(node, 'shape_tag', 'unknown')
            shape_vec = self.shape_embedder(shape_tag)
            if token_vec is None or not isinstance(token_vec, torch.Tensor):
                logger.warning(f"Leaf missing token vector, using shape vector: {shape_tag}")
                return shape_vec
            combined = token_vec.to(self.device) + shape_vec
            return self._ensure_dim(combined)

        vectors = [self.encode(child, get_vector_fn) for child in node.children if child]
        if not vectors:
            return torch.zeros(self.dim, device=self.device)
        stacked = torch.stack(vectors).unsqueeze(1)
        attn_output, _ = self.attention(stacked, stacked, stacked)
        return self.norm(attn_output.squeeze(1).mean(dim=0))

    def forward(self, node, get_vector_fn=lambda n: n.get_vector()):
        return self.encode(node, get_vector_fn)

"""TLiteV5_ReasoningModule: This module appears to be a multi-layer neural network designed for reasoning.
 It uses a stack of LayerNorm, Linear, and GELU activation layers.
 It also incorporates an InputProjector to ensure its input is of the expected dimension, and a final Softplus activated linear layer,
 typically used for outputting positive scores or confidence values."""

class TLiteV5_ReasoningModule(nn.Module):
    def __init__(self, dim=50, hidden_dim=128, depth=4, device='cpu'):
        super().__init__()
        self.expected_dim = dim
        self.device = device
        self.projector = None
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, dim)
            ) for _ in range(depth)
        ]).to(device)
        self.final_norm = nn.LayerNorm(dim).to(device)
        self.head = nn.Sequential(
            nn.Linear(dim, 1),
            nn.Softplus()
        ).to(device)
        logger.info(f"Initialized TLiteV5_ReasoningModule with dim={dim}")

    def forward(self, x):
        if x is None:
            logger.warning("Invalid input, returning zero")
            return torch.tensor(0.0, device=self.device)
        x = x.to(self.device)
        if x.shape[-1] != self.expected_dim:
            if self.projector is None:
                self.projector = InputProjector(x.shape[-1], self.expected_dim).to(self.device)
            x = self.projector(x)
        for layer in self.layers:
            x = x + layer(x)
        x = self.final_norm(x)
        return self.head(x).squeeze(-1)

"""TLiteExpert: This represents a single 'expert' network, a small multi-layer perceptron (MLP) with LayerNorm and GELU activations.
These experts are commonly used in Mixture of Experts (MoE) architectures."""

class TLiteExpert(nn.Module):
    def __init__(self, dim=50, hidden_dim=64, device='cpu'):
        super().__init__()
        self.norm = nn.LayerNorm(dim).to(device)
        self.fc1 = nn.Linear(dim, hidden_dim).to(device)
        self.fc2 = nn.Linear(hidden_dim, dim).to(device)
        self.device = device
        logger.info(f"Initialized TLiteExpert with dim={dim}")

    def forward(self, x):
        if x is None:
            logger.warning("Invalid input, returning zero")
            return torch.zeros(self.dim, device=self.device)
        x = x.to(self.device)
        x = self.norm(x)
        x = F.gelu(self.fc1(x))
        return self.fc2(x)
"""TLiteRouter: This module is part of an MoE system. It determines which TLiteExpert models should process a given input.
It takes an input vector, calculates scores for each expert, and then uses a softmax function to select the top_k experts and assign weights to their outputs."""
class TLiteRouter(nn.Module):
    def __init__(self, dim=50, num_experts=8, top_k=2, device='cpu'):
        super().__init__()
        self.gate = nn.Linear(dim, num_experts).to(device)
        self.num_experts = num_experts
        self.top_k = top_k
        self.device = device
        logger.info(f"Initialized TLiteRouter with {num_experts} experts")

    def forward(self, x):
        if x is None:
            logger.warning("Invalid input, returning zeros")
            return torch.zeros(self.num_experts, device=self.device), torch.zeros(self.top_k, device=self.device)
        x = x.to(self.device)
        scores = self.gate(x)
        topk_scores, topk_indices = torch.topk(scores, self.top_k, dim=-1)
        topk_weights = F.softmax(topk_scores, dim=-1)
        return topk_indices, topk_weights
"""TLiteV6: This combines the TLiteExpert and TLiteRouter to form a complete Mixture of Experts model.
 When data is passed to TLiteV6, the Router selects a few 'experts' to process the data, their outputs are combined based on the router's weights,
  and then passed through a final head for the ultimate prediction."""
class TLiteV6(nn.Module):
    def __init__(self, dim=50, hidden_dim=64, num_experts=8, top_k=2, device='cpu'):
        super().__init__()
        self.experts = nn.ModuleList([TLiteExpert(dim, hidden_dim, device) for _ in range(num_experts)])
        self.router = TLiteRouter(dim, num_experts, top_k, device)
        self.final_head = nn.Sequential(
            nn.Linear(dim, 1),
            nn.Softplus()
        ).to(device)
        self.device = device
        logger.info(f"Initialized TLiteV6 with {num_experts} experts")

    def forward(self, x):
        if x is None:
            logger.warning("Invalid input, returning zero")
            return torch.tensor(0.0, device=self.device)
        x = x.to(self.device)
        topk_indices, topk_weights = self.router(x)
        out = torch.zeros_like(x, device=self.device)
        for b in range(x.shape[0]):
            for i, idx in enumerate(topk_indices[b]):
                expert_out = self.experts[idx](x[b])
                out[b] += topk_weights[b][i] * expert_out
        return self.final_head(out).squeeze(-1)

Writing /content/TLiteComponents.py


In [None]:

%%writefile /content/smart_utils.py

import re
import logging
import sympy as sp
from word2number import w2n
import torch
from typing import Any

"""Its main purpose is to convert various forms of answers—whether they're numbers, words representing numbers,
or even mathematical expressions—into a single, standardized floating-point number.
Here's how it works:
Handles Missing or Direct Numeric Inputs: If the input is None, it returns 0.0.
If it's already an integer or a float, it simply converts it to a float and returns it.
Processes Text Inputs: For anything else, especially text, it goes through several steps:
Cleaning: It first cleans the text by removing most punctuation and converting it to lowercase, making it easier to process.
Direct Number Conversion: It tries to convert the cleaned text directly into a float (e.g., '123.45' becomes 123.45).
Word-to-Number Conversion: If that fails, it uses the word2number library (imported as w2n) to convert words like 'one hundred and fifty' into their numerical equivalent (e.g., 150.0).
Symbolic Math Parsing: If word2number can't convert it, it then attempts to use the sympy library (imported as sp) to parse and evaluate mathematical expressions.
 This means it can understand and solve things like 'sqrt(4)' or 'pi / 2' and return their numerical result.
Fallback Number Extraction: As a last resort, if all previous attempts fail, it looks for any numbers within the text (even if it's a jumbled string) and tries to extract them and return their average.
 If no numbers are found, it defaults to 0.0."""

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def smart_normalize(answer_text: Any) -> float:
    """Normalize any math answer to float, supporting numbers, words, and expressions."""
    if answer_text is None:
        logger.warning("Normalization failed: answer is None, returning 0.0")
        return 0.0
    if isinstance(answer_text, (int, float)):
        return float(answer_text)
    try:
        text = str(answer_text).strip().lower()
        text = re.sub(r'[^\w\s.-]', '', text).replace(',', '')
        # Try direct float conversion
        if re.match(r'^-?\d*\.?\d+$', text):
            return float(text)
        # Try word-to-number
        try:
            return float(w2n.word_to_num(text))
        except ValueError:
            pass
        # Try symbolic parsing with comprehensive math functions
        expr = sp.sympify(text, evaluate=False, locals={
            'sin': sp.sin, 'cos': sp.cos, 'tan': sp.tan, 'cot': sp.cot, 'sec': sp.sec, 'csc': sp.csc,
            'pi': sp.pi, 'sqrt': sp.sqrt, 'log': sp.log, 'ln': sp.ln, 'exp': sp.exp,
            'arcsin': sp.asin, 'arccos': sp.acos, 'arctan': sp.atan,
            'Integral': sp.Integral, 'Sum': sp.Sum, 'Product': sp.Product
        })
        sol = float(expr.evalf())
        logger.debug(f"Symbolic normalization succeeded: {text} -> {sol}")
        return sol
    except (ValueError, TypeError, sp.SympifyError) as e:
        logger.warning(f"Normalization failed for {answer_text}: {e}, trying number extraction")
        try:
            # Fallback: extract numbers and average
            num_matches = [float(n) for n in sp.sympify(text).atoms(sp.Number) if isinstance(n, sp.Number)]
            if num_matches:
                sol = sum(num_matches) / len(num_matches)
                logger.debug(f"Fallback succeeded: mean of numbers {num_matches} -> {sol}")
                return sol
            logger.warning("No numbers found, returning 0.0")
            return 0.0
        except Exception as e2:
            logger.warning(f"Fallback failed: {e2}, returning 0.0")
            return 0.0

Writing /content/smart_utils.py


In [None]:
!pip install word2number

Collecting word2number
  Downloading word2number-1.1.zip (9.7 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: word2number
  Building wheel for word2number (setup.py) ... [?25l[?25hdone
  Created wheel for word2number: filename=word2number-1.1-py3-none-any.whl size=5568 sha256=978771e9370ea0957453db825607427e61d206b0c00efdca18650e99f41dd1e7
  Stored in directory: /root/.cache/pip/wheels/5b/79/fb/d25928e599c7e11fe4e00d32048cd74933f34a74c633d2aea6
Successfully built word2number
Installing collected packages: word2number
Successfully installed word2number-1.1


In [None]:

%%writefile /content/smart_preprocessor_v2.py
import logging
from typing import Dict, Optional, Tuple, Union, List
from smart_utils import smart_normalize

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SmartPreprocessorV2:
    def __init__(self, context_col=None, target_col=None, debug=False):
        self.context_col = context_col
        self.target_col = target_col
        self.debug = debug
        self.min_context_len = 10  # Minimum char length to qualify as context

    def is_number(self, val):
        """Check if value is numeric-like."""
        try:
            float(val)
            return True
        except:
            return False

    def score_field(self, key, value) -> Tuple[float, float]:
        """Returns (context_score, target_score)."""
        if value is None:
            return (0, 0)
        if isinstance(value, str):
            length = len(value.strip())
            if length > self.min_context_len and " " in value:
                return (1.0, 0.0)  # Likely context
            if self.is_number(value):
                return (0.0, 1.0)  # Likely numeric target
            return (0.5, 0.2)      # Possible mixed
        if isinstance(value, (int, float)):
            return (0.0, 1.0)
        return (0.0, 0.0)

    def autodetect_fields(self, sample: Dict):
        """Auto-detect best context and target columns."""
        field_scores = {}
        for k, v in sample.items():
            ctx_score, tgt_score = self.score_field(k, v)
            field_scores[k] = (ctx_score, tgt_score)

        sorted_fields = sorted(field_scores.items(), key=lambda kv: kv[1], reverse=True)
        context_candidates = [f for f, (c, t) in sorted_fields if c > 0]
        target_candidates = [f for f, (c, t) in sorted_fields if t > 0]

        if context_candidates:
            self.context_col = context_candidates[0]
        if target_candidates:
            self.target_col = target_candidates[0]

        if self.debug:
            logger.info(f"[AutoDetect] Context → {self.context_col}, Target → {self.target_col}")

    def preprocess(self, sample: Dict) -> Optional[Dict]:
        """Return processed dict with id, context, target, equation."""
        if not self.context_col or not self.target_col:
            self.autodetect_fields(sample)

        sid = sample.get('id', 'unknown')

        context_val = sample.get(self.context_col)
        target_val = sample.get(self.target_col)

        if not context_val or not target_val:
            logger.warning(f"[{sid}] Missing context or target → Skipping sample.")
            return None

        # Handle numeric vs text
        if self.is_number(target_val) or isinstance(target_val, (int, float)):
            normalized_target = smart_normalize(target_val)
        else:
            try:
                float_attempt = float(str(target_val).strip())
                normalized_target = smart_normalize(target_val)
            except ValueError:
                target_str = str(target_val).strip()
                # Auto-split multi-label text
                if "," in target_str or ";" in target_str:
                    normalized_target = [lab.strip() for lab in target_str.replace(";", ",").split(",") if lab.strip()]
                else:
                    normalized_target = target_str

        return {
            'id': sid,
            'context': {self.context_col: str(context_val)},
            'target': normalized_target,
            'equation': sample.get('equation')
        }

Writing /content/smart_preprocessor_v2.py


In [None]:

%%writefile /content/tokenizer_and_embedding.py
import torch, re, logging
from typing import List

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class TokenEmbedding:
    def __init__(self, vocab: List[str], dim:int=50, device:str='cpu'):
        self.dim, self.device = dim, device
        self.vocab = ['<unk>'] + vocab
        self.word2idx = {w:i for i,w in enumerate(self.vocab)}
        self.embeddings = torch.randn(len(self.vocab), dim, device=device) / (dim**0.5)
        logger.info(f"TokenEmbedding: vocab_size={len(self.vocab)} dim={dim}")

    def lookup(self, token:str) -> torch.Tensor:
        idx = self.word2idx.get(token, 0)
        return self.embeddings[idx]

def universal_tokenizer(text: str) -> List[str]:
    if not text:
        return []
    # split numbers, identifiers, symbols
    return re.findall(r'\d+\.\d+|\d+|[A-Za-z]+|[+\-*/^=():]', text)

Writing /content/tokenizer_and_embedding.py


In [None]:

%%writefile /content/target_processor.py
import torch
import logging
from typing import List, Union

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MultiLabelTargetProcessor:
    """
    Handles multi-label targets for classification or regression.
    - For text labels: builds a label→index map and returns multi-hot vectors.
    - For numeric labels: returns float tensors directly.
    """
    def __init__(self):
        self.label_map = {}
        self.max_value = 1.0
        self.is_numeric_mode = False

    def fit_labels(self, labels_list: List[Union[str, List[str], float, int]]):
        """
        Fit label map from dataset.
        labels_list is a list where each item is:
          - a single label (str or float)
          - OR a list of labels (multi-label case)
        """
        for labels in labels_list:
            if isinstance(labels, (float, int)):
                self.is_numeric_mode = True
                self.max_value = max(self.max_value, float(labels))
            elif isinstance(labels, str):
                self.label_map.setdefault(labels, len(self.label_map))
            elif isinstance(labels, list):
                for lab in labels:
                    if isinstance(lab, (float, int)):
                        self.is_numeric_mode = True
                        self.max_value = max(self.max_value, float(lab))
                    else:
                        self.label_map.setdefault(str(lab), len(self.label_map))
            else:
                logger.warning(f"Unsupported label type: {type(labels)}")

        logger.info(f"Fitted label map: {self.label_map}")
        logger.info(f"Numeric mode: {self.is_numeric_mode}, Max value: {self.max_value}")

    def encode(self, labels: Union[str, List[str], float, int]) -> torch.Tensor:
        """
        Encode labels into tensor.
        - If numeric mode: returns float tensor normalized by max_value.
        - Else: returns multi-hot vector for label(s).
        """
        if self.is_numeric_mode:
            try:
                return torch.tensor(float(labels) / self.max_value, dtype=torch.float32)
            except Exception as e:
                logger.error(f"Numeric encoding failed for {labels}: {e}")
                return torch.tensor(0.0, dtype=torch.float32)

        # Classification mode
        vec = torch.zeros(len(self.label_map), dtype=torch.float32)
        if isinstance(labels, str):
            idx = self.label_map.get(labels)
            if idx is not None:
                vec[idx] = 1.0
        elif isinstance(labels, list):
            for lab in labels:
                idx = self.label_map.get(str(lab))
                if idx is not None:
                    vec[idx] = 1.0
        else:
            logger.warning(f"Unsupported label type for encoding: {type(labels)}")

        return vec

    def decode(self, tensor: torch.Tensor) -> Union[str, float, List[str]]:
        """
        Decode tensor back into labels or numeric value.
        """
        if self.is_numeric_mode:
            return tensor.item() * self.max_value

        # Classification mode: return list of labels with value > 0.5
        indices = (tensor > 0.5).nonzero(as_tuple=True)[0].tolist()
        return [lab for lab, idx in self.label_map.items() if idx in indices]

Writing /content/target_processor.py


# Section 2:

In [None]:

%%writefile /content/TreeNodeV1.py
import torch
import logging
from typing import Any, Dict, Optional

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("TreeNodeV1")


class TreeNodeV1:
    """
    Minimal Tree Node (Shape-Free):
    - NO shape normalization
    - NO TLite logic
    - shape_type stored raw with zero logic
    """

    def __init__(
        self,
        value: str,
        shape_type: Optional[str] = None,
        vector: Optional[Any] = None,
        confidence: float = 0.0,
        spiral_index: Optional[int] = None,
        id: Optional[str] = None,
        shape_tag: Optional[str] = None,
        level: int = 0,
        max_children: int = 10,
        label_text: Optional[str] = None,
        label_embedding: Optional[Any] = None,
        label_vector: Optional[Any] = None,
        **kwargs
    ):
        self.value = value
        self.id = id if id is not None else value

        # raw, unmanaged shape tag
        self.shape_type = shape_tag if shape_tag is not None else shape_type

        self.level = int(level)
        self.spiral_index = spiral_index
        self.max_children = int(max_children)
        self.children = []
        self.confidence = float(confidence)

        # Label info
        self.label_text = label_text
        self.label_embedding = None
        self.label_vector = None

        # Sub-features
        self.sub_features: Dict[str, Any] = {}
        self.feature_confidence: Dict[str, float] = {}

        # vectors
        self._assign_vector_safe("vector", vector)
        self._assign_vector_safe("label_embedding", label_embedding, attr_name="label_embedding")
        self._assign_vector_safe("label_vector", label_vector, attr_name="label_vector")

    # ---------------- INTERNAL HELPERS ----------------
    def _assign_vector_safe(self, attr_label: str, value: Any, attr_name: Optional[str] = None):
        if attr_name is None:
            attr_name = attr_label
        try:
            if value is None:
                setattr(self, attr_name, None)
                return
            if isinstance(value, torch.Tensor):
                setattr(self, attr_name, value.clone().detach().float())
            else:
                setattr(self, attr_name, torch.tensor(value, dtype=torch.float32))
        except Exception as e:
            logger.error(f"Vector assignment failed for {self.id}: {e}")
            setattr(self, attr_name, None)

    # ---------------- API ----------------
    def store_vector(self, vector: Any):
        self._assign_vector_safe("vector", vector)

    def get_vector(self):
        return getattr(self, "vector", None)

    def get_confidence(self):
        return float(self.confidence)

    def add_child(self, child: "TreeNodeV1"):
        if isinstance(child, TreeNodeV1):
            if len(self.children) < self.max_children:
                self.children.append(child)
            else:
                logger.warning(f"Node {self.id}: max children limit reached")
        else:
            logger.warning(f"Invalid child added to {self.id}: not TreeNodeV1")

    def is_leaf(self):
        return len(self.children) == 0

    # ---------------- SIMPLE SETTERS NEEDED BY PIPELINE ----------------
    def set_spiral_index(self, idx: int):
        self.spiral_index = idx

    def set_label_text(self, text: str):
        self.label_text = text

    def set_label_embedding(self, embedding: Any):
        self._assign_vector_safe("label_embedding", embedding, attr_name="label_embedding")

    def set_label_vector(self, vector: Any):
        self._assign_vector_safe("label_vector", vector, attr_name="label_vector")

    def get_label_vector(self):
        return getattr(self, "label_vector", None)

    def get_label_embedding(self):
        return getattr(self, "label_embedding", None)

    # ---------------- SERIALIZATION ----------------
    def to_dict(self):
        def _maybe(x):
            if isinstance(x, torch.Tensor):
                try: return x.cpu().numpy().tolist()
                except: return None
            return x

        return {
            "value": self.value,
            "id": self.id,
            "shape_type": self.shape_type,
            "level": self.level,
            "spiral_index": self.spiral_index,
            "confidence": self.confidence,
            "n_children": len(self.children),
            "label_text": self.label_text,
            "vector": _maybe(self.get_vector()),
            "label_vector": _maybe(self.get_label_vector()),
            "label_embedding": _maybe(self.get_label_embedding()),
            "sub_features": self.sub_features,
            "feature_confidence": self.feature_confidence,
        }

    def __repr__(self):
        return (
            f"<TreeNodeV1 id={self.id!r} value={self.value!r} "
            f"shape={self.shape_type!r} conf={self.confidence:.3f} "
            f"children={len(self.children)}>"
        )

Writing /content/TreeNodeV1.py


In [None]:

%%writefile /content/tree_node_utils.py
import logging
import torch
from typing import List
from TreeNodeV1 import TreeNodeV1

logger = logging.getLogger("tree_node_utils")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def sanitize_raw_nodes(raw_nodes: List, raise_on_invalid=False):
    """
    Keep only TreeNodeV1 instances with valid tensor vectors.
    Returns sanitized list and a dict with counts.
    """
    good = []
    bad_count = 0
    none_vector_count = 0
    invalid_type_count = 0

    for n in raw_nodes:
        if isinstance(n, TreeNodeV1):
            vec = getattr(n, "vector", None)
            if vec is None:
                none_vector_count += 1
                logger.warning("sanitize_raw_nodes: Node %s skipped, vector is None", getattr(n, "id", "<no-id>"))
                continue
            if not isinstance(vec, torch.Tensor):
                try:
                    n.store_vector(torch.tensor(vec, dtype=torch.float32))
                except Exception:
                    none_vector_count += 1
                    logger.warning("sanitize_raw_nodes: Node %s skipped, vector conversion failed", getattr(n, "id", "<no-id>"))
                    continue
            good.append(n)
        else:
            invalid_type_count += 1
            bad_count += 1
            logger.warning("sanitize_raw_nodes: Invalid node type skipped: %s", type(n))

    stats = {
        "good": len(good),
        "invalid_type": invalid_type_count,
        "none_vector": none_vector_count
    }
    logger.info("sanitize_raw_nodes stats: %s", stats)
    if raise_on_invalid and (invalid_type_count > 0 or none_vector_count > 0):
        raise ValueError(f"sanitize_raw_nodes found invalid nodes: {stats}")
    return good, stats

Writing /content/tree_node_utils.py


In [None]:

%%writefile /content/TreeBuilderV2.py
import logging
from typing import List, Tuple, Optional, Dict
import torch
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("TreeBuilderV2")


class TreeBuilderV2:
    """
    FINAL SHAPE-FREE VERSION
    ------------------------
    • Completely removes shape_type
    • Nodes only store: value, vector, spiral_index, confidence, labels
    • No TLite model, no shape abstraction, no tagging
    • Pure tree builder
    """

    def __init__(self, device: str = "cpu", dim: int = 50):
        self.device = device
        self.dim = dim

    def _ensure_dim(self, vector: torch.Tensor) -> torch.Tensor:
        """Ensure vector is 1D tensor of length self.dim."""
        if not isinstance(vector, torch.Tensor):
            try:
                vector = torch.tensor(vector, dtype=torch.float32)
            except Exception:
                raise ValueError("Vector not convertible to torch.Tensor")

        if vector.dim() > 1:
            vector = vector.view(-1)

        v = vector.to(self.device).float()
        current = v.shape[0]

        if current == self.dim:
            return v
        elif current < self.dim:
            pad = torch.zeros(self.dim - current, device=self.device)
            return torch.cat([v, pad], dim=0)
        else:
            return v[:self.dim]

    def build_tree(self,
                   vec_pairs: List[Tuple[str, torch.Tensor]],
                   label_texts: Optional[Dict[str, str]] = None,
                   token_embedding=None,
                   target_processor=None) -> Optional[TreeNodeV1]:

        if not vec_pairs:
            logger.warning("TreeBuilderV2: empty vec_pairs")
            return None

        sample_id = vec_pairs[0][0]
        root = TreeNodeV1(value="root", id=sample_id, max_children=len(vec_pairs))

        spiral_index = 0
        for token, raw_vec in vec_pairs:
            if raw_vec is None:
                continue

            try:
                vec = self._ensure_dim(raw_vec)
            except Exception:
                continue

            node = TreeNodeV1(value=str(token))    # ❗ no shape assignment
            node.set_spiral_index(spiral_index)
            spiral_index += 1
            node.store_vector(vec)

            # confidence = normalized L2 norm (0–1)
            try:
                norm = float(torch.norm(vec).item())
                node.confidence = min(max(norm / (self.dim ** 0.5 + 1e-8), 0.0), 1.0)
            except:
                node.confidence = 0.0

            # Optional label features
            if label_texts and token in label_texts:
                label = label_texts[token]
                node.set_label_text(label)

                if token_embedding:
                    try:
                        node.set_label_embedding(token_embedding.lookup(label))
                    except Exception:
                        pass

                if target_processor:
                    try:
                        node.set_label_vector(target_processor.encode(label))
                    except Exception:
                        pass

            root.add_child(node)

        # compute root vector as mean of child vectors
        child_vecs = [c.get_vector() for c in root.children if c.get_vector() is not None]

        if child_vecs:
            try:
                stacked = torch.stack(child_vecs)
                root_vec = stacked.mean(dim=0)
                root.store_vector(root_vec)
                root.confidence = float(sum(c.confidence for c in root.children) / len(root.children))
            except:
                pass

        return root

Writing /content/TreeBuilderV2.py


In [None]:

%%writefile /content/TreeDecoder.py
import logging
import torch
import torch.nn as nn
from typing import Dict, Any, Iterable, List, Optional

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("TreeDecoder")


class TreeDecoder(nn.Module):
    """
    Minimal, TLite-free TreeDecoder that:
      - maps a tree root vector -> a post-processed vector via a small MLP (forward)
      - provides `reconstruct_tree(nodes)` which returns a JSON-serializable dict
        mapping node_id -> { vector: [...], children: [child_id,...], meta... }
    Assumes node objects implement:
      - .id (str or int)
      - .get_vector() -> torch.Tensor | None
      - .children -> iterable of child nodes
      - optional attributes: .label_text, .shape_type, .confidence
    """

    def __init__(self, dim: int, hidden_dim: int = 128, device: str = "cpu"):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.device = device
        self.fc1 = nn.Linear(dim, hidden_dim).to(device)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, dim).to(device)

    def forward(self, tree_root) -> torch.Tensor:
        """
        Accepts a TreeNode-like root. Uses root.get_vector() as input.
        Returns a tensor of shape (dim,) on the configured device.
        If input invalid or missing, returns zeros.
        """
        try:
            if tree_root is None:
                logger.warning("TreeDecoder.forward received None root; returning zeros")
                return torch.zeros(self.dim, device=self.device)
            vec = None
            # prefer get_vector method if available
            if hasattr(tree_root, "get_vector"):
                vec = tree_root.get_vector()
            else:
                vec = getattr(tree_root, "vector", None)

            if vec is None:
                logger.warning("TreeDecoder.forward: root vector missing; returning zeros")
                return torch.zeros(self.dim, device=self.device)

            if not isinstance(vec, torch.Tensor):
                vec = torch.tensor(vec, dtype=torch.float32)

            vec = vec.to(self.device).float()
            # adjust length (pad/truncate) if needed
            if vec.numel() != self.dim:
                v = vec.view(-1)
                if v.numel() < self.dim:
                    pad = torch.zeros(self.dim - v.numel(), device=self.device)
                    v = torch.cat([v.to(self.device), pad], dim=0)
                else:
                    v = v[: self.dim].to(self.device)
                vec = v

            x = self.fc1(vec)
            x = self.act(x)
            x = self.fc2(x)
            if torch.isnan(x).any():
                logger.error("TreeDecoder.forward produced NaNs; returning zeros")
                return torch.zeros(self.dim, device=self.device)
            return x
        except Exception as e:
            logger.warning(f"TreeDecoder.forward failed: {e}")
            return torch.zeros(self.dim, device=self.device)

    # ---------------------- Serialization utilities ----------------------
    @staticmethod
    def _tensor_to_list_safe(t: Optional[torch.Tensor]) -> Optional[List[float]]:
        if t is None:
            return None
        try:
            if not isinstance(t, torch.Tensor):
                t = torch.tensor(t, dtype=torch.float32)
            return t.detach().cpu().numpy().tolist()
        except Exception:
            # last resort: try to convert iterables
            try:
                return [float(x) for x in list(t)]
            except Exception:
                return None

    def reconstruct_tree(self, nodes: Iterable) -> Dict[str, Any]:
        """
        Convert an iterable of nodes into a JSON-serializable mapping:
          node_id -> {
            "value": <node.value if present>,
            "vector": [...],         # or None
            "children": [child_id,...],
            "label_text": ...,
            "shape_type": ...,
            "confidence": float,
            "meta": {...}            # any sub_features if present (non-tensor)
          }
        NOTE: children are represented by their ids to avoid recursive nesting.
        """
        result: Dict[str, Any] = {}
        for obj in nodes:
            try:
                node_id = getattr(obj, "id", None)
                if node_id is None:
                    # fallback to value or str(obj)
                    node_id = getattr(obj, "value", None) or str(obj)

                node_id = str(node_id)

                # vector
                vec = None
                if hasattr(obj, "get_vector"):
                    try:
                        vec = obj.get_vector()
                    except Exception:
                        vec = getattr(obj, "vector", None)
                else:
                    vec = getattr(obj, "vector", None)

                vec_list = self._tensor_to_list_safe(vec)

                # children -> list of ids (string)
                children_ids: List[str] = []
                try:
                    chs = getattr(obj, "children", []) or []
                    for c in chs:
                        cid = getattr(c, "id", None) or getattr(c, "value", None) or str(c)
                        children_ids.append(str(cid))
                except Exception:
                    children_ids = []

                # optional metadata
                label_text = getattr(obj, "label_text", None)
                shape_type = getattr(obj, "shape_type", None)
                confidence = None
                try:
                    confidence = float(getattr(obj, "confidence", None)) if getattr(obj, "confidence", None) is not None else None
                except Exception:
                    confidence = None

                # collect non-tensor sub_features if present
                meta = {}
                try:
                    subf = getattr(obj, "sub_features", None)
                    if isinstance(subf, dict):
                        for k, v in subf.items():
                            # try safe convert tensors to lists; otherwise keep primitive
                            if isinstance(v, torch.Tensor):
                                meta[k] = self._tensor_to_list_safe(v)
                            else:
                                try:
                                    # ensure JSON-serializable simple types
                                    if isinstance(v, (str, int, float, bool, list, dict, type(None))):
                                        meta[k] = v
                                    else:
                                        meta[k] = str(v)
                                except Exception:
                                    meta[k] = str(v)
                except Exception:
                    meta = {}

                result[node_id] = {
                    "value": getattr(obj, "value", None),
                    "vector": vec_list,
                    "children": children_ids,
                    "label_text": label_text,
                    "shape_type": shape_type,
                    "confidence": confidence,
                    "meta": meta,
                }
            except Exception as exc:
                logger.debug(f"reconstruct_tree: skipping node due to error: {exc}")
                continue

        return result

Writing /content/TreeDecoder.py


In [None]:
%%writefile /content/tree_pos_encoder.py
import torch
import torch.nn as nn
import hashlib
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TreePosEncoder(nn.Module):
    def __init__(self, pos_dim=16, max_positions=10000, device='cpu'):
        super().__init__()
        self.device = device
        self.pos_dim = pos_dim
        self.max_positions = max_positions
        self.path_embed = nn.Embedding(max_positions, pos_dim).to(device)
        self.time_embed = nn.Embedding(100, pos_dim).to(device)

    def path_to_index(self, path):
        if isinstance(path, int):  # spiral_index
            return path % self.max_positions
        h = int(hashlib.md5(str(path).encode()).hexdigest(), 16)
        return h % self.max_positions

    def forward(self, access_path, timestamp: int = 0):
        path_idx = self.path_to_index(access_path)
        path_vec = self.path_embed(torch.tensor(path_idx, device=self.device))
        time_idx = timestamp % 100
        time_vec = self.time_embed(torch.tensor(time_idx, device=self.device))
        return path_vec + time_vec

Writing /content/tree_pos_encoder.py


#Section 3:

In [None]:

%%writefile /content/grid_seed_opencv.py
import cv2
import numpy as np
import math
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class GridDividerOpenCV:
    def __init__(self, grid_rows=None, grid_cols=None, target_subpart_area=32*32):
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        self.target_subpart_area = target_subpart_area

    def divide(self, image_np):
        if image_np is None or len(image_np.shape) != 3 or image_np.shape[2] != 3:
            logger.error("Invalid image: must be 3-channel BGR")
            return []
        h, w, _ = image_np.shape
        if self.grid_rows is None or self.grid_cols is None:
            total_area = w * h
            num_subparts = max(1, total_area // self.target_subpart_area)
            side = max(1, int(math.sqrt(num_subparts)))
            self.grid_rows = self.grid_cols = side
        grid_h = max(1, h // self.grid_rows)
        grid_w = max(1, w // self.grid_cols)
        subparts = []
        for r in range(self.grid_rows):
            for c in range(self.grid_cols):
                top = r * grid_h
                left = c * grid_w
                bottom = h if r == self.grid_rows - 1 else top + grid_h
                right = w if c == self.grid_cols - 1 else left + grid_w
                crop = image_np[top:bottom, left:right]
                if crop.size == 0:
                    logger.warning(f"Empty crop at ({r}, {c})")
                    continue
                subparts.append({
                    "coords": (r, c),
                    "box": (top, left, bottom, right),
                    "image": crop
                })
        logger.info(f"Divided image into {len(subparts)} subparts")
        return subparts

class SeedSelectorOpenCV:
    def __init__(self, method="sobel"):
        self.method = method

    def _sobel_gradient(self, gray):
        dx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
        dy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        grad = cv2.magnitude(dx, dy)
        return grad

    def select_seed(self, sub_image_np):
        if sub_image_np is None or len(sub_image_np.shape) != 3:
            logger.error("Invalid sub-image: must be 3-channel")
            return {"seed": (0, 0), "metric_value": 0.0}
        gray = cv2.cvtColor(sub_image_np, cv2.COLOR_BGR2GRAY)
        gray_small = cv2.resize(gray, (32, 32))
        if self.method == "sobel":
            metric = self._sobel_gradient(gray_small)
        elif self.method == "intensity":
            metric = gray_small.astype(np.float32)
        else:
            logger.error(f"Unsupported method: {self.method}")
            return {"seed": (0, 0), "metric_value": 0.0}
        idx = np.unravel_index(np.argmax(metric), metric.shape)
        logger.info(f"Selected seed at {idx} with metric {metric[idx]}")
        return {
            "seed": idx,
            "metric_value": float(metric[idx])
        }

Writing /content/grid_seed_opencv.py


In [None]:

%%writefile /content/whirlpool_scanner_opencv.py
import numpy as np
import logging
from TreeNodeV1 import TreeNodeV1
import torch
import cv2
import math

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class WhirlpoolScannerOpenCV:
    def __init__(self, color_threshold=30, max_nodes=100):
        self.color_threshold = color_threshold
        self.max_nodes = max_nodes

    def scan(self, image_np, seed):
        if image_np is None or len(image_np.shape) != 3 or image_np.shape[2] != 3:
            logger.error("Invalid image: must be 3-channel BGR")
            return []
        h, w, _ = image_np.shape
        center_y, center_x = seed
        if not (0 <= center_y < h and 0 <= center_x < w):
            logger.error(f"Invalid seed: {seed}, image size: ({h}, {w})")
            return []
        visited = np.zeros((h, w), dtype=bool)
        nodes = []

        def in_bounds(y, x):
            return 0 <= y < h and 0 <= x < w

        def color_dist(c1, c2):
            return np.linalg.norm(c1.astype(np.float32) - c2.astype(np.float32))

        directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
        spiral_radius = 1
        step = 1
        cy, cx = center_y, center_x
        origin_color = image_np[cy, cx]
        q = [(cy, cx)]
        visited[cy, cx] = True
        node_pixels = [(cy, cx)]

        while len(nodes) < self.max_nodes and step < max(h, w):
            for d in directions:
                for _ in range(spiral_radius):
                    cy += d[0]
                    cx += d[1]
                    if not in_bounds(cy, cx) or visited[cy, cx]:
                        continue
                    pixel_color = image_np[cy, cx]
                    if color_dist(pixel_color, origin_color) < self.color_threshold:
                        q.append((cy, cx))
                        visited[cy, cx] = True
                        node_pixels.append((cy, cx))
                    if len(node_pixels) >= 50:
                        node = self._build_node(image_np, node_pixels, len(nodes))
                        nodes.append(node)
                        node_pixels = []
                    if len(nodes) >= self.max_nodes:
                        break
                if len(nodes) >= self.max_nodes:
                    break
            step += 1
            spiral_radius += 1

        if node_pixels:
            node = self._build_node(image_np, node_pixels, len(nodes))
            nodes.append(node)

        logger.info(f"Whirlpool scan created {len(nodes)} nodes")
        for i, node in enumerate(nodes):
           logger.info(f"[Whirlpool] Node {i}: ID={node.id}, shape={node.shape_type}, vector_shape={None if node.get_vector() is None else tuple(node.get_vector().shape)}")
        return nodes

    # -----------------------------
    #  Shape descriptor functions
    # -----------------------------
    def _extract_patch(self, image_np, pixels, target_size=64, pad=4):
        """
        Extract a tight patch around pixels, pad and resize to target_size (square).
        """
        ys = [p[0] for p in pixels]
        xs = [p[1] for p in pixels]
        miny, maxy = max(0, min(ys)-pad), min(image_np.shape[0]-1, max(ys)+pad)
        minx, maxx = max(0, min(xs)-pad), min(image_np.shape[1]-1, max(xs)+pad)
        patch = image_np[miny:maxy+1, minx:maxx+1].copy()
        if patch.size == 0:
            patch = np.zeros((target_size, target_size, 3), dtype=np.uint8)
        patch_gray = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
        # resize preserving aspect ratio, pad to square
        h, w = patch_gray.shape
        if h == 0 or w == 0:
            patch_resized = np.zeros((target_size, target_size), dtype=np.uint8)
        else:
            scale = target_size / max(h, w)
            new_h, new_w = max(1, int(h*scale)), max(1, int(w*scale))
            resized = cv2.resize(patch_gray, (new_w, new_h), interpolation=cv2.INTER_AREA)
            top = (target_size - new_h) // 2
            left = (target_size - new_w) // 2
            patch_resized = np.zeros((target_size, target_size), dtype=np.uint8)
            patch_resized[top:top+new_h, left:left+new_w] = resized
        return patch_resized

    def _hog_descriptor(self, img, cells_y=2, cells_x=2, bins=8):
        """
        Simple HOG-style descriptor: divide image into cells_y x cells_x,
        compute gradient orientation hist in each cell with 'bins' bins.
        Returns cells_y * cells_x * bins dims.
        """
        # gradients
        gx = cv2.Sobel(img, cv2.CV_32F, 1, 0, ksize=3)
        gy = cv2.Sobel(img, cv2.CV_32F, 0, 1, ksize=3)
        mag, ang = cv2.cartToPolar(gx, gy, angleInDegrees=True)
        ang = ang % 180.0  # unsigned gradients
        h, w = img.shape
        cell_h = h // cells_y
        cell_w = w // cells_x
        hist = []
        for i in range(cells_y):
            for j in range(cells_x):
                y0, y1 = i*cell_h, (i+1)*cell_h if i < cells_y-1 else h
                x0, x1 = j*cell_w, (j+1)*cell_w if j < cells_x-1 else w
                mag_cell = mag[y0:y1, x0:x1].ravel()
                ang_cell = ang[y0:y1, x0:x1].ravel()
                if ang_cell.size == 0:
                    hist_cell = np.zeros(bins, dtype=np.float32)
                else:
                    hist_cell, _ = np.histogram(ang_cell, bins=bins, range=(0,180), weights=mag_cell)
                    if hist_cell.sum() > 0:
                        hist_cell = hist_cell / (hist_cell.sum() + 1e-6)
                hist.append(hist_cell)
        return np.concatenate(hist).astype(np.float32)

    def _radial_profile(self, img, bins=9):
        """
        Radial profile: distances from centroid, histogram into 'bins'.
        """
        h, w = img.shape
        Y, X = np.indices((h, w))
        mask = img > 0  # non-zero intensities as foreground
        if not mask.any():
            return np.zeros(bins, dtype=np.float32)
        ys = Y[mask]; xs = X[mask]
        cy, cx = ys.mean(), xs.mean()
        dists = np.sqrt((ys - cy)**2 + (xs - cx)**2)
        maxd = dists.max() if dists.size>0 else 1.0
        hist, _ = np.histogram(dists, bins=bins, range=(0, maxd), weights=None)
        if hist.sum() > 0:
            hist = hist / (hist.sum() + 1e-6)
        return hist.astype(np.float32)

    def _contour_stats(self, img):
        """
        Compute basic contour statistics from a binary threshold of img.
        Returns: contour_count, mean_area_ratio, perimeter_area_ratio, solidity, eccentricity, compactness
        """
        # adaptive threshold to emphasize shapes
        _, th = cv2.threshold(img, 0, 255, cv2.THRESH_OTSU + cv2.THRESH_BINARY)
        # close small holes
        kernel = np.ones((3,3), np.uint8)
        th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel)
        contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not contours:
            return np.array([0., 0., 0., 0., 0., 0.], dtype=np.float32)
        areas = np.array([cv2.contourArea(c) for c in contours], dtype=np.float32)
        perims = np.array([cv2.arcLength(c, True) for c in contours], dtype=np.float32)
        total_area = areas.sum()
        img_area = img.shape[0] * img.shape[1]
        mean_area_ratio = (areas.mean() / (img_area + 1e-6))
        contour_count = float(len(contours))
        # perimeter/area ratio (mean)
        per_area = (perims / (areas + 1e-6)).mean()
        # solidity: area / convex hull area mean
        hull_areas = []
        for c in contours:
            hull = cv2.convexHull(c)
            hull_areas.append(cv2.contourArea(hull) + 1e-6)
        hull_areas = np.array(hull_areas, dtype=np.float32)
        solidity = (areas / hull_areas).mean()
        # eccentricity from moments (for largest contour)
        largest_idx = int(np.argmax(areas))
        c = contours[largest_idx]
        mu = cv2.moments(c)
        if mu['mu20'] + mu['mu02'] == 0:
            eccentricity = 0.0
        else:
            common = math.sqrt((mu['mu20'] - mu['mu02'])**2 + 4*mu.get('mu11',0)**2)
            l1 = (mu['mu20'] + mu['mu02'] + common) / 2.0
            l2 = (mu['mu20'] + mu['mu02'] - common) / 2.0
            if l1 <= 0:
                eccentricity = 0.0
            else:
                eccentricity = float(math.sqrt(1 - (l2 / (l1 + 1e-9))))
        # compactness: (perimeter^2) / (4*pi*area) mean
        compact = ((perims**2) / (4 * math.pi * (areas + 1e-6))).mean()
        return np.array([contour_count, mean_area_ratio, per_area, solidity, eccentricity, compact], dtype=np.float32)

    def _edge_density(self, img):
        edges = cv2.Canny(img, 50, 150)
        return float(edges.sum()) / (img.size + 1e-6)

    def _gradient_stats(self, img):
        gx = cv2.Sobel(img, cv2.CV_32F, 1, 0, ksize=3)
        gy = cv2.Sobel(img, cv2.CV_32F, 0, 1, ksize=3)
        mag = np.sqrt(gx*gx + gy*gy)
        return np.array([mag.mean(), mag.std()], dtype=np.float32)

    # -----------------------------
    #  _build_node replaced to produce 50-d descriptor
    # -----------------------------
    def _build_node(self, image_np, pixels, node_id):
       patch = self._extract_patch(image_np, pixels, target_size=64, pad=4)

       hog = self._hog_descriptor(patch, cells_y=2, cells_x=2, bins=8)
       radial = self._radial_profile(patch, bins=9)
       contour_feats = self._contour_stats(patch)
       edge_d = np.array([self._edge_density(patch)], dtype=np.float32)
       grad_stats = self._gradient_stats(patch)

       feat = np.concatenate([hog, radial, contour_feats, edge_d, grad_stats], axis=0)

       if feat.shape[0] != 50:
           if feat.shape[0] < 50:
               feat = np.concatenate([feat, np.zeros(50 - feat.shape[0], dtype=np.float32)])
           else:
               feat = feat[:50]

       norm = np.linalg.norm(feat)
       if norm > 0:
           feat = feat / (norm + 1e-8)


       node = TreeNodeV1(id=node_id, value=f"node_{node_id}", shape_type=None)
       node.store_vector(torch.tensor(feat, dtype=torch.float32))
       return node

Writing /content/whirlpool_scanner_opencv.py


In [None]:

%%writefile /content/parallel_whirlpool_processor_opencv.py
import torch
import cv2
import numpy as np
import math
import logging
from grid_seed_opencv import GridDividerOpenCV, SeedSelectorOpenCV
from whirlpool_scanner_opencv import WhirlpoolScannerOpenCV

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ParallelWhirlpoolProcessorOpenCV:
    def __init__(self, grid_rows=None, grid_cols=None, target_area=1024, max_threads=1,
                 color_thresh=30, max_nodes=50):
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        self.target_area = target_area
        self.max_threads = max_threads
        self.color_thresh = color_thresh
        self.max_nodes = max_nodes
        self.scanner = WhirlpoolScannerOpenCV(color_threshold=color_thresh, max_nodes=max_nodes)
        self.selector = SeedSelectorOpenCV(method="sobel")
        logger.info(f"Initialized ParallelWhirlpoolProcessor with {max_threads} threads")

    def process(self, image_np):
        if image_np is None or len(image_np.shape) != 3:
            logger.error("Invalid image: must be 3-channel")
            return []
        divider = GridDividerOpenCV(grid_rows=self.grid_rows,
                                    grid_cols=self.grid_cols,
                                    target_subpart_area=self.target_area)
        subparts = divider.divide(image_np)
        results = []
        for part in subparts:  # Single-thread for Colab stability
            sub_img = part["image"]
            coords = part["coords"]
            try:
                seed_info = self.selector.select_seed(sub_img)
                seed_y, seed_x = seed_info["seed"]
                resized_h, resized_w = 32, 32
                orig_h, orig_w, _ = sub_img.shape
                scale_y = orig_h / resized_h
                scale_x = orig_w / resized_w
                scaled_y = min(int(seed_y * scale_y), orig_h - 1)
                scaled_x = min(int(seed_x * scale_x), orig_w - 1)
                nodes = self.scanner.scan(sub_img, seed=(scaled_y, scaled_x))
                results.append({
                    "coords": coords,
                    "seed": (scaled_y, scaled_x),
                    "nodes": nodes,
                    "metric_value": seed_info["metric_value"]
                })
            except Exception as e:
                logger.error(f"Error processing subpart {coords}: {e}")
                results.append({"coords": coords, "error": str(e)})
        logger.info(f"Processed {len(results)} subparts")
        return results

Writing /content/parallel_whirlpool_processor_opencv.py


In [None]:

%%writefile /content/whirlpool_node_standardizer.py
import torch
import logging
import numpy as np

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class WhirlpoolNodeStandardizer:
    def __init__(self, target_dim=16, device='cpu', max_norm=255.0):
        self.target_dim = target_dim
        self.device = device
        self.max_norm = max_norm

    def _normalize_vector(self, vec: torch.Tensor) -> torch.Tensor:
        """Normalize any vector to target_dim, no shape-tag logic."""
        if vec is None:
            return torch.zeros(self.target_dim, device=self.device)

        vec = vec.to(self.device).float()

        if torch.isnan(vec).any() or torch.isinf(vec).any():
            return torch.zeros(self.target_dim, device=self.device)

        # If RGB or image feature vector
        vec = vec / (self.max_norm + 1e-8)

        # Resize
        if vec.shape[0] == self.target_dim:
            return vec
        elif vec.shape[0] < self.target_dim:
            padded = torch.zeros(self.target_dim, device=self.device)
            padded[:vec.shape[0]] = vec
            return padded
        else:
            return vec[:self.target_dim]

    def clean_node(self, node):
        """Remove empty or broken nodes, keep only vector-based validation."""
        if node is None:
            return None

        try:
            vec = node.get_vector()
            if vec is None:
                logger.warning(f"Node {getattr(node, 'id', '?')} skipped: no vector")
                return None

            # Normalize vector
            vec_norm = self._normalize_vector(vec)
            node.store_vector(vec_norm)

            # Default simple attributes
            if not hasattr(node, "spiral_index"):
                node.spiral_index = -1
            if not hasattr(node, "level"):
                node.level = 0

            # Reject all-zero nodes
            if torch.allclose(vec_norm, torch.zeros_like(vec_norm)):
                logger.warning(f"Node {getattr(node, 'id', '?')} skipped: zero-vector")
                return None

            return node

        except Exception as e:
            logger.error(f"Failed to clean node: {e}")
            return None

    def standardize_batch(self, nodes):
        cleaned = []
        for node in nodes:
            c = self.clean_node(node)
            if c is not None:
                cleaned.append(c)
        logger.info(f"Standardized {len(cleaned)} / {len(nodes)} nodes")
        return cleaned

Writing /content/whirlpool_node_standardizer.py


In [None]:

%%writefile /content/focused_grid_whirlpool_processor.py
import cv2
import logging
from grid_seed_opencv import GridDividerOpenCV, SeedSelectorOpenCV
from parallel_whirlpool_processor_opencv import ParallelWhirlpoolProcessorOpenCV

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class FocusedGridWhirlpoolProcessor:
    def __init__(self, grid_rows=14, grid_cols=14, top_k=40, seed_method="sobel"):
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        self.top_k = top_k
        self.grid_divider = GridDividerOpenCV(grid_rows=self.grid_rows, grid_cols=self.grid_cols)
        self.seed_selector = SeedSelectorOpenCV(method=seed_method)
        self.whirlpool = ParallelWhirlpoolProcessorOpenCV()

    def process(self, image_np):
        if image_np is None or len(image_np.shape) != 3:
            logger.error("Invalid input image.")
            return []

        subparts = self.grid_divider.divide(image_np)
        if not subparts:
            logger.warning("No subparts extracted.")
            return []

        # Score each subpart by its seed metric
        scored = []
        for part in subparts:
            result = self.seed_selector.select_seed(part["image"])
            scored.append({
                "coords": part["coords"],
                "box": part["box"],
                "image": part["image"],
                "metric": result["metric_value"],
                "seed": result["seed"]
            })

        # Select top-K segments
        top_segments = sorted(scored, key=lambda x: x["metric"], reverse=True)[:self.top_k]
        logger.info(f"Selected top {self.top_k} of {len(scored)} subparts for focused processing")

        # Process selected segments with Whirlpool
        results = []
        for seg in top_segments:
            whirl_result = self.whirlpool.process(seg["image"])
            results.extend(whirl_result)

        logger.info(f"Extracted {len(results)} node groups from top segments.")
        return results

Writing /content/focused_grid_whirlpool_processor.py


In [None]:

%%writefile /content/parallel_grid_seed_whirlpool_processor.py
import cv2
import logging
from grid_seed_opencv import GridDividerOpenCV, SeedSelectorOpenCV
from parallel_whirlpool_processor_opencv import ParallelWhirlpoolProcessorOpenCV

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ParallelGridSeedWhirlpoolProcessor:
    """
    SAFE VERSION (No multiprocessing)
    ---------------------------------
    • Processes subparts sequentially (but very lightweight)
    • 100% compatible with all existing pipeline code
    • Works reliably inside Google Colab / Jupyter
    """

    def __init__(self, grid_rows=14, grid_cols=14, top_k=40, seed_method="sobel"):
        self.grid_rows = grid_rows
        self.grid_cols = grid_cols
        self.top_k = top_k
        self.grid_divider = GridDividerOpenCV(grid_rows=grid_rows, grid_cols=grid_cols)
        self.seed_selector = SeedSelectorOpenCV(method=seed_method)
        self.whirlpool = ParallelWhirlpoolProcessorOpenCV()

    def _score_segment_safe(self, part):
        """Local function — no pickling, no multiprocessing."""
        try:
            res = self.seed_selector.select_seed(part["image"])
            return {
                "coords": part["coords"],
                "box": part["box"],
                "image": part["image"],
                "metric": res["metric_value"],
                "seed": res["seed"]
            }
        except Exception as e:
            logger.error(f"Segment scoring failure: {e}")
            return None

    def process(self, image_np):
        if image_np is None or len(image_np.shape) != 3:
            logger.error("Invalid input image.")
            return []

        # Step 1: Divide into grid
        subparts = self.grid_divider.divide(image_np)
        if not subparts:
            logger.warning("No subparts extracted.")
            return []

        logger.info(f"Scoring {len(subparts)} segments (SAFE single-thread)...")

        # Step 2: Score each segment safely
        scored = []
        for part in subparts:
            score = self._score_segment_safe(part)
            if score:
                scored.append(score)

        # Step 3: Take top-K by saliency metric
        top_segments = sorted(scored, key=lambda x: x["metric"], reverse=True)[:self.top_k]
        logger.info(f"Selected top {self.top_k} of {len(scored)} subparts")

        # Step 4: Whirlpool feature extraction
        results = []
        for seg in top_segments:
            try:
                whirl = self.whirlpool.process(seg["image"])
                results.extend(whirl)
            except Exception as e:
                logger.error(f"Whirlpool error: {e}")

        logger.info(f"Extracted {len(results)} node groups from top segments.")
        return results

Writing /content/parallel_grid_seed_whirlpool_processor.py


In [None]:

%%writefile /content/ReverseWhirlpoolCleanerV2.py
"""
ReverseWhirlpoolCleanerV2 (PURE SHAPE-FREE VERSION)
---------------------------------------------------
Cleans and merges nodes ONLY using:
 - vector norm
 - cosine similarity
 - confidence

All references to shape_type, TLite, idx mappings are removed.
"""

import logging
from typing import Optional, List
import torch
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("ReverseWhirlpoolCleanerV2")


def _cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float:
    try:
        a = a.view(-1).float()
        b = b.view(-1).float()
        denom = (torch.norm(a) * torch.norm(b)).item() + 1e-8
        return float((a @ b).item() / denom)
    except Exception:
        return 0.0


class ReverseWhirlpoolCleanerV2:
    def __init__(
        self,
        sim_threshold: float = 0.85,
        min_norm: float = 1e-3,
        device: str = "cpu"
    ):
        self.sim_threshold = float(sim_threshold)
        self.min_norm = float(min_norm)
        self.device = device

    # ---------------------------------------------------
    # STEP 1 — FILTER BAD NODES
    # ---------------------------------------------------
    def _valid_children(self, children: List[TreeNodeV1]) -> List[TreeNodeV1]:
        valid = []
        for c in children:
            v = c.get_vector()
            if v is None:
                continue
            try:
                n = float(torch.norm(v).item())
                if n < self.min_norm:
                    continue
            except Exception:
                continue
            valid.append(c)
        return valid

    # ---------------------------------------------------
    # STEP 2 — MERGE SIMILAR NODES
    # ---------------------------------------------------
    def _merge_similar(self, nodes: List[TreeNodeV1]) -> List[TreeNodeV1]:
        merged = []
        for n in nodes:
            nv = n.get_vector()
            if nv is None:
                continue

            placed = False
            for rep in merged:
                sim = _cosine_sim(nv, rep.get_vector())
                if sim >= self.sim_threshold:
                    try:
                        w1 = float(rep.get_confidence())
                        w2 = float(n.get_confidence())
                        total = w1 + w2 if (w1 + w2) > 1e-8 else 1.0

                        new_vec = (rep.get_vector() * w1 + nv * w2) / total
                        rep.store_vector(new_vec)

                        rep.confidence = max(rep.confidence, n.confidence)

                        if not rep.label_text and n.label_text:
                            rep.set_label_text(n.label_text)
                    except Exception:
                        pass

                    placed = True
                    break

            if not placed:
                merged.append(n)

        return merged

    # ---------------------------------------------------
    # STEP 3 — MAIN CLEAN FUNCTION
    # ---------------------------------------------------
    def clean(self, root: TreeNodeV1) -> Optional[TreeNodeV1]:
        if root is None:
            return None

        if not hasattr(root, "children"):
            return root

        # 1. prune
        valid = self._valid_children(root.children)
        if not valid:
            logger.info("Cleaner: no valid children")
            return TreeNodeV1(value=root.value, shape_type="group", id=root.id)

        # 2. merge
        merged = self._merge_similar(valid)

        # 3. new root (shape-free → we still keep 'group' as harmless label)
        new_root = TreeNodeV1(
            value=root.value,
            shape_type="group",
            id=root.id,
            max_children=len(merged)
        )

        # attach children
        for i, n in enumerate(merged):
            n.set_spiral_index(i)
            new_root.add_child(n)

        # 4. recompute root vector
        child_vecs = [c.get_vector() for c in new_root.children if c.get_vector() is not None]
        if child_vecs:
            try:
                new_root.store_vector(torch.stack(child_vecs).mean(dim=0))
            except Exception:
                pass

        return new_root

Writing /content/ReverseWhirlpoolCleanerV2.py


In [None]:

%%writefile /content/utils_positional.py
import torch
import logging
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def assign_positional_encoding(node, pos_encoder, temp_encoder, path="root", timestamp=0, device='cpu'):
    """
    Add positional + temporal encodings into node vectors in-place.

    - node: TreeNodeV1
    - pos_encoder: callable(access_path, timestamp) -> torch.Tensor (dim == node.vector dim or compatible)
    - temp_encoder: callable(timestamp) -> torch.Tensor (dim == node.vector dim or compatible)
    - path: path id (used only if spiral_index missing)
    - timestamp: integer time-step
    - device: 'cpu' or 'cuda'
    """
    if not isinstance(node, TreeNodeV1):
        logger.warning(f"assign_positional_encoding: invalid node type: {type(node)}")
        return

    node_vec = node.get_vector()
    if node_vec is None:
        logger.warning(f"assign_positional_encoding: node has no vector: {getattr(node,'id',None)}")
        return

    try:
        # determine access path (prefer spiral_index if set)
        path_value = getattr(node, 'spiral_index', path)
        pos_vec = pos_encoder(path_value, timestamp)
        temp_vec = temp_encoder(timestamp)

        # ensure tensors and device
        if not isinstance(pos_vec, torch.Tensor):
            pos_vec = torch.tensor(pos_vec, dtype=torch.float32)
        if not isinstance(temp_vec, torch.Tensor):
            temp_vec = torch.tensor(temp_vec, dtype=torch.float32)

        pos_vec = pos_vec.to(device)
        temp_vec = temp_vec.to(device)
        node_vec = node_vec.to(device)

        # if encoders produce same dim as node vector, just add
        if pos_vec.shape[0] == node_vec.shape[0] and temp_vec.shape[0] == node_vec.shape[0]:
            node.store_vector(node_vec + pos_vec + temp_vec)
        else:
            # fallback: concatenate and crop to node vector length
            cat = torch.cat([node_vec, pos_vec, temp_vec], dim=0)
            node.store_vector(cat[: node_vec.shape[0]])

    except Exception as e:
        logger.warning(f"assign_positional_encoding: failed for node {getattr(node,'id', None)}: {e}")

    # recursive
    for i, child in enumerate(node.children):
        try:
            assign_positional_encoding(child, pos_encoder, temp_encoder, f"{path}.{i}", timestamp + 1, device)
        except Exception as e:
            logger.debug(f"assign_positional_encoding: child {i} failed: {e}")

    logger.debug(f"Assigned positional and temporal encoding to node: {node.value}")

Writing /content/utils_positional.py


In [None]:

%%writefile /content/explainability.py
import torch
import logging
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ExplainabilityModule:
    def __init__(self, device='cpu'):
        self.device = device

    def explain(self, node, depth=0):
        """
        Return list of human-readable explanation strings for the tree.
        """
        if not isinstance(node, TreeNodeV1):
            logger.error(f"ExplainabilityModule.explain: Invalid node type: {type(node)}")
            return []

        try:
            conf = node.get_confidence() if hasattr(node, "get_confidence") else float(getattr(node, "confidence", 0.0))
            explanation = [f"Depth {depth}: Node {node.value}, Confidence {conf:.4f}, Shape {getattr(node,'shape_type', 'N/A')}"]
        except Exception:
            explanation = [f"Depth {depth}: Node {node.value}"]

        for child in node.children:
            explanation.extend(self.explain(child, depth + 1))
        return explanation

    def visualize(self, node):
        """
        Log the explanation and return it.
        """
        try:
            explanation = self.explain(node)
            if explanation:
                logger.info("Tree Explanation:\n" + "\n".join(explanation))
            else:
                logger.info("Tree Explanation: <empty>")
            return explanation
        except Exception as e:
            logger.error(f"ExplainabilityModule.visualize failed: {e}")
            return []

Writing /content/explainability.py


In [None]:

%%writefile /content/positional_encoder.py
import torch

class PositionalEncoder:
    """
    Device-aware positional encoder that returns a vector of length `dim`.
    Uses a vectorized implementation compatible with torch operations.
    """

    def __init__(self, dim=50, device='cpu'):
        self.dim = int(dim)
        self.device = device

    def _path_to_index(self, access_path):
        # Convert path like 'root.0.1.2' → numerical index (e.g., combine digits).
        if isinstance(access_path, str):
            parts = access_path.strip().split('.')
            nums = [int(p) for p in parts if p.isdigit()]
            if nums:
                # deterministic combination
                idx = 0
                for n in nums:
                    idx = idx * 31 + (n + 1)
                return idx
        elif isinstance(access_path, int):
            return int(access_path)
        return 0

    def __call__(self, access_path, timestamp=0):
        """
        Return positional encoding tensor (torch.Tensor) on configured device.
        The encoding uses a standard sin/cos scheme vectorized with torch.
        """
        index = self._path_to_index(access_path)
        # combine index and timestamp to provide time-varying encoding
        pos_val = float(index + int(timestamp))

        device = self.device
        dim = self.dim

        pe = torch.zeros(dim, dtype=torch.float32, device=device)

        # positions scalar as tensor
        pos = torch.tensor([pos_val], dtype=torch.float32, device=device)

        # create denominators for even indices
        inv_freq = torch.pow(10000.0, (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))

        # compute vectorized sin/cos
        angles = pos / inv_freq  # shape (dim/2,)
        pe[0:dim:2] = torch.sin(angles)
        pe[1:dim:2] = torch.cos(angles[: pe[1:dim:2].shape[0]])

        return pe

Writing /content/positional_encoder.py


In [None]:

%%writefile /content/temporal_encoder.py
import torch

class TemporalEncoder:
    """
    Simple sinusoidal temporal encoder (dimension = dim).
    Produces a time-dependent encoding similar to transformer time embeddings.
    """

    def __init__(self, dim=16, device='cpu'):
        self.dim = int(dim)
        self.device = device
        self.scale = torch.sqrt(torch.tensor(float(dim), dtype=torch.float32, device=device))

    def forward(self, timestamp: int):
        try:
            t = torch.tensor(float(timestamp), dtype=torch.float32, device=self.device)

            # even indices 0,2,4...
            idx = torch.arange(0, self.dim, 2, dtype=torch.float32, device=self.device)

            # denominator term
            div_term = torch.exp(idx * (-torch.log(torch.tensor(10000.0, device=self.device)) / self.dim))

            encoding = torch.zeros(self.dim, dtype=torch.float32, device=self.device)
            encoding[0::2] = torch.sin(t * div_term)
            encoding[1::2] = torch.cos(t * div_term)

            return encoding / self.scale
        except Exception:
            # safe fallback
            return torch.zeros(self.dim, dtype=torch.float32, device=self.device)

    def __call__(self, timestamp):
        return self.forward(timestamp)

Writing /content/temporal_encoder.py


In [None]:

%%writefile /content/confidence_injector.py
import torch
import logging
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ConfidenceInjector")


class ConfidenceInjector:
    """
    Injects confidence scores into nodes based on vector statistics.
    mode = "norm": confidence = normalized vector norm
    """

    def __init__(self, mode="norm", device="cpu"):
        self.mode = mode
        self.device = device

    def inject(self, node):
        if not isinstance(node, TreeNodeV1):
            logger.warning(f"ConfidenceInjector: invalid node type {type(node)}")
            return

        vec = node.get_vector()

        if vec is None:
            node.confidence = 0.0
        else:
            vec = vec.to(self.device)

            if self.mode == "norm":
                # vector norm normalized to [0,1]
                raw = float(torch.norm(vec).item())
                conf = raw / 10.0      # scale factor, adjustable
                conf = max(0.0, min(conf, 1.0))
            else:
                conf = 0.5   # fallback constant

            node.confidence = round(conf, 4)

        # recurse into children
        for child in node.children:
            self.inject(child)

Writing /content/confidence_injector.py


In [None]:

%%writefile /content/TreeCNNppRunner.py
import cv2
import torch
import logging
from typing import List, Tuple, Optional

from TreeBuilderV2 import TreeBuilderV2
from ReverseWhirlpoolCleanerV2 import ReverseWhirlpoolCleanerV2
from utils_positional import assign_positional_encoding
from positional_encoder import PositionalEncoder
from temporal_encoder import TemporalEncoder

# Extractors (optional imports)
try:
    from parallel_grid_seed_whirlpool_processor import ParallelGridSeedWhirlpoolProcessor
except Exception:
    ParallelGridSeedWhirlpoolProcessor = None

try:
    from focused_grid_whirlpool_processor import FocusedGridWhirlpoolProcessor
except Exception:
    FocusedGridWhirlpoolProcessor = None

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("TreeCNNppRunner")


def _safe_extract_vec_pairs(nodes: List) -> List[Tuple[str, torch.Tensor]]:
    """
    Accepts a list of TreeNodeV1-like objects and returns [(token, tensor), ...].
    Ignores items without get_vector / vector or with invalid vectors.
    """
    pairs = []
    skipped = 0
    for i, n in enumerate(nodes or []):
        try:
            token = getattr(n, "value", f"node_{i}")
            # prefer API get_vector(), else attribute 'vector'
            vec = None
            if hasattr(n, "get_vector"):
                vec = n.get_vector()
            elif hasattr(n, "vector"):
                vec = getattr(n, "vector")
            if vec is None:
                skipped += 1
                continue
            if not isinstance(vec, torch.Tensor):
                vec = torch.tensor(vec, dtype=torch.float32)
            pairs.append((str(token), vec))
        except Exception:
            skipped += 1
            continue
    if skipped:
        logger.debug(f"_safe_extract_vec_pairs: skipped {skipped} invalid items")
    return pairs


def _collect_raw_nodes(extractor_output) -> List:
    """
    Normalize extractor output to a flat list of node-like objects.
    Accepts:
      - list of dicts with key "nodes" (list)
      - list of TreeNodeV1 objects
      - single dict containing nodes
    """
    raw = []
    if extractor_output is None:
        return raw

    # If extractor returned a list, iterate
    if isinstance(extractor_output, list):
        for part in extractor_output:
            if isinstance(part, dict):
                nodes = part.get("nodes")
                if isinstance(nodes, list):
                    raw.extend(nodes)
                else:
                    # some extractors may place nodes directly under different keys
                    maybe = part.get("node") or part.get("data")
                    if isinstance(maybe, list):
                        raw.extend(maybe)
            else:
                # assume it's a node-like object
                raw.append(part)
    elif isinstance(extractor_output, dict):
        nodes = extractor_output.get("nodes")
        if isinstance(nodes, list):
            raw.extend(nodes)
    else:
        # single node-like object
        raw.append(extractor_output)

    return raw


def run_treecnnpp(image_path: str,
                  device: str = "cpu",
                  dim: int = 50,
                  use_parallel: bool = False) -> Optional[object]:
    """
    Shape-free runner:
      - read image
      - run extractor (parallel or focused)
      - build pre-tree from node vectors
      - clean (norm + merge)
      - assign positional/temporal encodings
      - rebuild final tree and return it
    """
    # read image
    image = cv2.imread(image_path)
    if image is None:
        logger.error(f"TreeCNNppRunner: failed to read image '{image_path}'")
        return None

    # choose extractor
    extractor = None
    if use_parallel and ParallelGridSeedWhirlpoolProcessor is not None:
        try:
            extractor = ParallelGridSeedWhirlpoolProcessor()
        except Exception as e:
            logger.warning(f"Failed to init Parallel extractor: {e}")
            extractor = None

    if extractor is None and FocusedGridWhirlpoolProcessor is not None:
        try:
            extractor = FocusedGridWhirlpoolProcessor()
        except Exception as e:
            logger.warning(f"Failed to init Focused extractor: {e}")
            extractor = None

    if extractor is None:
        logger.error("No extractor available (install parallel_grid_seed_whirlpool_processor or focused_grid_whirlpool_processor).")
        return None

    # extract
    try:
        extractor_output = extractor.process(image)
    except Exception as e:
        logger.error(f"Extractor.process failed: {e}")
        return None

    raw_nodes = _collect_raw_nodes(extractor_output)
    if not raw_nodes:
        logger.warning("Extractor returned no raw nodes")
        return None

    vec_pairs = _safe_extract_vec_pairs(raw_nodes)
    if not vec_pairs:
        logger.warning("No valid (token,vector) pairs after extraction")
        return None

    # Build pre-tree
    builder = TreeBuilderV2(device=device, dim=dim)
    try:
        pre_tree = builder.build_tree(vec_pairs)
    except Exception as e:
        logger.error(f"Pre-tree build failed: {e}")
        return None

    if pre_tree is None:
        logger.warning("Pre-tree construction returned None")
        return None

    # Clean (shape-free)
    cleaner = ReverseWhirlpoolCleanerV2(sim_threshold=0.85, min_norm=1e-3, device=device)
    try:
        cleaned = cleaner.clean(pre_tree)
    except Exception as e:
        logger.error(f"Cleaner failed: {e}")
        return None

    if cleaned is None:
        logger.warning("Cleaner returned empty tree")
        return None

    # Positional + temporal encoding (pass named device to avoid arg-order bugs)
    pos_encoder = PositionalEncoder(dim)
    temp_encoder = TemporalEncoder(dim)
    try:
        assign_positional_encoding(cleaned, pos_encoder, temp_encoder, path="root", timestamp=0, device=device)
    except Exception as e:
        logger.warning(f"assign_positional_encoding failed: {e}")

    # Rebuild final tree from cleaned children (if any)
    final_pairs = _safe_extract_vec_pairs(getattr(cleaned, "children", []))
    if not final_pairs:
        # nothing to rebuild; return cleaned root
        return cleaned

    try:
        final_tree = builder.build_tree(final_pairs)
    except Exception as e:
        logger.warning(f"Final build failed, returning cleaned tree: {e}")
        return cleaned

    return final_tree


# simple CLI
if __name__ == "__main__":
    import sys
    if len(sys.argv) < 2:
        print("Usage: python TreeCNNppRunner.py path/to/image.jpg")
        sys.exit(1)
    out = run_treecnnpp(sys.argv[1], device='cpu', dim=50, use_parallel=False)
    print("Result:", getattr(out, "value", None), "children:", len(getattr(out, "children", [])))

Writing /content/TreeCNNppRunner.py


# Section

In [None]:

%%writefile /content/lazy_recursive.c
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef __cplusplus
extern "C" {
#endif

typedef struct { float* data; int size; int* spiral_indices; int index_count; } Vector;

Vector* lazy_recursive(float* input, int size, int* spiral_indices, int index_count) {
    if (!input || !spiral_indices || size <= 0 || index_count <= 0) return NULL;
    Vector* result = (Vector*)malloc(sizeof(Vector));
    result->size = size;
    result->data = (float*)calloc(size, sizeof(float));
    result->spiral_indices = (int*)malloc(index_count * sizeof(int));
    result->index_count = index_count;

    for (int i = 0; i < index_count; i++) {
        int idx = spiral_indices[i];
        if (idx < size) {
            result->data[idx] = input[idx] * 2.0;
            result->spiral_indices[i] = idx;
        }
    }
    return result;
}

void free_vector(Vector* vec) {
    if (vec) {
        free(vec->data);
        free(vec->spiral_indices);
        free(vec);
    }
}

#ifdef __cplusplus
}
#endif

Writing /content/lazy_recursive.c


In [None]:

%%writefile /content/tree_pruner.c
#include <stdio.h>
#include <stdlib.h>

#ifdef __cplusplus
extern "C" {
#endif

typedef struct { int id; float confidence; float* vector; int vec_size; } Node;

void tree_pruner(Node* nodes, int* size, float threshold) {
    if (!nodes || !size || *size <= 0) return;
    int new_size = 0;
    for (int i = 0; i < *size; i++) {
        if (nodes[i].confidence >= threshold) {
            nodes[new_size] = nodes[i];
            new_size++;
        } else {
            free(nodes[i].vector);
        }
    }
    *size = new_size;
}

#ifdef __cplusplus
}
#endif

Writing /content/tree_pruner.c


In [None]:

%%writefile /content/tree_matrix_hybrid.c
#include <stdio.h>
#include <stdlib.h>

#ifdef __cplusplus
extern "C" {
#endif

typedef struct { float* data; int rows; int cols; int* indices; int nnz; } SparseMatrix;

SparseMatrix* tree_matrix_hybrid(float* node_data, int node_count, int dim, int* sparse_indices, int nnz) {
    if (!node_data || !sparse_indices || node_count <= 0 || dim <= 0 || nnz <= 0) return NULL;
    SparseMatrix* matrix = (SparseMatrix*)malloc(sizeof(SparseMatrix));
    matrix->rows = node_count;
    matrix->cols = dim;
    matrix->nnz = nnz > node_count * dim ? node_count * dim : nnz; // Cap nnz
    matrix->indices = (int*)malloc(matrix->nnz * sizeof(int));
    matrix->data = (float*)calloc(matrix->nnz, sizeof(float));

    for (int i = 0; i < matrix->nnz; i++) {
        int idx = sparse_indices[i % matrix->nnz];
        if (idx < node_count * dim) {
            matrix->data[i] = node_data[idx];
            matrix->indices[i] = idx;
        }
    }
    return matrix;
}

void free_matrix(SparseMatrix* matrix) {
    if (matrix) {
        free(matrix->data);
        free(matrix->indices);
        free(matrix);
    }
}

#ifdef __cplusplus
}
#endif

Writing /content/tree_matrix_hybrid.c


In [None]:
%%writefile /content/quantization.c
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

typedef enum { INT8, BFLOAT16, FLOAT16, FLOAT32 } QuantMode;

QuantMode select_quant_mode(int contour_count, const char* user_flag) {
    if (user_flag && strcmp(user_flag, "--accuracy=high") == 0) return FLOAT32;
    if (user_flag && strcmp(user_flag, "--efficiency=high") == 0) return INT8;
    if (contour_count <= 10) return INT8; // Simple shapes
    if (contour_count <= 20) return BFLOAT16; // Medium complexity
    return FLOAT16; // Complex images
}

void quantize(float* input, int size, int contour_count, const char* user_flag, void* output) {
    if (!input || !output || size <= 0) return;
    QuantMode mode = select_quant_mode(contour_count, user_flag);

    float max_val = 0.0f;
    for (int i = 0; i < size; i++) {
        if (isnan(input[i]) || isinf(input[i])) input[i] = 0.0f;
        float abs_val = fabsf(input[i]);
        if (abs_val > max_val) max_val = abs_val;
    }

    float scale = (max_val > 1e-6f && max_val < 1e6f) ? max_val / 127.0f : 1.0f;
    float* out = (float*)output;

    switch (mode) {
        case INT8: {
            for (int i = 0; i < size; i++) {
                int8_t q = (int8_t)(roundf(input[i] / scale));
                out[i] = q * scale;
            }
            break;
        }
        case BFLOAT16: {
            for (int i = 0; i < size; i++) {
                uint32_t f32 = *(uint32_t*)&input[i];
                uint16_t b16 = (uint16_t)(f32 >> 16);
                uint32_t restored = ((uint32_t)b16) << 16;
                out[i] = *(float*)&restored;
            }
            break;
        }
        case FLOAT16: {
            for (int i = 0; i < size; i++) {
                uint32_t f32 = *(uint32_t*)&input[i];
                uint16_t sign = (f32 >> 31) & 0x1;
                int16_t exp = ((f32 >> 23) & 0xFF) - 127 + 15;
                uint16_t mantissa = (f32 >> 13) & 0x3FF;

                if (exp > 31) exp = 31;
                if (exp < 0) exp = 0;

                uint16_t f16 = (sign << 15) | ((uint16_t)exp << 10) | mantissa;

                // convert back to float32
                uint32_t restored = ((uint32_t)(f16 & 0x8000) << 16) |
                                    (((uint32_t)((f16 >> 10) & 0x1F) + (127 - 15)) << 23) |
                                    ((uint32_t)(f16 & 0x3FF) << 13);
                out[i] = *(float*)&restored;
            }
            break;
        }
        case FLOAT32: {
            memcpy(out, input, size * sizeof(float));
            break;
        }
    }
}

#ifdef __cplusplus
}
#endif

Writing /content/quantization.c


In [None]:
!g++ -O3 -fPIC -shared /content/quantization.c -o /content/quantization.so

In [None]:

%%writefile /content/quantization_wrapper.py
import ctypes
import torch
import numpy as np
import logging
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("QuantizationWrapper")

# ---------- Python fallback quantizers ----------
def _py_quantize_int8(vec: np.ndarray):
    # symmetric int8 with scale chosen per-vector
    max_val = float(np.max(np.abs(vec))) if vec.size > 0 else 1.0
    scale = max_val / 127.0 if max_val > 1e-6 and max_val < 1e6 else 1.0
    q = np.round(vec / scale).astype(np.int8)
    return (q.astype(np.float32) * scale), scale

def _py_quantize_bfloat16(vec: np.ndarray):
    # naive bf16-like: zero-out low 16 bits of float32 mantissa via view+mask
    out = vec.astype(np.float32).copy()
    # use numpy view for bit-level ops
    u = out.view(np.uint32)
    # drop low 16 bits (keep top 16 bits -> bfloat16 restore style)
    u[:] = (u & 0xFFFF0000)
    out = u.view(np.float32)
    return out

def _py_quantize_float16(vec: np.ndarray):
    # use numpy's float16 conversion (may be slower but correct)
    return vec.astype(np.float16).astype(np.float32)

# ---------- Wrapper class ----------
class QuantizationWrapper:
    def __init__(self, dim=64, so_path="./quantization.so", allow_fallback=True):
        self.dim = dim
        self.so_path = so_path
        self.lib = None
        self.allow_fallback = bool(allow_fallback)

        # try to load the C shared library but do NOT raise on failure
        try:
            self.lib = ctypes.CDLL(self.so_path)
            # set argtypes/restype for safety (keeps previous signature)
            self.lib.quantize.argtypes = [
                ctypes.POINTER(ctypes.c_float),  # input
                ctypes.c_int,                    # size
                ctypes.c_int,                    # contour_count
                ctypes.c_char_p,                 # user_flag
                ctypes.c_void_p                  # output
            ]
            self.lib.quantize.restype = None
            logger.info(f"Loaded quantization library from {self.so_path}")
        except Exception as e:
            self.lib = None
            logger.warning(f"Could not load quantization library '{self.so_path}': {e}. Using Python fallback.")

    def _detect_mode(self, user_flag: str, contour_count: int):
        # match the C select_quant_mode heuristic
        if user_flag and user_flag == "--accuracy=high":
            return "FLOAT32"
        if user_flag and user_flag == "--efficiency=high":
            return "INT8"
        if contour_count <= 10:
            return "INT8"
        if contour_count <= 20:
            return "BFLOAT16"
        return "FLOAT16"

    def apply_quantization(self, node: TreeNodeV1, user_flag="--efficiency=high"):
        """
        Quantize node vectors in-place (recurses into children).
        If the C library is present, uses it; otherwise uses a safe Python fallback.
        """
        if node is None:
            return

        vec_t = node.get_vector()
        if vec_t is None:
            # nothing to do, walk children
            for ch in node.children:
                self.apply_quantization(ch, user_flag=user_flag)
            return

        try:
            vec = vec_t.detach().cpu().numpy().astype(np.float32)
        except Exception as e:
            logger.error(f"Failed to convert node vector to numpy for node {getattr(node, 'id', '?')}: {e}")
            return

        # sanitize
        vec[np.isnan(vec)] = 0.0
        vec[np.isinf(vec)] = 0.0

        contour_count = 10
        try:
            if hasattr(node, "get_contour_count") and callable(getattr(node, "get_contour_count")):
                contour_count = int(node.get_contour_count())
            elif hasattr(node, "contour_count"):
                contour_count = int(getattr(node, "contour_count"))
        except Exception:
            contour_count = 10

        mode = self._detect_mode(user_flag, contour_count)

        # If we have C library loaded, prefer it but guard with try/except
        if self.lib is not None:
            try:
                size = int(len(vec))
                # prepare output buffer matching float32 array (C code writes floats back)
                OutputArrayType = ctypes.c_float * size
                output_buf = OutputArrayType()
                self.lib.quantize(
                    vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                    ctypes.c_int(size),
                    ctypes.c_int(contour_count),
                    ctypes.c_char_p(user_flag.encode() if user_flag else b""),
                    ctypes.cast(output_buf, ctypes.c_void_p)
                )
                quantized = np.frombuffer(output_buf, dtype=np.float32).copy()
                # sanity check
                if np.isnan(quantized).any() or np.isinf(quantized).any():
                    logger.warning(f"Quantized vector for node {node.id} contains NaNs/Infs — falling back to Python quantizer")
                    raise RuntimeError("C quantize produced invalid values")
                node.store_vector(torch.tensor(quantized, dtype=torch.float32))
            except Exception as e:
                logger.warning(f"C quantization failed for node {getattr(node,'id','?')}: {e}")
                # fall back to Python quantization if allowed
                if not self.allow_fallback:
                    return
                self._py_quantize_assign(node, vec, mode)
        else:
            # no C lib; use Python fallback
            self._py_quantize_assign(node, vec, mode)

        # recurse
        for ch in node.children:
            self.apply_quantization(ch, user_flag=user_flag)

    def _py_quantize_assign(self, node, vec: np.ndarray, mode: str):
        try:
            if mode == "INT8":
                qvec, scale = _py_quantize_int8(vec)
                node.store_vector(torch.tensor(qvec, dtype=torch.float32))
            elif mode == "BFLOAT16":
                qvec = _py_quantize_bfloat16(vec)
                node.store_vector(torch.tensor(qvec, dtype=torch.float32))
            elif mode == "FLOAT16":
                qvec = _py_quantize_float16(vec)
                node.store_vector(torch.tensor(qvec, dtype=torch.float32))
            else:  # FLOAT32
                node.store_vector(torch.tensor(vec.astype(np.float32), dtype=torch.float32))
        except Exception as e:
            logger.error(f"Python fallback quantization failed for node {getattr(node,'id','?')}: {e}")
            # as last resort store zeros of same shape
            node.store_vector(torch.zeros_like(node.get_vector() if node.get_vector() is not None else torch.zeros(self.dim)))

Writing /content/quantization_wrapper.py


In [None]:

%%writefile /content/attention_aggregator.c
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#ifdef __cplusplus
extern "C" {
#endif

void attention_aggregate(float* vectors, int num_vectors, int dim, float* output) {
    if (!vectors || !output || num_vectors <= 0 || dim <= 0) return;

    float* weights = (float*)calloc(num_vectors, sizeof(float));
    float sum = 0.0;

    for (int i = 0; i < num_vectors; i++) {
        weights[i] = expf(vectors[i * dim]);

        if (isnan(weights[i]) || isinf(weights[i])) {
            weights[i] = 0.0f;
        }

        sum += weights[i];
    }

    if (sum == 0.0f) {
        for (int i = 0; i < num_vectors; i++) {
            weights[i] = 1.0f / num_vectors;
        }
    } else {
        for (int i = 0; i < num_vectors; i++) {
            weights[i] /= sum;
        }
    }

    for (int j = 0; j < dim; j++) {
        output[j] = 0.0;
        for (int i = 0; i < num_vectors; i++) {
            output[j] += weights[i] * vectors[i * dim + j];
        }
    }

    free(weights);
}

#ifdef __cplusplus
}
#endif

Writing /content/attention_aggregator.c


In [None]:

%%writefile /content/attention_aggregator_wrapper.py
import ctypes
import torch
import logging
import numpy as np
from typing import List
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("attention_aggregator_wrapper")


def _python_attention_aggregate(vectors: List[np.ndarray], dim: int) -> torch.Tensor:
    """
    Safe Python fallback attention aggregation:
      - compute simple attention weights from the first element of each vector (stable)
      - normalize weights, then weighted-sum vectors
      - avoids exp overflow and NaNs
    """
    if not vectors:
        return torch.zeros(dim)
    arr = np.stack([np.asarray(v, dtype=np.float32).reshape(-1)[:dim] for v in vectors], axis=0)
    # score by mean energy of each vector (stable)
    scores = np.clip(np.nan_to_num(arr.mean(axis=1)), -1e3, 1e3)
    # shift and softmax in a stable way
    max_s = scores.max() if scores.size else 0.0
    exp_scores = np.exp(scores - max_s)
    if exp_scores.sum() == 0:
        weights = np.ones_like(exp_scores) / len(exp_scores)
    else:
        weights = exp_scores / (exp_scores.sum() + 1e-12)
    out = (weights[:, None] * arr).sum(axis=0)
    out = np.nan_to_num(out)
    return torch.tensor(out, dtype=torch.float32)


class AttentionAggregatorWrapper:
    def __init__(self, dim: int = 64, device: str = 'cpu'):
        self.dim = dim
        self.device = device
        self.lib = None
        self._load_shared_lib()

    def _load_shared_lib(self):
        try:
            self.lib = ctypes.CDLL('./attention_aggregator.so')
            self.lib.attention_aggregate.argtypes = [
                ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int,
                ctypes.POINTER(ctypes.c_float)
            ]
            self.lib.attention_aggregate.restype = None
            logger.info("Loaded attention_aggregator.so successfully.")
        except Exception as e:
            logger.warning(f"Could not load attention_aggregator.so — using Python fallback. ({e})")
            self.lib = None

    def __call__(self, nodes: List[TreeNodeV1]) -> torch.Tensor:
        try:
            if not nodes:
                logger.debug("AttentionAggregatorWrapper called with empty node list.")
                return torch.zeros(self.dim, device=self.device)

            vectors = []
            for i, node in enumerate(nodes):
                if isinstance(node, TreeNodeV1):
                    vec = node.get_vector()
                    if vec is None:
                        logger.debug(f"Node[{i}] ({getattr(node,'id',i)}) has no vector — skipping.")
                        continue
                    vec = vec.detach().cpu().to(dtype=torch.float32).reshape(-1)
                    if vec.numel() < 1:
                        continue
                    if torch.isnan(vec).any() or torch.isinf(vec).any():
                        logger.debug(f"Node[{i}] has NaN/Inf in vector — replacing with zeros for that vector.")
                        vec = torch.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
                    # convert to numpy for C wrapper or python fallback
                    vec_np = vec.cpu().numpy()
                    # ensure length >= dim by padding/truncating
                    if vec_np.size < self.dim:
                        pad = np.zeros(self.dim - vec_np.size, dtype=np.float32)
                        vec_np = np.concatenate([vec_np, pad], axis=0)
                    elif vec_np.size > self.dim:
                        vec_np = vec_np[:self.dim]
                    vectors.append(vec_np)
                else:
                    logger.debug(f"Item[{i}] is not TreeNodeV1 — skipping.")

            if not vectors:
                logger.warning("No valid vectors after processing nodes; returning zero vector.")
                return torch.zeros(self.dim, device=self.device)

            # Use C shared lib if available
            if self.lib is not None:
                try:
                    num_vectors = len(vectors)
                    flat = np.ascontiguousarray(np.concatenate(vectors, axis=0).astype(np.float32))
                    c_array = (ctypes.c_float * flat.size)(*flat.tolist())
                    out_array = (ctypes.c_float * self.dim)()
                    self.lib.attention_aggregate(c_array, ctypes.c_int(num_vectors), ctypes.c_int(self.dim), out_array)
                    out_np = np.ctypeslib.as_array(out_array)
                    out = torch.tensor(out_np, dtype=torch.float32, device=self.device)
                    if torch.isnan(out).any() or torch.isinf(out).any():
                        logger.warning("C aggregator produced NaN/Inf — falling back to Python aggregator.")
                        return _python_attention_aggregate(vectors, self.dim).to(self.device)
                    return out
                except Exception as e:
                    logger.warning(f"C aggregator failed at runtime ({e}) — falling back to Python aggregator.")
                    return _python_attention_aggregate(vectors, self.dim).to(self.device)
            else:
                # pure-python fallback
                return _python_attention_aggregate(vectors, self.dim).to(self.device)

        except Exception as e:
            logger.error(f"Attention aggregation unexpected error: {e}")
            return torch.zeros(self.dim, device=self.device)

Writing /content/attention_aggregator_wrapper.py


In [None]:

%%writefile /content/symbolic_feature_extractor.py
import cv2
import numpy as np

def symbolic_feature_extract_py(bgr_img: np.ndarray, spiral_indices: list, dim: int = 64) -> np.ndarray:
    if bgr_img is None or bgr_img.ndim != 3:
        return np.zeros(dim, dtype=np.float32)

    gray = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY)
    sobelx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
    sobely = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
    h, w = gray.shape

    vec = np.zeros(dim, dtype=np.float32)
    for i, idx in enumerate(spiral_indices[:dim]):
        y, x = divmod(idx, w)
        if 0 <= y < h and 0 <= x < w:
            intensity = gray[y, x] / 255.0
            gx, gy = sobelx[y, x], sobely[y, x]
            vec[i] = intensity + 0.5 * (abs(gx) + abs(gy)) / 255.0
        else:
            vec[i] = 0.0
    return vec

Writing /content/symbolic_feature_extractor.py


In [None]:

%%writefile /content/neural_enhancer.c
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#ifdef __cplusplus
extern "C" {
#endif

void neural_enhance(float* input, int dim, float* output) {
    if (!input || !output || dim <= 0) return;
    for (int i = 0; i < dim; i++) {
        output[i] = tanh(input[i] * 2.0);
    }
}

#ifdef __cplusplus
}
#endif

Writing /content/neural_enhancer.c


In [None]:

%%writefile /content/neural_enhancer_wrapper.py
import ctypes
import torch
import logging
import numpy as np

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("neural_enhancer_wrapper")


def _python_neural_enhance(vec: torch.Tensor) -> torch.Tensor:
    """
    Simple numerical stable nonlinear enhancement fallback:
      out = tanh(2.0 * vec) (same idea as C), applied elementwise with NaN/Inf protection.
    """
    if vec is None:
        return None
    v = vec.detach().cpu().float()
    v = torch.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)
    return torch.tanh(v * 2.0)


class NeuralEnhancerWrapper:
    def __init__(self, dim: int = 64, device: str = 'cpu'):
        self.dim = dim
        self.device = device
        self.lib = None
        self._load_shared_lib()

    def _load_shared_lib(self):
        try:
            self.lib = ctypes.CDLL('./libneural_enhancer.so')
            self.lib.neural_enhance.argtypes = [
                ctypes.POINTER(ctypes.c_float), ctypes.c_int,
                ctypes.POINTER(ctypes.c_float)
            ]
            self.lib.neural_enhance.restype = None
            logger.info("Loaded libneural_enhancer.so successfully.")
        except Exception as e:
            logger.warning(f"Could not load libneural_enhancer.so — using Python fallback. ({e})")
            self.lib = None

    def __call__(self, vec: torch.Tensor) -> torch.Tensor:
        try:
            if vec is None:
                logger.debug("NeuralEnhancerWrapper called with None vector.")
                return torch.zeros(self.dim, device=self.device)
            if not isinstance(vec, torch.Tensor):
                vec = torch.tensor(vec, dtype=torch.float32)
            v = vec.detach().cpu().float().reshape(-1)
            # pad/truncate
            if v.numel() < self.dim:
                v = torch.cat([v, torch.zeros(self.dim - v.numel(), dtype=torch.float32)], dim=0)
            elif v.numel() > self.dim:
                v = v[:self.dim]

            # try C lib if available
            if self.lib is not None:
                try:
                    in_np = v.numpy().astype(np.float32)
                    in_c = (ctypes.c_float * self.dim)(*in_np.tolist())
                    out_c = (ctypes.c_float * self.dim)()
                    self.lib.neural_enhance(in_c, ctypes.c_int(self.dim), out_c)
                    out_np = np.ctypeslib.as_array(out_c)
                    out = torch.tensor(out_np, dtype=torch.float32, device=self.device)
                    if torch.isnan(out).any() or torch.isinf(out).any():
                        logger.warning("C neural enhancer produced NaN/Inf — falling back to Python.")
                        return _python_neural_enhance(v).to(self.device)
                    return out
                except Exception as e:
                    logger.warning(f"C neural enhancer failed ({e}) — using Python fallback.")
                    return _python_neural_enhance(v).to(self.device)
            else:
                return _python_neural_enhance(v).to(self.device)
        except Exception as e:
            logger.error(f"Neural enhancement unexpected error: {e}")
            return torch.zeros(self.dim, device=self.device)

Writing /content/neural_enhancer_wrapper.py


In [None]:

%%writefile /content/multimodal_fusion_weights.c
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#ifdef __cplusplus
extern "C" {
#endif

void compute_fusion_weights(float* symbolic, float* neural, int dim, float* weights) {
    if (!symbolic || !neural || !weights || dim <= 0) return;
    float symbolic_norm = 0.0, neural_norm = 0.0;
    for (int i = 0; i < dim; i++) {
        symbolic_norm += symbolic[i] * symbolic[i];
        neural_norm += neural[i] * neural[i];
    }
    symbolic_norm = sqrt(symbolic_norm);
    neural_norm = sqrt(neural_norm);
    float total = symbolic_norm + neural_norm;
    weights[0] = total > 0 ? symbolic_norm / total : 0.5;
    weights[1] = total > 0 ? neural_norm / total : 0.5;
}

#ifdef __cplusplus
}
#endif

Writing /content/multimodal_fusion_weights.c


In [None]:

%%writefile /content/multimodal_fusion_weights_wrapper.py
import ctypes
import torch
import logging
import numpy as np

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("multimodal_fusion_weights_wrapper")


def _python_compute_fusion_weights(symbolic_vec: torch.Tensor, neural_vec: torch.Tensor) -> torch.Tensor:
    """
    Fallback strategy:
      - compute L2 norms and normalize: [norm_sym / (norm_sym+norm_neu), norm_neu / (...)]
      - stable to zero norms
    """
    def safe_norm(x):
        if x is None:
            return 0.0
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        x = x.detach().cpu().float()
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return float(torch.norm(x).item())

    ns = safe_norm(symbolic_vec)
    nn = safe_norm(neural_vec)
    total = ns + nn
    if total <= 1e-12:
        return torch.tensor([0.5, 0.5], dtype=torch.float32)
    return torch.tensor([ns / total, nn / total], dtype=torch.float32)


class MultimodalFusionWeightsWrapper:
    def __init__(self, dim: int = 64, device: str = 'cpu'):
        self.dim = dim
        self.device = device
        self.lib = None
        self._load_shared_lib()

    def _load_shared_lib(self):
        try:
            self.lib = ctypes.CDLL('./multimodal_fusion_weights.so')
            self.lib.compute_fusion_weights.argtypes = [
                ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float),
                ctypes.c_int, ctypes.POINTER(ctypes.c_float)
            ]
            self.lib.compute_fusion_weights.restype = None
            logger.info("Loaded multimodal_fusion_weights.so successfully.")
        except Exception as e:
            logger.warning(f"Could not load multimodal_fusion_weights.so — using Python fallback. ({e})")
            self.lib = None

    def compute_weights(self, symbolic_vec: torch.Tensor, neural_vec: torch.Tensor) -> torch.Tensor:
        try:
            if symbolic_vec is None or neural_vec is None:
                logger.warning("One or both input vectors are None — returning default [0.5,0.5].")
                return torch.tensor([0.5, 0.5], dtype=torch.float32, device=self.device)

            # prepare fixed-size arrays
            s = torch.tensor(symbolic_vec, dtype=torch.float32).detach().cpu().reshape(-1)
            n = torch.tensor(neural_vec, dtype=torch.float32).detach().cpu().reshape(-1)
            if s.numel() < self.dim:
                s = torch.cat([s, torch.zeros(self.dim - s.numel(), dtype=torch.float32)], dim=0)
            else:
                s = s[:self.dim]
            if n.numel() < self.dim:
                n = torch.cat([n, torch.zeros(self.dim - n.numel(), dtype=torch.float32)], dim=0)
            else:
                n = n[:self.dim]

            if self.lib is not None:
                try:
                    s_arr = (ctypes.c_float * self.dim)(*s.numpy().tolist())
                    n_arr = (ctypes.c_float * self.dim)(*n.numpy().tolist())
                    out_arr = (ctypes.c_float * 2)()
                    self.lib.compute_fusion_weights(s_arr, n_arr, ctypes.c_int(self.dim), out_arr)
                    w0 = float(out_arr[0])
                    w1 = float(out_arr[1])
                    out = torch.tensor([w0, w1], dtype=torch.float32, device=self.device)
                    # sanitize
                    out = torch.nan_to_num(out, nan=0.5, posinf=0.5, neginf=0.5)
                    ssum = float(out.sum())
                    if abs(ssum) < 1e-8:
                        return torch.tensor([0.5, 0.5], dtype=torch.float32, device=self.device)
                    return (out / ssum).to(self.device)
                except Exception as e:
                    logger.warning(f"C compute_fusion_weights failed ({e}) — using Python fallback.")
                    return _python_compute_fusion_weights(s, n).to(self.device)
            else:
                return _python_compute_fusion_weights(s, n).to(self.device)

        except Exception as e:
            logger.error(f"Fusion weights unexpected error: {e}")
            return torch.tensor([0.5, 0.5], dtype=torch.float32, device=self.device)

Writing /content/multimodal_fusion_weights_wrapper.py


In [None]:

%%writefile /content/shape_abstraction.c
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#ifdef __cplusplus
extern "C" {
#endif

void shape_embed(float* input, int dim, float* output) {
    if (!input || !output || dim <= 0) return;
    float norm = 0.0;
    for (int i = 0; i < dim; i++) {
        norm += input[i] * input[i];
    }
    norm = sqrt(norm) > 0 ? sqrt(norm) : 1.0;
    for (int i = 0; i < dim; i++) {
        output[i] = input[i] / norm;
    }
}

#ifdef __cplusplus
}
#endif

Writing /content/shape_abstraction.c


In [None]:

%%writefile /content/shape_abstraction_wrapper.py
import ctypes
import torch
import logging
import numpy as np
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("shape_abstraction_wrapper")


def _python_normalize(vec: torch.Tensor) -> torch.Tensor:
    """Safe fallback normalization."""
    v = vec.float()
    v = torch.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)
    n = torch.norm(v).item()
    if n <= 1e-8:
        return torch.zeros_like(v)
    return v / n


class ShapeAbstractionWrapper:
    def __init__(self, dim=64, device='cpu'):
        self.device = device
        self.dim = dim
        self.lib = None
        self._load()

    def _load(self):
        try:
            self.lib = ctypes.CDLL('./shape_abstraction.so')
            self.lib.shape_embed.argtypes = [
                ctypes.POINTER(ctypes.c_float),
                ctypes.c_int,
                ctypes.POINTER(ctypes.c_float)
            ]
            self.lib.shape_embed.restype = None
            logger.info("Loaded shape_abstraction.so successfully.")
        except Exception as e:
            logger.warning(f"Could not load shape_abstraction.so — using Python fallback. ({e})")
            self.lib = None

    def __call__(self, node: TreeNodeV1) -> torch.Tensor:
        try:
            vec = node.get_vector()
            if vec is None:
                return torch.zeros(self.dim, device=self.device)

            v = vec.detach().cpu().float().view(-1)

            # pad/truncate to dim
            if v.numel() < self.dim:
                v = torch.cat([v, torch.zeros(self.dim - v.numel())])
            elif v.numel() > self.dim:
                v = v[:self.dim]

            # Use C library if available
            if self.lib:
                try:
                    in_arr = (ctypes.c_float * self.dim)(*v.numpy().tolist())
                    out_arr = (ctypes.c_float * self.dim)()
                    self.lib.shape_embed(in_arr, self.dim, out_arr)
                    out = torch.tensor(list(out_arr), dtype=torch.float32)
                    out = torch.nan_to_num(out)
                    return out.to(self.device)
                except Exception as e:
                    logger.warning(f"C library failed, using python fallback ({e})")
                    return _python_normalize(v).to(self.device)

            # python fallback
            return _python_normalize(v).to(self.device)

        except Exception as e:
            logger.error(f"Shape abstraction failure: {e}")
            return torch.zeros(self.dim, device=self.device)

Writing /content/shape_abstraction_wrapper.py


In [None]:

%%writefile /content/adaptive_depth.py
import logging
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("adaptive_depth")


class AdaptiveDepthController:
    def __init__(self, max_depth=4, min_confidence=0.2, min_feature_confidence=0.1):
        """
        Fully shape-independent depth controller.
        """
        self.max_depth = int(max_depth)
        self.min_confidence = float(min_confidence)
        self.min_feature_confidence = float(min_feature_confidence)

    def control_depth(self, node: TreeNodeV1, depth: int = 0):
        try:
            if not isinstance(node, TreeNodeV1):
                return None

            if depth >= self.max_depth:
                return None

            if node.get_confidence() < self.min_confidence:
                return None

            # prune subfeatures
            for k, conf in list(node.feature_confidence.items()):
                if conf < self.min_feature_confidence:
                    node.sub_features.pop(k, None)
                    node.feature_confidence.pop(k, None)

            # recursive prune
            children = []
            for child in node.children:
                pruned = self.control_depth(child, depth + 1)
                if pruned:
                    children.append(pruned)
            node.children = children

            return node if node.get_vector() is not None else None

        except Exception as e:
            logger.error(f"Depth control failed: {e}")
            return None

Writing /content/adaptive_depth.py


In [None]:

%%writefile /content/sparse_tree_optimizer.py
import torch
import logging
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("sparse_tree_optimizer")


class SparseTreeOptimizer:
    def __init__(self, threshold=0.1):
        """
        Shape-independent sparsity optimizer.
        threshold = |x| > threshold keeps values; else zero.
        """
        self.threshold = float(threshold)

    def optimize(self, node: TreeNodeV1):
        try:
            if node is None:
                return None

            vec = node.get_vector()
            if vec is not None:
                mask = (vec.abs() > self.threshold).float()
                node.store_vector(vec * mask)

            children = []
            for child in node.children:
                opt = self.optimize(child)
                if opt:
                    children.append(opt)
            node.children = children

            return node if node.get_vector() is not None or node.children else None

        except Exception as e:
            logger.error(f"Optimize error: {e}")
            return None

Writing /content/sparse_tree_optimizer.py


In [None]:

%%writefile /content/tree_matrix_hybrid_wrapper.py
import ctypes
import torch
import logging
from TreeNodeV1 import TreeNodeV1

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("tree_matrix_hybrid")


class TreeMatrixHybridWrapper:
    def __init__(self, dim=64, device='cpu'):
        self.dim = dim
        self.device = device
        self.lib = None
        self._load()

    def _load(self):
        try:
            self.lib = ctypes.CDLL('./tree_matrix_hybrid.so')
            self.lib.tree_matrix_hybrid.argtypes = [
                ctypes.POINTER(ctypes.c_float), ctypes.c_int, ctypes.c_int,
                ctypes.POINTER(ctypes.c_int), ctypes.c_int
            ]
            self.lib.tree_matrix_hybrid.restype = ctypes.c_void_p
            self.lib.free_matrix.argtypes = [ctypes.c_void_p]
            logger.info("Loaded tree_matrix_hybrid.so")
        except Exception as e:
            logger.warning(f"Could not load tree_matrix_hybrid.so ({e})")
            self.lib = None

    def to_matrix(self, root: TreeNodeV1):
        try:
            if root is None:
                return torch.zeros((1, self.dim), device=self.device)

            # collect vectors
            data = []
            idxs = []

            def walk(n):
                v = n.get_vector()
                if v is not None:
                    vec = v.detach().cpu().float().view(-1)
                    if vec.numel() < self.dim:
                        vec = torch.cat([vec, torch.zeros(self.dim - vec.numel())])
                    elif vec.numel() > self.dim:
                        vec = vec[:self.dim]
                    data.extend(vec.tolist())
                    idxs.append(len(idxs))
                for c in n.children:
                    walk(c)

            walk(root)

            node_count = len(idxs)
            if node_count == 0:
                return torch.zeros((1, self.dim), device=self.device)

            if self.lib is None:
                # Python fallback
                mat = torch.tensor(data, dtype=torch.float32).view(node_count, self.dim)
                return mat.to(self.device)

            # C call
            data_arr = (ctypes.c_float * len(data))(*data)
            idx_arr = (ctypes.c_int * len(idxs))(*idxs)
            ptr = self.lib.tree_matrix_hybrid(data_arr, node_count, self.dim,
                                              idx_arr, len(idxs))

            if not ptr:
                logger.warning("tree_matrix_hybrid returned NULL")
                return torch.zeros((node_count, self.dim), device=self.device)

            mat_flat = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_float *
                                                       (node_count * self.dim))).contents
            mat = torch.tensor(list(mat_flat), dtype=torch.float32).reshape(node_count, self.dim)
            self.lib.free_matrix(ptr)
            return mat.to(self.device)

        except Exception as e:
            logger.error(f"Matrix hybrid error: {e}")
            return torch.zeros((1, self.dim), device=self.device)

Writing /content/tree_matrix_hybrid_wrapper.py


In [None]:

!apt-get update
!apt-get install -y libopencv-dev
!g++ -v /content/lazy_recursive.c -o /content/lazy_recursive.so -shared -fPIC
!g++ -v /content/tree_pruner.c -o /content/tree_pruner.so -shared -fPIC
!g++ -v /content/quantization.c -o /content/quantization.so -shared -fPIC
!g++ -v /content/tree_matrix_hybrid.c -o /content/tree_matrix_hybrid.so -shared -fPIC
!g++ -v /content/symbolic_feature_extractor.c -o /content/symbolic_feature_extractor.so -lopencv_core -lopencv_imgproc -shared -fPIC
!g++ -v /content/neural_enhancer.c -o /content/neural_enhancer.so -shared -fPIC
!g++ -v /content/attention_aggregator.c -o /content/attention_aggregator.so -shared -fPIC
!g++ -v /content/shape_abstraction.c -o /content/shape_abstraction.so -shared -fPIC
!g++ -v /content/multimodal_fusion_weights.c -o /content/multimodal_fusion_weights.so -shared -fPIC

In [None]:
!mv neural_enhancer.so libneural_enhancer.so

In [None]:

%%writefile /content/tree_fusion_engine.py
import logging
import torch

from TreeNodeV1 import TreeNodeV1
from adaptive_depth import AdaptiveDepthController
from sparse_tree_optimizer import SparseTreeOptimizer
from attention_aggregator_wrapper import AttentionAggregatorWrapper
from neural_enhancer_wrapper import NeuralEnhancerWrapper
from quantization_wrapper import QuantizationWrapper
from multimodal_fusion_weights_wrapper import MultimodalFusionWeightsWrapper
from tree_matrix_hybrid_wrapper import TreeMatrixHybridWrapper

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class TreeCNNFusionEngine:
    def __init__(self, dim=64, device='cpu', efficiency_flag="--efficiency=high"):
        self.device = device
        self.dim = dim
        self.efficiency_flag = efficiency_flag

        # shape-independent modules
        self.depth_controller = AdaptiveDepthController(
        max_depth=4,
        min_confidence=0.0,
        min_feature_confidence=0.0
        )
        self.sparse_optimizer = SparseTreeOptimizer(threshold=0.1)
        self.quantizer = QuantizationWrapper(dim=dim)
        self.attn_aggregator = AttentionAggregatorWrapper(dim=dim, device=device)
        self.enhancer = NeuralEnhancerWrapper(dim=dim)
        self.weights = MultimodalFusionWeightsWrapper(dim=dim, device=device)
        self.matrix_converter = TreeMatrixHybridWrapper(dim, device=device)

    def process(self, root: TreeNodeV1) -> torch.Tensor:
        try:
            root = self.depth_controller.control_depth(root)
            if root is None:
                return torch.zeros(self.dim, device=self.device)

            root = self.sparse_optimizer.optimize(root)
            self.quantizer.apply_quantization(root, user_flag=self.efficiency_flag)

            valid = [c for c in root.children if c.get_vector() is not None]
            agg = self.attn_aggregator(valid)

            enhanced = self.enhancer(agg)

            # fusion: symbolic removed → use neural only
            symbolic = torch.zeros(self.dim, device=self.device)
            w = self.weights.compute_weights(symbolic, enhanced)

            fused = (w[0] * symbolic) + (w[1] * enhanced)
            return fused.to(self.device)

        except Exception as e:
            logger.error(f"FusionEngine failed: {e}")
            return torch.zeros(self.dim, device=self.device)

    def to_matrix(self, root: TreeNodeV1) -> torch.Tensor:
        return self.matrix_converter.to_matrix(root)

Writing /content/tree_fusion_engine.py


# Section

In [None]:
%%writefile /content/label_pipeline.py
from tokenizer_and_embedding import TokenEmbedding, universal_tokenizer
from target_processor import MultiLabelTargetProcessor
from TreeBuilderV2 import TreeBuilderV2
import torch

def run_label_pipeline(label_text, dim=50, device="cpu"):
    tokens = universal_tokenizer(label_text)
    vocab = sorted(set(tokens))
    token_embedding = TokenEmbedding(vocab=vocab, dim=dim, device=device)

    target_processor = MultiLabelTargetProcessor()
    target_processor.fit_labels([label_text])

    vec_pairs = []
    for i, token in enumerate(tokens):
        vec = token_embedding.lookup(token)  # Already returns a torch.Tensor
        vec_pairs.append((f"label_node_{i}", vec.to(device)))

    builder = TreeBuilderV2(device=device, dim=dim)
    label_tree = builder.build_tree(
        vec_pairs,
        label_texts={t: label_text for t, _ in vec_pairs},
        token_embedding=token_embedding,
        target_processor=target_processor
    )

    return label_tree

Writing /content/label_pipeline.py


# Section

In [None]:

%%writefile /content/treecnn_pipeline.py
import torch
import os
import logging

from TreeCNNppRunner import run_treecnnpp
from tree_fusion_engine import TreeCNNFusionEngine
from label_pipeline import run_label_pipeline

logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("TreeCNNPipeline")


class TreeCNNPipeline:
    """
    FINAL SAFE + SHAPE-FREE TreeCNN++ PIPELINE
    ------------------------------------------
    • No TLite or shape tags.
    • Fully device-safe (CPU/GPU)
    • Guaranteed output dim = self.dim
    • Optional on-disk caching for precompute efficiency
    """

    def __init__(self, device="cpu", dim=64, cache_dir="/content/treecnn_cache"):
        self.device = device
        self.dim = dim

        # All extracted vectors will be dim-sized
        self.shape_dim = dim

        # disabled TLite
        self.tlite_shape_model = None
        self.shape_to_idx = {}

        # Fusion engine
        self.fusion = TreeCNNFusionEngine(dim=dim, device=device)

        # Optional caching
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)

        logger.info(f"TreeCNNPipeline initialized (dim={dim}, device={device})")

    # -------------------------------
    # INTERNAL: sanitize final vector
    # -------------------------------
    def _safe_vec(self, vec, tag="unknown"):
        """
        Ensure vec is a torch.Tensor of size [dim] and contains no NaNs.
        """
        try:
            if vec is None:
                raise ValueError("None vector")

            if not isinstance(vec, torch.Tensor):
                vec = torch.tensor(vec, dtype=torch.float32)

            vec = vec.to(self.device).float()

            # Wrong dimension → fix
            if vec.numel() != self.dim:
                logger.warning(f"[{tag}] wrong dim {vec.numel()} → forced to {self.dim}")
                v = vec.view(-1)
                if v.numel() < self.dim:
                    pad = torch.zeros(self.dim - v.numel(), device=self.device)
                    v = torch.cat([v, pad], dim=0)
                vec = v[:self.dim]

            # NaN / Inf guard
            if torch.isnan(vec).any() or torch.isinf(vec).any():
                logger.error(f"[{tag}] NaN detected → replaced with zeros")
                return torch.zeros(self.dim, device=self.device)

            return vec

        except Exception as e:
            logger.error(f"[{tag}] safe_vec failed: {e}")
            return torch.zeros(self.dim, device=self.device)

    # --------------------------------
    # OPTIONAL CACHING (FAST TRAINING)
    # --------------------------------
    def _cache_path(self, key: str):
        return os.path.join(self.cache_dir, f"{key}.pt")

    def _load_cache(self, key: str):
        path = self._cache_path(key)
        if os.path.exists(path):
            try:
                return torch.load(path, map_location=self.device)
            except:
                return None
        return None

    def _save_cache(self, key: str, vec):
        try:
            torch.save(vec.detach().cpu(), self._cache_path(key))
        except:
            pass

    # --------------------------------
    # IMAGE → fused vector
    # --------------------------------
    def process_image(self, image_path, use_parallel=True):
        cache_key = f"img::{os.path.basename(image_path)}"

        cached = self._load_cache(cache_key)
        if cached is not None:
            return self._safe_vec(cached, tag="cache_image")

        # Build tree
        tree = run_treecnnpp(
            image_path=image_path,
            device=self.device,
            dim=self.shape_dim,
            use_parallel=use_parallel
        )

        if tree is None:
            logger.warning(f"[image] TreeCNN failed: {image_path}")
            vec = torch.zeros(self.dim, device=self.device)
            self._save_cache(cache_key, vec)
            return vec

        fused = self.fusion.process(tree)
        fused = self._safe_vec(fused, tag="image_fusion")

        self._save_cache(cache_key, fused)
        return fused

    # --------------------------------
    # TEXT → fused vector
    # --------------------------------
    def process_text(self, text):
        cache_key = f"text::{text.strip().replace(' ','_')[:80]}"

        cached = self._load_cache(cache_key)
        if cached is not None:
            return self._safe_vec(cached, tag="cache_text")

        tree = run_label_pipeline(
            text=text,
            dim=self.shape_dim,
            device=self.device
        )

        if tree is None:
            logger.warning(f"[text] label tree failed: {text}")
            vec = torch.zeros(self.dim, device=self.device)
            self._save_cache(cache_key, vec)
            return vec

        fused = self.fusion.process(tree)
        fused = self._safe_vec(fused, tag="text_fusion")

        self._save_cache(cache_key, fused)
        return fused

    # --------------------------------
    # GROUP → fused vector
    # --------------------------------
    def process_group(self, combined_tree):
        if combined_tree is None:
            return torch.zeros(self.dim, device=self.device)

        fused = self.fusion.process(combined_tree)
        return self._safe_vec(fused, tag="group_fusion")

Writing /content/treecnn_pipeline.py


# Section

In [None]:

%%writefile /content/group_tree_builder.py
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional

from TreeNodeV1 import TreeNodeV1
from TreeCNNppRunner import run_treecnnpp
from label_pipeline import run_label_pipeline

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("GroupTreeBuilder")


class GroupTreeBuilder:
    def __init__(self, device: str = 'cpu', dim: int = 50, max_workers: int = 8):
        self.device = device
        self.dim = dim
        self.max_workers = max_workers

    # ----------------------------------------------------------
    # IMAGE TREE
    # ----------------------------------------------------------
    def build_image_tree(self, image_path: str) -> Optional[TreeNodeV1]:
        try:
            tree = run_treecnnpp(
                image_path=image_path,
                device=self.device,
                dim=self.dim,
                use_parallel=True
            )
            if tree is None:
                logger.warning(f"[image] run_treecnnpp returned None for {image_path}")
            return tree
        except Exception as e:
            logger.error(f"[image] Failed {image_path}: {e}")
            return None

    # ----------------------------------------------------------
    # LABEL TREE
    # ----------------------------------------------------------
    def build_label_tree(self, text: str) -> Optional[TreeNodeV1]:
       try:
           tree = run_label_pipeline(
               label_text=text,      # <- IMPORTANT FIX
               dim=self.dim,
               device=self.device
           )
           if tree is None:
               logger.warning(f"[label] empty label tree for text: {text}")
           return tree
       except Exception as e:
           logger.error(f"[label] Failed text='{text}': {e}")
           return None

    # ----------------------------------------------------------
    # COMBINE IMAGE + METADATA + LABEL TREES
    # ----------------------------------------------------------
    def combine_trees(
        self,
        image_tree: Optional[TreeNodeV1],
        label_trees: List[Optional[TreeNodeV1]],
        organ: Optional[str] = None
    ) -> TreeNodeV1:

        # Generous max_children, but safe
        root = TreeNodeV1(
            value="group_root",
            shape_type=None,
            id="group_root",
            max_children=5000
        )

        # 1. add image tree
        if isinstance(image_tree, TreeNodeV1):
            root.add_child(image_tree)
        else:
            logger.warning("[combine] Missing image tree")

        # 2. add organ metadata (minimal node)
        if organ:
            organ_node = TreeNodeV1(
                value=f"organ:{organ}",
                shape_type=None,
                id=f"organ:{organ}",
                max_children=0
            )
            root.add_child(organ_node)

        # 3. add label trees
        for lt in label_trees:
            if isinstance(lt, TreeNodeV1):
                root.add_child(lt)

        return root

    # ----------------------------------------------------------
    # MAIN ENTRY: FULL GROUP BUILDER
    # ----------------------------------------------------------
    def build_combined_for_group(
        self,
        image_path: str,
        questions: List[str],
        organ: Optional[str] = None
    ) -> Optional[TreeNodeV1]:

        # IMAGE NODE
        image_tree = self.build_image_tree(image_path)

        # QUESTION TEXT NODES (PARALLEL)
        label_trees = []
        if questions:
            with ThreadPoolExecutor(
                max_workers=min(self.max_workers, max(1, len(questions)))
            ) as ex:
                futures = {ex.submit(self.build_label_tree, q): q for q in questions}
                for fut in as_completed(futures):
                    try:
                        lt = fut.result()
                        if lt is not None:
                            label_trees.append(lt)
                    except Exception as e:
                        logger.error(f"[label_future] {e}")

        # FINAL GROUP TREE
        return self.combine_trees(image_tree, label_trees, organ)

Writing /content/group_tree_builder.py


In [None]:

%%writefile /content/dataset_group_loader.py
import os
import re
import pandas as pd
from typing import List, Dict, Any, Tuple
from collections import defaultdict, Counter


def normalize_answer(a):
    if not isinstance(a, str):
        return "other"
    a = a.lower().strip()

    # yes/no group
    if a in {"yes", "y", "true", "1"}:
        return "yes"
    if a in {"no", "n", "false", "0"}:
        return "no"

    # normal
    if "normal" in a or "unremarkable" in a:
        return "normal"

    # opacity / opacification
    if "opacity" in a or "opac" in a:
        return "opacity"

    # enlarged
    if "enlarg" in a:
        return "enlarged"

    # fracture
    if "fracture" in a or "fx" in a:
        return "fracture"

    # abnormal group keywords
    keywords = ["effusion", "pneumo", "consolid", "lesion", "mass", "nodule"]
    if any(k in a for k in keywords):
        return "abnormal"

    return "other"


class GroupedDatasetLoader:
    def __init__(
        self,
        csv_path: str,
        image_root: str = None,
        image_column: str = "IMAGEID",
        question_column: str = "QUESTION",
        answer_column: str = "ANSWER",
        organ_column: str = "IMAGEORGAN",
        min_answer_freq: int = 3       # <-- NEW: filter rare answers
    ):
        self.csv_path = csv_path
        self.image_root = image_root
        self.image_column = image_column
        self.question_column = question_column
        self.answer_column = answer_column
        self.organ_column = organ_column
        self.min_answer_freq = min_answer_freq

    def _extract_filename(self, image_val: str) -> str:
        if not isinstance(image_val, str):
            return str(image_val)
        if image_val.startswith("http://") or image_val.startswith("https://"):
            return os.path.basename(image_val.split("?")[0])
        return os.path.basename(image_val)

    def load_and_group(self) -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
        df = pd.read_excel(self.csv_path)

        groups = defaultdict(lambda: {
            "questions": [],
            "answers": [],
            "organs": [],
            "image_vals": []
        })

        # ------------------------------------
        # FIRST PASS → Load raw groups
        # ------------------------------------
        all_answers_raw = []

        for _, row in df.iterrows():
            img_val = row.get(self.image_column, None)
            if pd.isna(img_val):
                continue

            q = row.get(self.question_column, "")
            a = row.get(self.answer_column, "")
            organ = row.get(self.organ_column, "")

            key = str(img_val).strip()

            groups[key]["questions"].append("" if pd.isna(q) else str(q).strip())

            # normalize answer
            a_norm = normalize_answer("" if pd.isna(a) else str(a))
            groups[key]["answers"].append(a_norm)
            all_answers_raw.append(a_norm)

            groups[key]["organs"].append("" if pd.isna(organ) else str(organ).strip())
            groups[key]["image_vals"].append(img_val)

        # ------------------------------------
        # SECOND PASS → Filter Rare Answers
        # ------------------------------------
        freq = Counter(all_answers_raw)
        allowed = {a for a, c in freq.items() if c >= self.min_answer_freq}

        # anything below threshold becomes "other"
        def map_answer(a):
            return a if a in allowed and a != "" else "other"

        # all answers include "other"
        allowed.add("other")

        # ------------------------------------
        # THIRD PASS → Build groups + label map
        # ------------------------------------
        group_list = []
        used_answers = set()

        for key, val in groups.items():
            raw_img_val = val["image_vals"][0]
            filename = self._extract_filename(raw_img_val)

            if self.image_root:
                image_path = os.path.join(self.image_root, filename)
            else:
                image_path = raw_img_val

            mapped_answers = [map_answer(a) for a in val["answers"]]

            for a in mapped_answers:
                used_answers.add(a)

            group_list.append({
                "image_id": key,
                "image_name": filename,
                "image_path": image_path,
                "questions": val["questions"],
                "answers": mapped_answers,
                "organ": val["organs"][0] if val["organs"] else None
            })

        sorted_answers = sorted(list(used_answers))
        label_map = {lab: idx for idx, lab in enumerate(sorted_answers)}

        print(f"⚡ Total unique answers after normalization = {len(label_map)}")

        return group_list, label_map

Writing /content/dataset_group_loader.py


# Section

In [None]:

%%writefile /content/trainer_pipeline.py
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Any, Optional

from dataset_group_loader import GroupedDatasetLoader
from group_tree_builder import GroupTreeBuilder
from treecnn_pipeline import TreeCNNPipeline

# Optional metrics backend: prefer sklearn if available, else fallback
try:
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    _HAVE_SKLEARN = True
except Exception:
    _HAVE_SKLEARN = False


# -------------------------
# Group Dataset (group-level samples)
# -------------------------
class GroupVQADataset(Dataset):
    """
    Each sample corresponds to one image group:
      - image_path
      - questions (list)
      - answers (list)
      - organ (metadata)

    Produces:
      - fused: torch.Tensor [dim]
      - target: torch.Tensor [num_classes] (multi-hot float32)

    New:
      - computes self.class_counts : torch.Tensor [num_classes] (counts of positive examples per class)
      - exposes get_pos_weight() to compute the standard BCE pos_weight = (neg / pos)
    """
    def __init__(self,
                 groups: List[Dict[str, Any]],
                 label_map: Dict[str, int],
                 pipeline: Optional[TreeCNNPipeline] = None,
                 device: str = 'cpu',
                 precompute: bool = False):
        self.groups = groups
        self.label_map = label_map
        self.num_classes = len(label_map)
        self.pipeline = pipeline
        self.device = device
        self.precompute = precompute

        self.builder = GroupTreeBuilder(device=device, dim=50, max_workers=8)

        # compute class counts from groups (positive counts)
        self.class_counts = torch.zeros(self.num_classes, dtype=torch.long)
        for g in self.groups:
            answers = g.get("answers", [])
            if not answers:
                continue
            for a in answers:
                if not a:
                    continue
                idx = self.label_map.get(a)
                if idx is not None:
                    self.class_counts[idx] += 1

        # precompute fused vectors and multi-hot targets if requested
        if self.precompute:
            self.precomputed = []
            for g in self.groups:
                fused = self._compute_group_fused(g)
                target = self._aggregate_target(g["answers"])
                # store on CPU; moved to device at retrieval time
                self.precomputed.append({
                    "fused": fused.detach().cpu(),
                    "target": target.detach().cpu()
                })

    def __len__(self):
        return len(self.groups)

    def _aggregate_target(self, answers: List[str]) -> torch.Tensor:
        """
        Return multi-hot vector (float) of length num_classes.
        If an answer is not in label_map it's ignored.
        """
        target = torch.zeros(self.num_classes, dtype=torch.float32)
        for a in answers:
            if not a:
                continue
            idx = self.label_map.get(a)
            if idx is not None:
                target[idx] = 1.0
        return target

    def _compute_group_fused(self, group):
        """
        Build combined tree and fuse via pipeline (if provided).
        If pipeline is None or fusion fails we return zeros.
        """
        try:
            combined = self.builder.build_combined_for_group(group["image_path"], group["questions"], organ=group.get("organ"))
            if self.pipeline is not None:
                fused = self.pipeline.process_group(combined)
                if fused is None:
                    return torch.zeros(self.pipeline.dim, device=self.device)
                return fused.to(self.device)
            else:
                # fallback zero vector
                return torch.zeros(64, device=self.device)
        except Exception:
            return torch.zeros(self.pipeline.dim if self.pipeline is not None else 64, device=self.device)

    def __getitem__(self, idx):
        if self.precompute:
            ex = self.precomputed[idx]
            return {
                "fused": ex["fused"].to(self.device),
                "target": ex["target"].to(self.device)
            }
        group = self.groups[idx]
        fused = self._compute_group_fused(group)
        target_idx = self._aggregate_target(group["answers"])
        return {"fused": fused.to(self.device), "target": target_idx.to(self.device)}

    # ----------------- helper for Trainer -----------------
    def get_pos_weight(self, eps: float = 1e-6) -> torch.Tensor:
        """
        Return pos_weight tensor for BCEWithLogitsLoss:
          pos_weight[c] = (num_neg_examples / num_pos_examples) = ((N - pos_c) / pos_c)
        Handles zero-count classes safely by capping.
        Returns float tensor (on CPU); Trainer will move to device.
        """
        N = len(self.groups)
        pos = self.class_counts.float()
        # avoid division by zero: if pos == 0, set pos to eps (small) -> very large weight
        pos_safe = torch.where(pos <= 0.0, torch.full_like(pos, eps), pos)
        neg = float(N) - pos
        # pos_weight = neg / pos
        pos_weight = neg / pos_safe
        # clamp to avoid infinite/huge values (optional)
        pos_weight = torch.clamp(pos_weight, min=1.0, max=10.0)
        return pos_weight.float()


def collate_batch(batch):
    """
    batch: list of dicts { 'fused': [dim], 'target': [num_classes] }
    returns stacked tensors
    """
    fused = torch.stack([b['fused'] for b in batch], dim=0)
    targets = torch.stack([b['target'] for b in batch], dim=0)
    return {'fused': fused, 'target': targets}


# -------------------------
# Deep multi-label classification head
# -------------------------
class DeepHead(nn.Module):
    def __init__(self, in_dim, num_classes, hidden_dims=[256,128], dropout=0.3):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(h))
            layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)  # [B, num_classes]


# -------------------------
# Trainer (multi-label)
# -------------------------
class Trainer:
    def __init__(self,
                 model: nn.Module,
                 optimizer,
                 loss_fn: Optional[nn.Module] = None,
                 device: str = 'cpu',
                 ckpt_dir: str = './checkpoints',
                 dataset_for_weights: Optional[GroupVQADataset] = None,
                 use_pos_weight: bool = True):
        """
        model: nn.Module
        optimizer: optimizer
        loss_fn: optional; if None we'll create BCEWithLogitsLoss (multi-label).
                 If dataset_for_weights is provided and use_pos_weight=True we'll compute pos_weight automatically.
        dataset_for_weights: pass the GroupVQADataset instance to compute pos-weight from its class_counts
        """
        self.model = model.to(device)
        self.optimizer = optimizer
        self.device = device
        self.ckpt_dir = ckpt_dir
        os.makedirs(ckpt_dir, exist_ok=True)

        # Setup loss: prefer provided; else build BCEWithLogitsLoss with optional pos_weight
        if loss_fn is not None:
            self.loss_fn = loss_fn
        else:
            pos_weight = None
            if dataset_for_weights is not None and use_pos_weight:
                try:
                    pos_weight = dataset_for_weights.get_pos_weight()
                    # move to device
                    pos_weight = pos_weight.to(device)
                except Exception:
                    pos_weight = None
            if pos_weight is not None:
                self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            else:
                self.loss_fn = nn.BCEWithLogitsLoss()

    def _compute_metrics(self, preds_logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5):
        """
        preds_logits: [N, C] (logits)
        targets: [N, C] (0./1.)
        Returns: dict with accuracy, precision, recall, f1 (macro)
        Uses sklearn if available, else fallback to micro calculations.
        """
        with torch.no_grad():
            probs = torch.sigmoid(preds_logits).cpu()
            preds = (probs >= threshold).long().numpy()
            targs = targets.cpu().long().numpy()

        metrics = {"accuracy": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0}

        try:
            if _HAVE_SKLEARN:
                # accuracy here is element-wise match fraction
                acc = (preds == targs).mean()
                prec = precision_score(targs, preds, average='macro', zero_division=0)
                rec = recall_score(targs, preds, average='macro', zero_division=0)
                f1 = f1_score(targs, preds, average='macro', zero_division=0)
                metrics.update({"accuracy": float(acc), "precision": float(prec), "recall": float(rec), "f1": float(f1)})
            else:
                tp = int(((preds == 1) & (targs == 1)).sum())
                fp = int(((preds == 1) & (targs == 0)).sum())
                fn = int(((preds == 0) & (targs == 1)).sum())
                tn = int(((preds == 0) & (targs == 0)).sum())
                acc = float((tp + tn) / max(1, (tp + tn + fp + fn)))
                prec = float(tp / max(1, (tp + fp)))
                rec = float(tp / max(1, (tp + fn)))
                f1 = float(2 * prec * rec / max(1e-8, prec + rec))
                metrics.update({"accuracy": acc, "precision": prec, "recall": rec, "f1": f1})
        except Exception:
            pass

        return metrics

    def train_epoch(self, dataloader: DataLoader, log_every: int = 10, threshold: float = 0.5):
        self.model.train()
        total_loss = 0.0
        count = 0
        start = time.time()

        all_logits = []
        all_targets = []

        num_steps = len(dataloader)
        bar_len = 40
        iter_start = time.time()

        for i, batch in enumerate(dataloader):
            fused = batch['fused'].to(self.device)
            targets = batch['target'].to(self.device)

            self.optimizer.zero_grad()
            logits = self.model(fused)  # [B, C]
            loss = self.loss_fn(logits, targets)
            loss.backward()
            self.optimizer.step()

            total_loss += float(loss.item())
            count += 1

            all_logits.append(logits.detach().cpu())
            all_targets.append(targets.detach().cpu())

            # Progress / ETA
            progress = (i + 1) / num_steps
            filled = int(progress * bar_len)
            bar = "█" * filled + "-" * (bar_len - filled)
            elapsed = time.time() - iter_start
            eta = (elapsed / (i + 1)) * (num_steps - (i + 1)) if (i + 1) > 0 else 0.0

            if (i + 1) % max(1, log_every) == 0 or (i + 1) == num_steps:
                avg_loss = total_loss / max(1, count)
                print(f"\r[{bar}] {progress*100:5.1f}% | Loss: {avg_loss:.4f} | ETA: {eta:5.1f}s", end="")

        # End epoch newline
        print()

        # Aggregate metrics
        # Aggregate metrics
        logits_all = torch.cat(all_logits, dim=0)
        targets_all = torch.cat(all_targets, dim=0)

        # ---- ADAPTIVE THRESHOLD ----
        best_thr, best_f1 = self.find_best_threshold(logits_all, targets_all)
        print(f"\n🔧 Best threshold this epoch = {best_thr:.2f} (F1 = {best_f1:.4f})")

        # compute metrics using new threshold
        metrics = self._compute_metrics(logits_all, targets_all, threshold=best_thr)

        elapsed_total = time.time() - start
        epoch_loss = total_loss / max(1, count)

        return epoch_loss, elapsed_total, metrics

    def evaluate(self, dataloader: DataLoader, threshold: float = 0.5):
        self.model.eval()
        total_loss = 0.0
        count = 0
        all_logits = []
        all_targets = []

        with torch.no_grad():
            for batch in dataloader:
                fused = batch['fused'].to(self.device)
                targets = batch['target'].to(self.device)

                logits = self.model(fused)
                loss = self.loss_fn(logits, targets)

                total_loss += float(loss.item())
                count += 1
                all_logits.append(logits.cpu())
                all_targets.append(targets.cpu())

        if count == 0:
            return 0.0, {}

        logits_all = torch.cat(all_logits, dim=0)
        targets_all = torch.cat(all_targets, dim=0)
        metrics = self._compute_metrics(logits_all, targets_all, threshold=threshold)
        return total_loss / count, metrics

    def save(self, name='latest.pt'):
        path = os.path.join(self.ckpt_dir, name)
        torch.save({'model_state': self.model.state_dict()}, path)
        return path

    def load(self, path):
        ck = torch.load(path, map_location=self.device)
        self.model.load_state_dict(ck['model_state'])

    def find_best_threshold(self, logits, targets):
       thresholds = [i/100 for i in range(5, 80)]  # 0.05 → 0.95
       best_thr = 0.5
       best_f1 = 0

       with torch.no_grad():
           probs = torch.sigmoid(logits).cpu()
           targs = targets.cpu().long().numpy()

           for thr in thresholds:
               preds = (probs >= thr).long().numpy()
               f1 = f1_score(targs, preds, average='macro', zero_division=0)

               if f1 > best_f1:
                   best_f1 = f1
                   best_thr = thr

       return best_thr, best_f1
# -------------------------
# Example usage snippet (kept here for convenience)
# -------------------------
if __name__ == "__main__":
    # small sanity example
    csv_path = "/content/vqa_rad.csv"
    loader = GroupedDatasetLoader(csv_path=csv_path)
    groups, label_map = loader.load_and_group()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    pipeline = TreeCNNPipeline(device=device, dim=64)

    dataset = GroupVQADataset(groups, label_map, pipeline=pipeline, device=device, precompute=False)
    dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_batch)

    input_dim = pipeline.dim
    model = DeepHead(in_dim=input_dim, num_classes=len(label_map)).to(device)
    optimizer = optim.Adam(model.parameters(), lr=5e-4)

    # Build Trainer that automatically computes pos_weight from dataset.class_counts
    trainer = Trainer(model, optimizer, loss_fn=None, device=device, dataset_for_weights=dataset, use_pos_weight=True)

    for epoch in range(2):
        loss, elapsed, metrics = trainer.train_epoch(dataloader)
        print(f"Epoch {epoch} loss={loss:.4f} time={elapsed:.1f}s metrics={metrics}")

Overwriting /content/trainer_pipeline.py


In [None]:
import zipfile
import os

# Path to the zip file
zip_path = "/content/VQA_RAD Image Folder.zip"
# Destination directory
extract_path = "/content/images"

# Make sure the destination directory exists
os.makedirs(extract_path, exist_ok=True)

# Unzip the file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print(f"Files extracted to {extract_path}")

Files extracted to /content/images


In [None]:
import trainer_pipeline
importlib.reload(trainer_pipeline)

<module 'trainer_pipeline' from '/content/trainer_pipeline.py'>

In [None]:

import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import importlib

# Reload FIRST
import trainer_pipeline
importlib.reload(trainer_pipeline)

from trainer_pipeline import Trainer, collate_batch, GroupVQADataset, DeepHead
from dataset_group_loader import GroupedDatasetLoader
from treecnn_pipeline import TreeCNNPipeline

# ----------------------------------------
# 1. Load Dataset
# ----------------------------------------
loader_src = GroupedDatasetLoader(
    csv_path="/content/VQA_RAD Dataset Public.xlsx",
    image_root="/content/images"
)

groups, label_map = loader_src.load_and_group()
num_classes = len(label_map)
print("Number of classes =", num_classes)

# ----------------------------------------
# 2. Pipeline + Dataset
# ----------------------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipeline = TreeCNNPipeline(device=device, dim=64)

dataset = GroupVQADataset(
    groups,
    label_map,
    pipeline=pipeline,
    device=device,
    precompute=True
)

train_loader = DataLoader(
    dataset,
    batch_size=8,
    collate_fn=collate_batch,
    num_workers=2,
    pin_memory=(device=='cuda')
)

# ----------------------------------------
# 3. Deep Multi-Label Head
# ----------------------------------------
model = DeepHead(
    in_dim=64,
    num_classes=num_classes
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Trainer automatically builds pos_weight + BCE loss
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=None,                 # ← Let trainer build BCEWithLogits + pos_weight
    device=device,
    dataset_for_weights=dataset,  # ← Required for pos_weight
    use_pos_weight=True
)

# ----------------------------------------
# 4. TRAIN (NOW USING ADAPTIVE THRESHOLD)
# ----------------------------------------
EPOCHS = 10

for epoch in range(EPOCHS):
    # NOTE: DO NOT pass threshold anymore — adaptive threshold is inside Trainer
    loss, elapsed, metrics = trainer.train_epoch(train_loader)

    print(f"🟩 Epoch {epoch+1}/{EPOCHS} Summary:")
    print(f"   Loss      = {loss:.4f}")
    print(f"   Accuracy  = {metrics['accuracy']:.4f}")
    print(f"   Precision = {metrics['precision']:.4f}")
    print(f"   Recall    = {metrics['recall']:.4f}")
    print(f"   F1 Score  = {metrics['f1']:.4f}")
    print("--------------------------------------------------")

  warn(msg)


⚡ Total unique answers after normalization = 7
Number of classes = 7
[████████████████████████████████████████] 100.0% | Loss: 0.5836 | ETA:   0.0s

🔧 Best threshold this epoch = 0.05 (F1 = 0.3981)
🟩 Epoch 1/10 Summary:
   Loss      = 0.5836
   Accuracy  = 0.4445
   Precision = 0.3389
   Recall    = 0.8333
   F1 Score  = 0.3981
--------------------------------------------------
[████████████████████████████████████████] 100.0% | Loss: 0.5325 | ETA:   0.0s

🔧 Best threshold this epoch = 0.11 (F1 = 0.4013)
🟩 Epoch 2/10 Summary:
   Loss      = 0.5325
   Accuracy  = 0.6192
   Precision = 0.3406
   Recall    = 0.7381
   F1 Score  = 0.4013
--------------------------------------------------
[████████████████████████████████████████] 100.0% | Loss: 0.5363 | ETA:   0.0s

🔧 Best threshold this epoch = 0.15 (F1 = 0.4526)
🟩 Epoch 3/10 Summary:
   Loss      = 0.5363
   Accuracy  = 0.6738
   Precision = 0.3853
   Recall    = 0.7143
   F1 Score  = 0.4526
----------------------------------------------