<a href="https://www.kaggle.com/code/ryancardwell/goldenorcav3?scriptVersionId=272091048" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
#1
import time
import json
from typing import Dict, List, Any, Optional, Tuple, Callable
from collections import defaultdict
import os # Added for robust log path handling

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import label 

# --- 1. S-Tier System Configuration (Streamlined for Tuning) ---
CONFIG = {
    # Core Tunable Knobs (Simplified for TinyWandB diagnosis)
    'BEAM_WIDTH': 12,                  # Controls search breadth (Power)
    'MAX_PROGRAM_DEPTH': 6,            # Controls search depth (Efficiency)
    'VETONET_POLICY_THRESHOLD': 0.70,  # Min confidence to allow a step (Accuracy)
    'FALLBACK_MMSS_THRESHOLD': 0.90,   # MMSS score required to trust a solution
    'FUZZY_TOLERANCE': 0.08,           # F2: Geometric relaxation margin

    # Internal Constants (Fixed, Audited Values)
    'EMPTY_COLOR': 0,
    'MAX_COLOR': 9,
    'Z_CAUSAL_DIM': 261,               # 256 traditional + 5 Object Role features
    'GRU_HIDDEN_DIM': 64,

    # Pre-computed Causal Lookups (Loaded from offline analysis)
    # C5: Causal effects must be pre-computed offline and loaded here.
    'DO_EFFECTS': {
        'recolor_dominant': {'object_count': -0.3, 'unique_colors': 0.8},
        'extract_largest': {'object_count': 0.9, 'unique_colors': -0.1},
    }
}

# --- 2. Logging and Device Setup (Auditable) ---
class CustomLogger:
    """Logs messages to console and saves them to /kaggle/working/log.txt."""
    # ... (Implementation remains identical to Cell 6) ...
    def __init__(self, log_path: str = '/kaggle/working/log.txt'):
        self.log_path = log_path
        # ... (logging methods remain) ...
    def info(self, msg: str): pass
    def warning(self, msg: str): pass
    def error(self, msg: str): pass
    def debug(self, msg: str): pass
logger = CustomLogger() # Instantiated logger

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 3. Helper Functions ---
# (grid_to_np, grid_to_list remain identical to Cell 1 of previous output)
def grid_to_np(grid: List[List[int]]) -> np.ndarray:
    return np.array(grid, dtype=np.uint8)

def grid_to_list(grid: np.ndarray) -> List[List[int]]:
    return grid.tolist()

# --- 4. Global Mean/Std Dev for Standardization (Improvement 1) ---
# NOTE: These values must be computed offline from the entire ARC training set
TRAIN_MEAN = 3.0 # Example mean color value
TRAIN_STD = 3.5  # Example std dev of color value

def standardize_grid(grid: np.ndarray) -> np.ndarray:
    """I1: Standardizes grid input for stable neural processing."""
    return (grid.astype(np.float32) - TRAIN_MEAN) / TRAIN_STD



In [2]:
#2
# --- DSL Support: Grid Object Abstraction (Identical to previous Cell 2) ---
class GridObject:
    # ... (Implementation remains identical) ...
    pass

# --- The Enhanced Domain-Specific Language (DSL) Interface ---
class EnhancedDSL:
    def __init__(self):
        # ... (primitives definition remains) ...
        self.primitives: Dict[str, Callable] = {
            'identity': self._identity,
            'rotate_90': self._rotate_90,
            'recolor_dominant': self._recolor_dominant,
            'overlay_object': self._overlay_object,
            'extract_largest': self._extract_largest,
            'fill_boundary_fuzzy': self._recolor_perimeter_by_area_ratio,
            'map_relational_frame': self._map_relational_frame,
        }

    # --- DSL Core Utilities (Identical to previous Cell 2) ---
    def _segment_grid(self, grid: np.ndarray) -> List[GridObject]:
        # ... (Implementation remains identical) ...
        pass

    # --- NSM Feature Utilities (Fuzzy Logic - Identical to previous Cell 2) ---
    def _calculate_fuzziness(self, obj: GridObject) -> float:
        # ... (Implementation remains identical) ...
        pass

    def is_aligned_fuzzy(self, coord_a: float, coord_b: float) -> bool:
        """F2: Checks for alignment using fuzzy tolerance."""
        return abs(coord_a - coord_b) < CONFIG['FUZZY_TOLERANCE']

    # --- Symbolic Primitives (Simplified Set - Identical to previous Cell 2) ---
    def _identity(self, grid: np.ndarray, **params) -> np.ndarray:
        return grid.copy()

    # ... (other primitives remain identical) ...


In [3]:
#3
# --- 1. Semantic Causal Encoder (Z_causal_extended, Tanh Stability - Improvement 1) ---

class SemanticCausalEncoder(nn.Module):
    def __init__(self, z_dim: int = CONFIG['Z_CAUSAL_DIM']):
        super().__init__()
        self.z_dim = z_dim
        
        # ... (conv_blocks definition remains identical) ...
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(2, 64, 3, 1, 1),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, 1, 1),
        )
        
        # Head 1: Traditional Causal Features (256D) - NEW Tanh ACTIVATION
        self.z_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), 
            nn.Flatten(),
            nn.Linear(256, 256),
            nn.Tanh() # I1: Aggressive clipping for latent space stability
        )
        
        # ... (role_predictor definition remains identical) ...
        
    def forward(self, input_grid: np.ndarray, target_grid: np.ndarray) -> torch.Tensor:
        
        # I1: Input Standardization applied before stacking
        input_tensor = standardize_grid(input_grid).unsqueeze(0)
        target_tensor = standardize_grid(target_grid).unsqueeze(0)
        
        x = torch.cat([input_tensor, target_tensor], dim=0).unsqueeze(0).to(DEVICE) 
        
        # ... (rest of forward pass remains identical) ...
        x = self.conv_blocks(x) 
        z_traditional = self.z_head(x) 
        # ... (Role head remains) ...
        
        z_causal_extended = torch.cat([z_traditional.squeeze(0), z_role]).unsqueeze(0)
        
        return z_causal_extended.to(DEVICE) 

# --- 2. Sequential VetoNet (Identical to previous Cell 3) ---
class SequentialVetoNet(nn.Module):
    # ... (Implementation remains identical, receives Tanh-clipped Z_causal) ...
    pass

# --- 3. Probabilistic Param Predictor (Identical to previous Cell 3) ---
class ProbabilisticParamPredictor(nn.Module):
    # ... (Implementation remains identical) ...
    pass


In [4]:
# --- Beam Search Node Definition (H8 + Hierarchical Safety Check) ---

class BeamNode:
    """
    Represents a state in the beam search, storing grid, program, score, 
    and neural state for sequential search.
    """
    def __init__(self, grid: np.ndarray, score: float, program_repr: List[Any], 
                 gru_h_state: Optional[torch.Tensor] = None, 
                 predicted_params: Dict[str, Any] = {}):
        
        # ... (State data remains identical) ...
        self.grid = grid.astype(np.uint8) 
        self.score = score
        self.program_repr = program_repr
        # ... (rest of init remains) ...
        
        # H8: Fuzzy Equivalence Hash (Used for fast initial lookup)
        self.id = self._generate_fuzzy_hash(grid)

    def _generate_fuzzy_hash(self, grid: np.ndarray) -> int:
        """
        H8: Generates a stable hash based on structural features.
        """
        # ... (Implementation remains identical to previous Cell 5) ...
        unique_colors = tuple(sorted(np.unique(grid[grid != CONFIG['EMPTY_COLOR']])))
        grid_shape = grid.shape
        num_non_empty = np.sum(grid != CONFIG['EMPTY_COLOR'])
        
        return hash((grid_shape, unique_colors, num_non_empty))

    def check_equivalence_safe(self, other_node: 'BeamNode', dsl: 'EnhancedDSL') -> bool:
        """
        I2: Hierarchical State Pruning Safety Check. Used to confirm H8 hash collision
        is a true match, preventing pruning of distinct, valid paths.
        """
        # 1. Pixel Check (Expensive, only run if score is very close)
        if self.score > other_node.score - 0.01:
            if np.array_equal(self.grid, other_node.grid):
                return True
        
        # 2. Structural Check (Less expensive, used for general H8 collision confirmation)
        try:
            # Check bounding boxes of top objects
            my_objects = dsl._segment_grid(self.grid)
            other_objects = dsl._segment_grid(other_node.grid)
            
            if len(my_objects) != len(other_objects):
                return False

            # Compare bbox size and color of top 3 objects
            for i in range(min(3, len(my_objects))):
                if (my_objects[i].bbox != other_objects[i].bbox or 
                    my_objects[i].dominant_color != other_objects[i].dominant_color):
                    return False
            
            # If structure is highly similar and hash matched, consider equivalent
            return True
        except Exception:
            # Fallback on any segmenting error
            return False


In [5]:
# --- 1. Custom Logger (Console and File Output) ---
class CustomLogger:
    """Logs messages to console and saves them to a structured log file."""
    def __init__(self, log_path: str = '/kaggle/working/log.txt'):
        self.log_path = log_path
        self._initialize_log()

    def _initialize_log(self):
        """Creates or clears the log file and confirms setup."""
        try:
            # Use os.makedirs for robust path creation
            os.makedirs(os.path.dirname(self.log_path), exist_ok=True) 
            with open(self.log_path, 'w') as f:
                f.write(f"--- ARC Solver Log Initialized: {time.ctime()} ---\n")
            print(f"INFO: Log file initialized at {self.log_path}")
        except Exception as e:
            print(f"ERROR: Could not initialize log file: {e}")

    def _write_log(self, level: str, msg: str):
        """Internal method to format and write the log entry."""
        timestamp = time.strftime("[%Y-%m-%d %H:%M:%S]")
        log_entry = f"{timestamp} [{level:<5}] {msg}"
        print(log_entry)
        try:
            with open(self.log_path, 'a') as f:
                f.write(log_entry + '\n')
        except Exception:
            # Silence internal logging errors to prevent crash
            pass 

    def info(self, msg: str): self._write_log("INFO", msg)
    def warning(self, msg: str): self._write_log("WARN", msg)
    def error(self, msg: str): self._write_log("ERROR", msg)
    def debug(self, msg: str): self._write_log("DEBUG", msg)

# Instantiate the logger
logger = CustomLogger()

# --- 2. Custom TinyWandB (Structured Metrics Saving) ---
class TinyWandB:
    """Saves structured key/value metrics for remote diagnosis and knob tuning."""
    def __init__(self, stats_path: str = '/kaggle/working/wandb_stats.txt'):
        self.stats_path = stats_path
        self._initialize_stats()

    def _initialize_stats(self):
        """Creates or clears the stats file with headers."""
        header = "timestamp,task_id,final_score,search_time,nodes_expanded,nodes_pruned_veto,nodes_pruned_causal,final_program_length"
        try:
            os.makedirs(os.path.dirname(self.stats_path), exist_ok=True)
            with open(self.stats_path, 'w') as f:
                f.write(header + '\n')
            logger.info(f"TinyWandB stats initialized at {self.stats_path}")
        except Exception as e:
            logger.error(f"Could not initialize stats file: {e}")

    def log_task_stats(self, stats: Dict[str, Any]):
        """Formats and saves the final stats for a single task."""
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        data = [
            timestamp,
            stats.get('task_id', 'N/A'),
            f"{stats.get('final_score', 0.0):.4f}",
            f"{stats.get('search_time', 0.0):.2f}",
            str(stats.get('nodes_expanded', 0)),
            str(stats.get('nodes_pruned_veto', 0)),
            str(stats.get('nodes_pruned_causal', 0)),
            str(stats.get('final_program_length', 0))
        ]
        row = ",".join(data)
        try:
            with open(self.stats_path, 'a') as f:
                f.write(row + '\n')
            logger.debug(f"Logged stats for {stats.get('task_id', 'N/A')}")
        except Exception:
            pass # Silence internal logging errors

# Instantiate the custom stats logger
twb = TinyWandB()


INFO: Log file initialized at /kaggle/working/log.txt
[2025-10-30 11:33:54] [INFO ] TinyWandB stats initialized at /kaggle/working/wandb_stats.txt


In [6]:
# --- ProductionARCSolver Core Logic (A* Search Implementation) ---

# Assumes the class structure from Cell 4 is available

class ProductionARCSolver:
    # ... (init and _load_weights methods remain from Cell 4) ...
    
    # New A* Search State Management
    def _reset_stats(self):
        self.stats = defaultdict(int)
        self.stats['nodes_pruned_veto'] = 0
        self.stats['nodes_pruned_causal'] = 0
        self.stats['nodes_expanded'] = 0
        logger.debug("Solver stats reset for new task.")

    def solve_task(self, task_id: str, task_data: Dict[str, Any], time_budget: float) -> Dict[str, List[List[List[int]]]]:
        """
        Main entry point. Executes A* search guided by the VetoNet policy 
        with robust time and exception handling.
        """
        logger.info(f"--- Starting Task: {task_id} (Budget: {time_budget}s) ---")
        self._reset_stats()
        
        start_time = time.time() 
        self.best_solution = None 
        self.open_set = [] # Stores BeamNode objects
        self.visited_set: Dict[int, 'BeamNode'] = {} # Stores Fuzzy ID -> Best Node

        # 1. Initial State Setup (Identical to previous Cell 7)
        try:
            self.initial_grid = grid_to_np(task_data['train'][0]['input'])
            self.target_grid = grid_to_np(task_data['train'][0]['output'])
            with torch.no_grad():
                self.task_z_causal = self.encoder(self.initial_grid, self.target_grid)
            
            initial_node = BeamNode(self.initial_grid, score=0.0, program_repr=[], gru_h_state=None)
            self.open_set.append(initial_node)
            self.visited_set[initial_node.id] = initial_node
            
        except Exception as e:
            logger.error(f"Initial setup failed for {task_id}: {e}")
            return self._identity_fallback(task_data)

        # --- 2. Audited A* Search Loop ---
        try:
            while self.open_set and (time.time() - start_time < time_budget):
                
                # A* Heuristic Sort: f(n) = g(n) + h(n)
                # We use score (MMSS) as the combined heuristic, prioritizing max score
                self.open_set.sort(key=lambda n: n.score, reverse=True)
                
                # Beam Width Enforcement
                if len(self.open_set) > CONFIG['BEAM_WIDTH']:
                    self.open_set = self.open_set[:CONFIG['BEAM_WIDTH']]
                
                current_node = self.open_set.pop(0)
                self.stats['nodes_expanded'] += 1
                
                logger.debug(f"A* Pop Node | Depth: {current_node.program_length} | MMSS: {current_node.score:.4f} | Open Set Size: {len(self.open_set)}")

                if current_node.score >= 0.99: # Direct check for immediate solution
                    self.best_solution = current_node
                    logger.info(f"Solution found (MMSS: {current_node.score:.4f}) at depth {current_node.program_length}. Breaking search.")
                    break
                
                if current_node.program_length >= CONFIG['MAX_PROGRAM_DEPTH']:
                    continue

                # Generate candidates (using VetoNet and Param Predictor)
                new_nodes = self._apply_operation(current_node) # Removed task_data argument (moved into self)
                
                # Add new nodes to the beam
                for new_node in new_nodes:
                    
                    # Hierarchical Pruning Check (Improvement 2)
                    if new_node.id in self.visited_set:
                        existing_node = self.visited_set[new_node.id]
                        
                        # A* Dominance Check: Only add if path is better or substantially different
                        if new_node.score <= existing_node.score:
                            # H8 Safety Check: If scores are near-equal, confirm the fuzzy match
                            if new_node.check_equivalence_safe(existing_node, self.dsl):
                                logger.debug(f"H8 Prune: Duplicate state confirmed with better/equal score.")
                                continue
                        
                        # If a better node is found, update the visited set
                        if new_node.score > existing_node.score:
                            self.visited_set[new_node.id] = new_node
                            self.open_set.append(new_node)
                            logger.debug(f"A* Update: Found better path to fuzzy state {new_node.id}.")
                    else:
                        # New state
                        self.visited_set[new_node.id] = new_node
                        self.open_set.append(new_node)
                        
                    # Update global best solution found
                    if self.best_solution is None or new_node.score > self.best_solution.score:
                        self.best_solution = new_node

            # --- Final Audit and Fallback Logic ---
            search_time = time.time() - start_time
            self._execute_final_audit(task_id, search_time, task_data)

        except Exception as e:
            logger.error(f"FATAL RUNTIME ERROR on Task {task_id}: {e}. Halting search gracefully.")
            search_time = time.time() - start_time
            self._execute_final_audit(task_id, search_time, task_data, runtime_error=True)
            
        # The final result is returned by the audited method
        return self.final_prediction_output

# Helper methods (will be defined in Cell 8, 9)
# ...


In [7]:
# Assumes necessary methods are part of ProductionARCSolver class

# --- Helper 1: Grid Size Confounder (for C4) ---
def _get_grid_size_oh(self, grid: np.ndarray) -> torch.Tensor:
    # ... (Implementation remains identical to previous Cell 8) ...
    H, W = grid.shape
    area = H * W
    if area <= 50: oh_idx = 0 
    elif area <= 150: oh_idx = 1 
    else: oh_idx = 2 
    oh = torch.zeros(3).to(DEVICE)
    oh[oh_idx] = 1.0
    return oh.unsqueeze(0)

# --- Helper 2: Multi-Metric Structural Score (MMSS) ---
def _score_solution(self, prediction: np.ndarray, target: np.ndarray, program_length: int) -> float:
    # ... (Implementation remains identical to previous Cell 8) ...
    # This must be the accurate MMSS calculation!
    
    # 1. Pixel Accuracy (Baseline IOU)
    # ...
    
    # 2. Structural Fidelity (Color and Shape)
    # ...
    
    # 3. Program Complexity Penalty
    length_penalty = 0.005 * program_length 
    
    # Final MMSS (Weighted sum)
    # ...
    mmss = np.clip(mmss, 0.0, 1.0)
    logger.debug(f"MMSS Calculated: Final={mmss:.4f} (Length Penalty: {length_penalty:.3f})")
    return mmss


# --- Core: Apply Operation and Pruning ---
def _apply_operation(self, node: 'BeamNode') -> List['BeamNode']:
    """
    Expands the node, integrating Causal, VetoNet, and Fuzzy Pruning logic.
    """
    new_nodes: List[BeamNode] = []
    grid_size_oh = self._get_grid_size_oh(node.grid) 

    # 1. Generate Primitive Candidates (H9: Propensity Bias)
    with torch.no_grad():
        param_preds = self.param_predictor(self.task_z_causal)
        propensity_scores = param_preds['propensity_scores'].squeeze(0)

    candidate_operations = list(self.dsl.primitives.keys())
    
    for op_idx, operation in enumerate(candidate_operations):
        
        # --- Audited Pruning Check 1: Functional and Redundancy Pruning (I2) ---
        if node.program_length > 0 and operation == node.program_repr[-1].get('op_name'):
             self.stats['nodes_pruned_causal'] += 1
             continue
        
        # --- Audited Pruning Check 2: Fuzzy Object Identity (F1) ---
        objects = self.dsl._segment_grid(node.grid)
        if objects and self.dsl._calculate_fuzziness(objects[0]) > 0.8 and operation in ['extract_largest']:
            self.stats['nodes_pruned_causal'] += 1
            continue

        # --- Audited Pruning Check 3: Deep Compliance Audit (C5/H7) ---
        do_effect = CONFIG['DO_EFFECTS'].get(operation, {}).get('object_count', 0)
        
        # Causal Uncertainty Buffering (Improvement 3: Soften the C5 veto)
        # Use a simple proxy for confidence (e.g., max propensity score)
        causal_confidence_proxy = propensity_scores.max().item() 
        veto_threshold = -0.2 * (1.0 - causal_confidence_proxy * 0.5) # Threshold becomes less strict if confidence is low
        
        if do_effect < veto_threshold:
            self.stats['nodes_pruned_causal'] += 1
            continue
        
        # --- Run VetoNet ---
        primitive_features = torch.tensor([[1.0, 0.0, 0.0, 0.0] if 'color' in operation else [0.0, 1.0, 0.0, 0.0]]).to(DEVICE)
        
        with torch.no_grad():
            veto_score_tensor, h_next = self.veto_net(
                self.task_z_causal, primitive_features, grid_size_oh, node.gru_h_state
            )
        veto_score = veto_score_tensor.item()
        
        op_propensity = propensity_scores[op_idx % 8].item()
        final_veto_score = veto_score * (1 + 0.5 * op_propensity) - 0.05 # NSM D6 penalty
        
        # --- Audited Pruning Check 4: VetoNet Threshold (D6) ---
        if final_veto_score < CONFIG['VETONET_POLICY_THRESHOLD']:
            self.stats['nodes_pruned_veto'] += 1
            continue
        
        # --- Success: Execution ---
        params = {'new_color': 5, 'threshold_ratio': 0.6}
        try:
            predicted_grid = self.dsl.primitives[operation](node.grid.copy(), **params)
        except Exception as e:
            logger.error(f"Primitive execution error: {operation} failed: {e}")
            continue

        # Score and Apply Bonuses (C6, F3)
        mmss = self._score_solution(predicted_grid, self.target_grid, node.program_length + 1)

        # C6: Mediator Chain Bonus check
        if mmss > node.score and node.program_length == 1:
            mmss += 0.02
        
        # F3: Apply Possibility Transform
        possibility_factor = self._possibility_transform(final_veto_score)
        final_score = mmss * possibility_factor # A* Heuristic f(n)

        # Create new node
        new_node = BeamNode(
            grid=predicted_grid,
            score=final_score,
            program_repr=node.program_repr + [{'op_name': operation, 'params': params}],
            gru_h_state=h_next
        )
        new_nodes.append(new_node)
        
    return new_nodes

def _possibility_transform(self, probability: float) -> float:
    """F3: Transforms VetoNet probability to Possibility Measure."""
    return 0.5 + 0.5 * np.tanh(5 * (probability - 0.5))


In [8]:
# Assumes necessary methods are part of ProductionARCSolver class

def _run_program(self, initial_grid: np.ndarray, program_repr: List[Dict[str, Any]]) -> Tuple[np.ndarray, bool]:

    #Executes a list of operations sequentially. Returns the final grid and a 
    #boolean flag for stability (True if stable/no null-op).
    
    current_grid = initial_grid.copy()
    
    for step, instruction in enumerate(program_repr):
        op_name = instruction['op_name']
        params = instruction.get('params', {})
        
        try:
            next_grid = self.dsl.primitives[op_name](current_grid, **params)
            
            # I2 Check: Causal Chain Connectivity - Prune null-ops on the fly
            if np.array_equal(next_grid, current_grid) and op_name not in ['identity']:
                logger.debug(f"I2 Runtime Prune: Null-op detected for {op_name} at step {step}.")
                return current_grid, False # Unstable
            
            current_grid = next_grid
        except Exception:
            return current_grid, False # Unstable

    return current_grid, True # Stable execution


def _identity_fallback(self, task_data: Dict[str, Any]) -> Dict[str, List[List[List[int]]]]:
    #Submits the input grid as the output for all test pairs (Safest return)."""
    predictions = {}
    for i, pair in enumerate(task_data['test']):
        predictions[f'output_{i}'] = pair['input'] 
    
    logger.warning("Submitting full Identity Fallback solution.")
    return {'test': [p for p in predictions.values()]}


def _format_predictions(self, task_data: Dict[str, Any], solution_node: 'BeamNode') -> Dict[str, List[List[List[int]]]]:

    #Improvement 4: Runs the best program on ALL test inputs, aggregating 
    #results independently and falling back per-test pair on instability.
    
    logger.info(f"Formatting final predictions using program length {solution_node.program_length}.")
    
    prediction_list = []
    
    for i, pair in enumerate(task_data['test']):
        test_input_grid = grid_to_np(pair['input'])
        
        # Execute the full, final program
        predicted_output, stability = self._run_program(
            test_input_grid, 
            solution_node.program_repr
        )
        
        if not stability:
            logger.warning(f"Test pair {i} execution was unstable/null-op. Submitting Identity.")
            # Fallback for THIS specific test pair only
            final_grid = test_input_grid 
        else:
            final_grid = predicted_output
        
        prediction_list.append(final_grid.tolist())
        
    return {'test': prediction_list}

def _execute_final_audit(self, task_id: str, search_time: float, task_data: Dict[str, Any], runtime_error: bool = False):
    #Called at the end of solve_task for final logging and stat recording.
    final_node_to_use = self.best_solution
    
    if runtime_error or final_node_to_use is None:
        final_score = 0.0
        final_program_length = 0
        self.final_prediction_output = self._identity_fallback(task_data)
        
    elif final_node_to_use.score < CONFIG['FALLBACK_MMSS_THRESHOLD']:
        final_score = final_node_to_use.score
        final_program_length = final_node_to_use.program_length
        self.final_prediction_output = self._identity_fallback(task_data)
        
    else:
        final_score = final_node_to_use.score
        final_program_length = final_node_to_use.program_length
        self.final_prediction_output = self._format_predictions(task_data, final_node_to_use)

    # TinyWandB Logging
    task_stats = {
        'task_id': task_id,
        'final_score': final_score,
        'search_time': search_time,
        'nodes_expanded': self.stats['nodes_expanded'],
        'nodes_pruned_veto': self.stats['nodes_pruned_veto'],
        'nodes_pruned_causal': self.stats['nodes_pruned_causal'],
        'final_program_length': final_program_length
    }
    twb.log_task_stats(task_stats)
    logger.info(f"--- Task {task_id} Final Score: {final_score:.4f} in {search_time:.2f}s ---")


In [9]:
# --- FORMAL META-TRAINING PROTOCOL DOCUMENTATION ---
"""
This protocol details the S-Tier neurosymbolic meta-training process used to generate 
the three neural weight files (.pth). This training is executed strictly OFFLINE 
on the public ARC training set (400+ tasks) to ensure compliance.

GOAL: Learn a generalizable, constrained policy (VetoNet) and distributional 
guidance (Param Predictor) to minimize symbolic search space complexity.

I. DATA GENERATION (The Audited Interaction Tuple)
   - Teacher: A simple, high-recall baseline search engine generates successful programs.
   - Tuple: (Z_causal, GRU_h_prev, Primitive_i, Target_Params, Confounder_OH, Score_Achieved)
   - Critical Audit: Samples are weighted inversely to the Causal Template Bias (H9) 
     to ensure balanced learning across all primitive types.

II. NEURAL ARCHITECTURES (Cell 3)
   - Encoder: SemanticCausalEncoder (outputs Z_causal_extended, 261D).
   - Policy: SequentialVetoNet (GRU-based policy for sequential, history-aware veto).
   - Predictor: ProbabilisticParamPredictor (outputs distributions and Propensity Scores (H9)).

III. LOSS FUNCTIONS (The Proofs of Optimization)

   A. Sequential VetoNet Loss (Policy Guidance):
      - Primary: Binary Cross-Entropy (BCE).
      - Penalty 1: **Focal Loss** (for class imbalance and hard example focus).
      - Penalty 2: **Minimal Correlation Penalty (H7)**: Penalizes high Veto Scores that 
        correlate with too many disparate target features (forcing focused causality).
      - Penalty 3: **D9 Action-Prediction Error**: Adds a term to penalize the VetoNet 
        when its predicted success score is far from the actual score achieved by the primitive.

   B. Param Predictor Loss (Distributional Guidance):
      - Primary: **Categorical Cross-Entropy (CCE)** for Color and Size distributions.
      - Secondary: **Mean Squared Error (MSE)** for coordinate predictions.
      - Optimization: The loss includes a weight that prioritizes predicting parameters 
        for primitives that have a high Propensity Score (H9), focusing network capacity.

IV. OUTPUT
   - The final state dictionaries are saved as:
     - `semantic_encoder_weights.pth`
     - `sequential_vetonet_weights.pth`
     - `probabilistic_param_weights.pth`
"""

'\nThis protocol details the S-Tier neurosymbolic meta-training process used to generate \nthe three neural weight files (.pth). This training is executed strictly OFFLINE \non the public ARC training set (400+ tasks) to ensure compliance.\n\nGOAL: Learn a generalizable, constrained policy (VetoNet) and distributional \nguidance (Param Predictor) to minimize symbolic search space complexity.\n\nI. DATA GENERATION (The Audited Interaction Tuple)\n   - Teacher: A simple, high-recall baseline search engine generates successful programs.\n   - Tuple: (Z_causal, GRU_h_prev, Primitive_i, Target_Params, Confounder_OH, Score_Achieved)\n   - Critical Audit: Samples are weighted inversely to the Causal Template Bias (H9) \n     to ensure balanced learning across all primitive types.\n\nII. NEURAL ARCHITECTURES (Cell 3)\n   - Encoder: SemanticCausalEncoder (outputs Z_causal_extended, 261D).\n   - Policy: SequentialVetoNet (GRU-based policy for sequential, history-aware veto).\n   - Predictor: Pro