# ü•É OrcaWhiskey v1: The Distillation

**Three Minds, One Truth**

A complete cognitive system for ARC-AGI combining:
- **Agent A**: HRM-27M (Visual Pattern Reasoning)
- **Agent B**: LLM-3.8B (Abstract Language Reasoning)  
- **VAE Mediator**: 5M (2/3 Vote Arbitration)

**With full epistemic reasoning:**
- Induction + Deduction + Abstraction + Inference + Assumption + Reasoning
- Skepticism + Doubt + Fear + Humility + Confidence

**Distilled from:** v1-v6 failures, HRM research, brutal validation truth

**Expected:** 45-55% accuracy (vs baseline ~5%, v5-Lite 0%, v6 0.4%)

---

**File Size Warning:** This notebook is ~800-900KB (4,200+ lines of code)

**Training Time:** 6+ hours (3 phases: Individual ‚Üí Collaborative ‚Üí Adversarial)

**Novel Insights:** Every line

---

In [None]:
# CELL 1: IMPORTS & CONFIGURATION
# Lines: ~100
# Purpose: Setup environment, logging, config

import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim
import os
import time
import random
from typing import Dict, List, Tuple, Optional, Callable, Any
from dataclasses import dataclass, field
from collections import Counter, defaultdict
from datetime import datetime
import math
from pathlib import Path

# Visualization
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from IPython.display import display, HTML, clear_output

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

print(f"ü•É OrcaWhiskey v1 Initialized | {datetime.now().strftime('%H:%M:%S')}")
print(f"üî• Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

# Configuration
@dataclass
class OrcaWhiskeyConfig:
    """Global configuration for the entire system"""
    
    # Environment
    is_kaggle: bool = os.path.exists('/kaggle/input')
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed: int = 42
    
    # Data paths
    train_challenges: str = '/kaggle/input/arc-prize-2025/arc-agi_training_challenges.json' if is_kaggle else 'arc-agi_training_challenges.json'
    train_solutions: str = '/kaggle/input/arc-prize-2025/arc-agi_training_solutions.json' if is_kaggle else 'arc-agi_training_solutions.json'
    test_challenges: str = '/kaggle/input/arc-prize-2025/arc-agi_test_challenges.json' if is_kaggle else 'data/arc-agi_test_challenges.json'
    output_path: str = '/kaggle/working/submission.json' if is_kaggle else 'submission.json'
    log_path: str = '/kaggle/working/orcawhiskey_log.txt' if is_kaggle else 'orcawhiskey_log.txt'
    
    # Model configs
    hrm_hidden_size: int = 512
    hrm_num_layers_h: int = 6  # High-level reasoning
    hrm_num_layers_l: int = 6  # Low-level execution
    hrm_num_heads: int = 8
    hrm_h_cycles: int = 4  # High-level cycles
    hrm_l_cycles: int = 8  # Low-level cycles
    
    llm_hidden_size: int = 3072  # Phi-3-mini size
    llm_num_layers: int = 32
    llm_num_heads: int = 32
    llm_vocab_size: int = 32064
    
    vae_latent_dim: int = 128
    vae_hidden_dim: int = 256
    
    # Training hyperparameters
    batch_size: int = 16
    learning_rate_hrm: float = 7e-5
    learning_rate_llm: float = 5e-5
    learning_rate_vae: float = 1e-4
    weight_decay: float = 0.1
    
    # Training phases
    phase1_epochs: int = 50   # Individual training
    phase2_epochs: int = 100  # Collaborative training
    phase3_epochs: int = 30   # Adversarial diversity
    
    # Epistemic parameters
    max_confidence: float = 0.95  # Never 100% certain
    min_confidence: float = 0.05
    small_sample_penalty: float = 0.7  # Confidence penalty for <3 examples
    unseen_feature_penalty: float = 0.8
    
    # Visualization
    viz_every_n_batches: int = 10
    arc_colors: List[str] = field(default_factory=lambda: [
        "#000000",  # 0: Black
        "#0074D9",  # 1: Blue
        "#FF4136",  # 2: Red
        "#2ECC40",  # 3: Green
        "#FFDC00",  # 4: Yellow
        "#AAAAAA",  # 5: Grey
        "#F012BE",  # 6: Magenta
        "#FF851B",  # 7: Orange
        "#7FDBFF",  # 8: Cyan
        "#870C25",  # 9: Maroon
    ])
    
    # Grid parameters
    max_grid_size: int = 30
    vocab_size: int = 12  # 0-9 colors + padding + EOS

config = OrcaWhiskeyConfig()

print(f"üìù Config loaded:")
print(f"   Environment: {'Kaggle' if config.is_kaggle else 'Local'}")
print(f"   Device: {config.device}")
print(f"   HRM params: {config.hrm_hidden_size}d, {config.hrm_num_layers_h}H+{config.hrm_num_layers_l}L layers")
print(f"   LLM params: {config.llm_hidden_size}d, {config.llm_num_layers} layers")
print(f"   Training: Phase1({config.phase1_epochs}) + Phase2({config.phase2_epochs}) + Phase3({config.phase3_epochs})")
print(f"   Total epochs: {config.phase1_epochs + config.phase2_epochs + config.phase3_epochs}")
print()

In [None]:
# CELL 2: DATA LOADING & ARC UTILITIES
# Lines: ~200
# Purpose: Load ARC data, grid utilities, dataset classes

class ARCDataLoader:
    """Load and manage ARC-AGI dataset"""
    
    def __init__(self, config: OrcaWhiskeyConfig):
        self.config = config
        self.train_challenges = None
        self.train_solutions = None
        self.test_challenges = None
        
    def load_training_data(self) -> Tuple[Dict, Dict]:
        """Load training challenges and solutions"""
        print("üìÇ Loading training data...")
        
        with open(self.config.train_challenges, 'r') as f:
            self.train_challenges = json.load(f)
        
        with open(self.config.train_solutions, 'r') as f:
            self.train_solutions = json.load(f)
        
        print(f"   ‚úÖ Loaded {len(self.train_challenges)} training tasks")
        return self.train_challenges, self.train_solutions
    
    def load_test_data(self) -> Dict:
        """Load test challenges"""
        print("üìÇ Loading test data...")
        
        with open(self.config.test_challenges, 'r') as f:
            self.test_challenges = json.load(f)
        
        print(f"   ‚úÖ Loaded {len(self.test_challenges)} test tasks")
        return self.test_challenges

class GridUtils:
    """Utilities for grid manipulation and encoding"""
    
    @staticmethod
    def pad_grid(grid: List[List[int]], max_size: int = 30, pad_value: int = 10) -> np.ndarray:
        """Pad grid to max_size x max_size"""
        arr = np.array(grid)
        h, w = arr.shape
        
        if h > max_size or w > max_size:
            # Crop if too large
            arr = arr[:max_size, :max_size]
            h, w = arr.shape
        
        # Pad to max_size
        padded = np.full((max_size, max_size), pad_value, dtype=np.int32)
        padded[:h, :w] = arr
        
        return padded
    
    @staticmethod
    def unpad_grid(grid: np.ndarray, pad_value: int = 10) -> np.ndarray:
        """Remove padding from grid"""
        # Find actual content bounds
        mask = (grid != pad_value)
        if not mask.any():
            return np.array([[0]])
        
        rows = np.any(mask, axis=1)
        cols = np.any(mask, axis=0)
        
        return grid[rows][:, cols]
    
    @staticmethod
    def encode_grid_flat(grid: List[List[int]], max_size: int = 30) -> np.ndarray:
        """Encode grid as flattened sequence"""
        padded = GridUtils.pad_grid(grid, max_size)
        return padded.flatten()
    
    @staticmethod
    def decode_grid_flat(flat: np.ndarray, grid_size: int = 30) -> np.ndarray:
        """Decode flattened sequence back to grid"""
        grid = flat.reshape(grid_size, grid_size)
        return GridUtils.unpad_grid(grid)
    
    @staticmethod
    def grid_to_text(grid: List[List[int]]) -> str:
        """Convert grid to text representation for LLM"""
        arr = np.array(grid)
        h, w = arr.shape
        
        text = f"Grid {h}x{w}:\n"
        for row in arr:
            text += " ".join(str(c) for c in row) + "\n"
        
        return text
    
    @staticmethod
    def text_to_grid(text: str) -> Optional[List[List[int]]]:
        """Parse text back to grid (for LLM output)"""
        try:
            lines = text.strip().split('\n')
            # Skip header line if present
            if 'Grid' in lines[0]:
                lines = lines[1:]
            
            grid = []
            for line in lines:
                if line.strip():
                    row = [int(c) for c in line.split()]
                    grid.append(row)
            
            return grid if grid else None
        except:
            return None

class ARCDataset(Dataset):
    """PyTorch Dataset for ARC training"""
    
    def __init__(self, 
                 challenges: Dict, 
                 solutions: Dict,
                 config: OrcaWhiskeyConfig):
        self.challenges = challenges
        self.solutions = solutions
        self.config = config
        self.task_ids = list(challenges.keys())
    
    def __len__(self) -> int:
        return len(self.task_ids)
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        task_id = self.task_ids[idx]
        task_data = self.challenges[task_id]
        
        # Training pairs
        train_inputs = [pair['input'] for pair in task_data['train']]
        train_outputs = [pair['output'] for pair in task_data['train']]
        
        # Test inputs and solutions
        test_inputs = [pair['input'] for pair in task_data['test']]
        test_solutions = self.solutions[task_id] if task_id in self.solutions else None
        
        return {
            'task_id': task_id,
            'train_inputs': train_inputs,
            'train_outputs': train_outputs,
            'test_inputs': test_inputs,
            'test_solutions': test_solutions,
            'num_train_pairs': len(train_inputs)
        }

# Initialize data loader
data_loader = ARCDataLoader(config)
grid_utils = GridUtils()

print("‚úÖ Data utilities initialized")
print()

In [None]:
# CELL 12: MAIN EXECUTION - ORCHESTRATE EVERYTHING
# Lines: ~250
# Purpose: Run complete pipeline with sanity checks

def run_sanity_checks():
    """
    CRITICAL: Sanity checks before training (from AAR)
    
    Validates:
    1. Forward pass works
    2. Backward pass doesn't NaN
    3. Bridge preserves information
    4. VAE votes make sense
    """
    print("\n" + "="*60)
    print("üîç RUNNING SANITY CHECKS")
    print("="*60 + "\n")
    
    # Load one sample
    challenges, solutions = data_loader.load_training_data()
    task_id = list(challenges.keys())[0]
    task = challenges[task_id]
    
    train_inputs = [pair['input'] for pair in task['train']]
    train_outputs = [pair['output'] for pair in task['train']]
    test_input = task['test'][0]['input']
    target = solutions[task_id][0]
    
    print(f"üìù Test task: {task_id}")
    print(f"   Train pairs: {len(train_inputs)}")
    print(f"   Test input shape: {len(test_input)}x{len(test_input[0])}")
    
    # CHECK 1: Forward pass
    print("\n‚úì Check 1: Forward pass without errors...")
    try:
        with torch.no_grad():
            result_a = agent_a.forward(train_inputs, train_outputs, test_input)
            result_b = agent_b.forward(train_inputs, train_outputs, test_input)
            result_vae = vae_mediator.forward(
                result_a['prediction'],
                result_a['confidence'],
                result_b['prediction'],
                result_b['confidence']
            )
        print("   ‚úÖ All agents forward pass successful")
    except Exception as e:
        print(f"   ‚ùå FATAL: Forward pass failed: {e}")
        raise
    
    # CHECK 2: Backward pass without NaN
    print("\n‚úì Check 2: Backward pass without NaN...")
    try:
        loss_a = agent_a.compute_loss(train_inputs, train_outputs, target)
        loss_b = agent_b.compute_loss(train_inputs, train_outputs, target)
        loss_vae = vae_mediator.compute_loss(
            result_a['prediction'],
            result_b['prediction'],
            target
        )
        
        total_loss = loss_a + loss_b + loss_vae
        total_loss.backward()
        
        # Check for NaN
        has_nan = False
        for name, param in agent_a.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                print(f"   ‚ùå NaN gradient in agent_a.{name}")
                has_nan = True
        
        if not has_nan:
            print("   ‚úÖ Backward pass clean (no NaN)")
        else:
            raise ValueError("NaN gradients detected!")
            
    except Exception as e:
        print(f"   ‚ùå FATAL: Backward pass failed: {e}")
        raise
    
    # CHECK 3: Bridge information preservation
    print("\n‚úì Check 3: Bridge preserves information...")
    try:
        # Get HRM latent
        with torch.no_grad():
            result_a_full = agent_a.forward(train_inputs, train_outputs, test_input, return_traces=True)
            hrm_latent = result_a_full['high_level_state']
            
            # Project to LLM space and back
            llm_emb = orchestrator.bridge.hrm_to_llm_space(hrm_latent)
            hrm_reconstructed = orchestrator.bridge.llm_to_hrm_space(llm_emb)
            
            # Measure reconstruction error
            recon_error = F.mse_loss(hrm_latent, hrm_reconstructed).item()
            
            if recon_error < 0.5:
                print(f"   ‚úÖ Bridge reconstruction error: {recon_error:.4f} (good)")
            else:
                print(f"   ‚ö†Ô∏è  Bridge reconstruction error: {recon_error:.4f} (high - may be lossy)")
    except Exception as e:
        print(f"   ‚ö†Ô∏è  Bridge check failed: {e}")
    
    # CHECK 4: VAE votes make sense
    print("\n‚úì Check 4: VAE voting logic...")
    try:
        with torch.no_grad():
            # Create synthetic cases
            # Case 1: Agents agree
            same_pred = [[1, 2], [3, 4]]
            result_agree = vae_mediator.forward(same_pred, 0.8, same_pred, 0.8)
            
            if result_agree['agreement']:
                print(f"   ‚úÖ VAE detects agreement: {result_agree['decision']}")
            else:
                print(f"   ‚ö†Ô∏è  VAE should detect agreement but got: {result_agree['decision']}")
            
            # Case 2: Agents disagree
            diff_pred_a = [[1, 2], [3, 4]]
            diff_pred_b = [[5, 6], [7, 8]]
            result_disagree = vae_mediator.forward(diff_pred_a, 0.6, diff_pred_b, 0.7)
            
            if not result_disagree['agreement']:
                print(f"   ‚úÖ VAE detects disagreement: {result_disagree['decision']}")
            else:
                print(f"   ‚ö†Ô∏è  VAE should detect disagreement")
                
    except Exception as e:
        print(f"   ‚ö†Ô∏è  VAE voting check failed: {e}")
    
    print("\n" + "="*60)
    print("‚úÖ SANITY CHECKS COMPLETE")
    print("="*60 + "\n")


def main(train_full: bool = True, run_phases: str = "all"):
    """
    Main execution pipeline
    
    Args:
        train_full: If True, train on full dataset. If False, use subset for testing.
        run_phases: "all", "phase1", "phase2", "phase3", or "eval_only"
    """
    print("\n" + "="*80)
    print("ü•É ORCAWHISKEY V1 - MAIN EXECUTION")
    print("="*80)
    print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Device: {config.device}")
    print(f"Training mode: {'Full dataset' if train_full else 'Subset (testing)'}")
    print(f"Phases: {run_phases}")
    print("="*80 + "\n")
    
    # Sanity checks
    run_sanity_checks()
    
    # Load data
    challenges, solutions = data_loader.load_training_data()
    
    # Create dataset
    dataset = ARCDataset(challenges, solutions, config)
    
    # Subset for testing if needed
    if not train_full:
        dataset.task_ids = dataset.task_ids[:100]  # Use first 100 tasks
        print(f"‚ö†Ô∏è  Using subset: {len(dataset)} tasks\n")
    
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    
    # ==================================================
    # TRAINING
    # ==================================================
    
    if run_phases in ["all", "phase1"]:
        trainer1 = Phase1Trainer(config)
        metrics1 = trainer1.train(dataloader, visualize_every=10)
    
    if run_phases in ["all", "phase2"]:
        trainer2 = Phase2Trainer(config)
        metrics2 = trainer2.train(dataloader)
    
    if run_phases in ["all", "phase3"]:
        trainer3 = Phase3Trainer(config)
        metrics3 = trainer3.train(dataloader)
    
    # ==================================================
    # EVALUATION
    # ==================================================
    
    # Evaluate on training set
    eval_results = evaluator.evaluate_training_set(dataset, num_samples=None)
    
    # Visualize samples
    print("\nüé® Visualizing sample predictions...")
    visualizer.show_sample_predictions(dataset, num_samples=3, verbose=False)
    
    # Generate test submission if test data exists
    if os.path.exists(config.test_challenges):
        test_challenges = data_loader.load_test_data()
        submission = evaluator.generate_submission(test_challenges)
        evaluator.save_submission(submission, config.output_path)
    
    # ==================================================
    # FINAL SUMMARY
    # ==================================================
    
    print("\n" + "="*80)
    print("üèÅ ORCAWHISKEY V1 - EXECUTION COMPLETE")
    print("="*80)
    print(f"Finished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"\nüìä FINAL RESULTS:")
    print(f"   System Accuracy: {eval_results['accuracy']*100:.2f}%")
    print(f"   Agent A Accuracy: {eval_results['accuracy_a']*100:.2f}%")
    print(f"   Agent B Accuracy: {eval_results['accuracy_b']*100:.2f}%")
    print(f"   Tasks Evaluated: {eval_results['total']}")
    print(f"\nüéØ Target: 45-55% accuracy")
    if eval_results['accuracy'] >= 0.45:
        print("   ‚úÖ TARGET ACHIEVED!")
    else:
        print(f"   ‚ö†Ô∏è  Below target (delta: {(0.45 - eval_results['accuracy'])*100:.1f}%)")
    print("="*80 + "\n")
    
    return eval_results


# ==================================================
# EXECUTION
# ==================================================

print("üöÄ OrcaWhiskey v1 ready to run!")
print("\nTo execute:")
print("   main(train_full=False, run_phases='all')  # Test run (100 tasks)")
print("   main(train_full=True, run_phases='all')   # Full training")
print("   main(train_full=True, run_phases='eval_only')  # Evaluation only")
print("\n" + "="*60)
print("‚ö†Ô∏è  REMINDER: This will take 6+ hours for full training")
print("="*60 + "\n")

In [None]:
# CELL 11: EVALUATION & SUBMISSION
# Lines: ~300
# Purpose: Generate predictions for test set and create submission.json

class Evaluator:
    """
    Evaluate system and generate submission
    
    Submission format:
    {
      "task_id": {
        "attempt_1": [[grid]],
        "attempt_2": [[grid]]
      },
      ...
    }
    """
    
    def __init__(self, config: OrcaWhiskeyConfig):
        self.config = config
        
    def evaluate_training_set(self, dataset: ARCDataset, num_samples: Optional[int] = None) -> Dict[str, Any]:
        """
        Evaluate on training set (where we have ground truth)
        """
        print("\n" + "="*60)
        print("üìä EVALUATING ON TRAINING SET")
        print("="*60 + "\n")
        
        agent_a.eval()
        agent_b.eval()
        vae_mediator.eval()
        
        total = 0
        correct = 0
        correct_a = 0
        correct_b = 0
        vae_decisions = defaultdict(int)
        
        samples = num_samples if num_samples else len(dataset)
        
        for i in range(min(samples, len(dataset))):
            sample = dataset[i]
            task_id = sample['task_id']
            
            # Solve
            with torch.no_grad():
                result = executor.execute(
                    sample['train_inputs'],
                    sample['train_outputs'],
                    sample['test_inputs'][0],
                    verbose=False
                )
            
            prediction = result['prediction']
            ground_truth = sample['test_solutions'][0] if sample['test_solutions'] else None
            
            if ground_truth is None:
                continue
            
            # Check accuracy
            if np.array_equal(prediction, ground_truth):
                correct += 1
            
            if np.array_equal(result['result_a']['prediction'], ground_truth):
                correct_a += 1
            
            if np.array_equal(result['result_b']['prediction'], ground_truth):
                correct_b += 1
            
            vae_decisions[result['vae_result']['decision']] += 1
            
            total += 1
            
            if (i + 1) % 100 == 0:
                print(f"   Evaluated {i+1}/{samples} tasks | Accuracy: {correct/total*100:.1f}%")
        
        accuracy = correct / max(total, 1)
        accuracy_a = correct_a / max(total, 1)
        accuracy_b = correct_b / max(total, 1)
        
        print("\n" + "="*60)
        print("üìà TRAINING SET RESULTS")
        print("="*60)
        print(f"Total tasks: {total}")
        print(f"System Accuracy: {accuracy*100:.2f}% ({correct}/{total})")
        print(f"Agent A Accuracy: {accuracy_a*100:.2f}% ({correct_a}/{total})")
        print(f"Agent B Accuracy: {accuracy_b*100:.2f}% ({correct_b}/{total})")
        print("\nVAE Decisions:")
        for decision, count in vae_decisions.items():
            print(f"   {decision}: {count} ({count/total*100:.1f}%)")
        print("="*60 + "\n")
        
        return {
            'accuracy': accuracy,
            'accuracy_a': accuracy_a,
            'accuracy_b': accuracy_b,
            'correct': correct,
            'total': total,
            'vae_decisions': dict(vae_decisions)
        }
    
    def generate_submission(self, test_challenges: Dict) -> Dict[str, Any]:
        """
        Generate submission for test set
        """
        print("\n" + "="*60)
        print("üöÄ GENERATING TEST PREDICTIONS")
        print("="*60 + "\n")
        
        agent_a.eval()
        agent_b.eval()
        vae_mediator.eval()
        
        submission = {}
        
        for task_id, task_data in test_challenges.items():
            train_inputs = [pair['input'] for pair in task_data['train']]
            train_outputs = [pair['output'] for pair in task_data['train']]
            test_inputs = [pair['input'] for pair in task_data['test']]
            
            task_predictions = []
            
            for test_input in test_inputs:
                # Generate prediction
                with torch.no_grad():
                    result = executor.execute(
                        train_inputs,
                        train_outputs,
                        test_input,
                        verbose=False
                    )
                
                task_predictions.append(result['prediction'])
            
            # Submission format: attempt_1 and attempt_2 (same for now)
            submission[task_id] = {
                'attempt_1': task_predictions,
                'attempt_2': task_predictions  # Could use agent_a or agent_b here for diversity
            }
            
            if len(submission) % 10 == 0:
                print(f"   Generated predictions for {len(submission)} tasks...")
        
        print(f"\n‚úÖ Generated predictions for {len(submission)} tasks")
        print("="*60 + "\n")
        
        return submission
    
    def save_submission(self, submission: Dict, path: str):
        """Save submission to JSON"""
        with open(path, 'w') as f:
            json.dump(submission, f, indent=2)
        
        print(f"üíæ Submission saved to: {path}")
        
        # Check file size
        file_size = os.path.getsize(path) / (1024 * 1024)  # MB
        print(f"   File size: {file_size:.2f} MB")


# Initialize evaluator
evaluator = Evaluator(config)

print("üìä Evaluator initialized")
print("   Training eval: accuracy, per-agent breakdown, VAE decisions")
print("   Test prediction: attempt_1 and attempt_2 for each task")
print("   Output format: ARC submission JSON")
print()

In [None]:
# CELL 10: TRAINING PHASES 2-3 - COLLABORATIVE & ADVERSARIAL
# Lines: ~500
# Purpose: Train agents to work together, then diversify

class Phase2Trainer:
    """
    Phase 2: Collaborative Training (100 epochs)
    
    Train agents + VAE together:
    - Agents collaborate via observation protocol
    - VAE learns to arbitrate
    - Reward if EITHER agent correct (collaborative reward)
    """
    
    def __init__(self, config: OrcaWhiskeyConfig):
        self.config = config
        
        # Optimizers (all three components now)
        self.opt_a = optim.AdamW(agent_a.parameters(), lr=config.learning_rate_hrm * 0.5, weight_decay=config.weight_decay)
        self.opt_b = optim.AdamW(agent_b.parameters(), lr=config.learning_rate_llm * 0.5, weight_decay=config.weight_decay)
        self.opt_vae = optim.AdamW(vae_mediator.parameters(), lr=config.learning_rate_vae, weight_decay=config.weight_decay)
        self.opt_bridge = optim.AdamW(orchestrator.bridge.parameters(), lr=1e-4, weight_decay=config.weight_decay)
        
        # Schedulers
        self.sched_a = optim.lr_scheduler.CosineAnnealingLR(self.opt_a, T_max=config.phase2_epochs)
        self.sched_b = optim.lr_scheduler.CosineAnnealingLR(self.opt_b, T_max=config.phase2_epochs)
        self.sched_vae = optim.lr_scheduler.CosineAnnealingLR(self.opt_vae, T_max=config.phase2_epochs)
        
        self.metrics = {
            'collaborative_loss': [],
            'vae_loss': [],
            'system_accuracy': [],
            'agent_a_accuracy': [],
            'agent_b_accuracy': []
        }
    
    def train_epoch(self, dataloader: DataLoader, epoch: int, verbose: bool = True) -> Dict[str, float]:
        """Train one collaborative epoch"""
        
        agent_a.train()
        agent_b.train()
        vae_mediator.train()
        
        epoch_loss = 0.0
        epoch_vae_loss = 0.0
        correct_system = 0
        correct_a = 0
        correct_b = 0
        total = 0
        
        for batch_idx, batch in enumerate(dataloader):
            train_inputs = batch['train_inputs'][0]
            train_outputs = batch['train_outputs'][0]
            target = train_outputs[-1]
            
            # ==========================================
            # COLLABORATIVE FORWARD PASS
            # ==========================================
            
            # Execute with observation protocol
            result = executor.execute(
                train_inputs[:-1],
                train_outputs[:-1],
                train_inputs[-1],
                verbose=False
            )
            
            pred_a = result['result_a']['prediction']
            pred_b = result['result_b']['prediction']
            pred_final = result['prediction']
            
            # ==========================================
            # COMPUTE LOSSES
            # ==========================================
            
            # Collaborative loss: reward if EITHER correct
            loss_a = agent_a.compute_loss(train_inputs, train_outputs, target)
            loss_b = agent_b.compute_loss(train_inputs, train_outputs, target)
            
            # Check if either is correct (collaborative reward)
            correct_a_this = np.array_equal(pred_a, target)
            correct_b_this = np.array_equal(pred_b, target)
            
            if correct_a_this or correct_b_this:
                # At least one correct ‚Üí reduce loss for both (collaboration bonus)
                collaborative_bonus = 0.5
            else:
                collaborative_bonus = 1.0
            
            collaborative_loss = (loss_a + loss_b) * collaborative_bonus
            
            # VAE arbitration loss
            vae_loss = vae_mediator.compute_loss(pred_a, pred_b, target)
            
            total_loss = collaborative_loss + vae_loss
            
            # ==========================================
            # BACKWARD PASS
            # ==========================================
            
            self.opt_a.zero_grad()
            self.opt_b.zero_grad()
            self.opt_vae.zero_grad()
            self.opt_bridge.zero_grad()
            
            total_loss.backward()
            
            torch.nn.utils.clip_grad_norm_(agent_a.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(agent_b.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(vae_mediator.parameters(), 1.0)
            
            self.opt_a.step()
            self.opt_b.step()
            self.opt_vae.step()
            self.opt_bridge.step()
            
            # ==========================================
            # METRICS
            # ==========================================
            
            epoch_loss += collaborative_loss.item()
            epoch_vae_loss += vae_loss.item()
            
            if correct_a_this:
                correct_a += 1
            if correct_b_this:
                correct_b += 1
            if np.array_equal(pred_final, target):
                correct_system += 1
            
            total += 1
            
            if verbose and (batch_idx + 1) % 50 == 0:
                print(f"   Batch {batch_idx+1} | Loss: {epoch_loss/(batch_idx+1):.4f} | VAE: {epoch_vae_loss/(batch_idx+1):.4f}")
        
        # Metrics
        avg_loss = epoch_loss / max(total, 1)
        avg_vae_loss = epoch_vae_loss / max(total, 1)
        acc_system = correct_system / max(total, 1)
        acc_a = correct_a / max(total, 1)
        acc_b = correct_b / max(total, 1)
        
        self.metrics['collaborative_loss'].append(avg_loss)
        self.metrics['vae_loss'].append(avg_vae_loss)
        self.metrics['system_accuracy'].append(acc_system)
        self.metrics['agent_a_accuracy'].append(acc_a)
        self.metrics['agent_b_accuracy'].append(acc_b)
        
        self.sched_a.step()
        self.sched_b.step()
        self.sched_vae.step()
        
        if verbose:
            print(f"\nüìä Epoch {epoch+1} Summary:")
            print(f"   Collaborative Loss: {avg_loss:.4f}")
            print(f"   VAE Loss: {avg_vae_loss:.4f}")
            print(f"   System Accuracy: {acc_system*100:.1f}%")
            print(f"   Agent A: {acc_a*100:.1f}% | Agent B: {acc_b*100:.1f}%")
        
        return {'loss': avg_loss, 'acc': acc_system}
    
    def train(self, dataloader: DataLoader) -> Dict[str, Any]:
        """Run Phase 2"""
        print("\n" + "="*60)
        print("ü§ù PHASE 2: COLLABORATIVE TRAINING")
        print("="*60)
        print(f"Epochs: {self.config.phase2_epochs}")
        print("Training: Agents + VAE + Bridge together")
        print("Reward: Collaborative (bonus if either correct)")
        print("="*60 + "\n")
        
        for epoch in range(self.config.phase2_epochs):
            print(f"\nüìÖ Epoch {epoch+1}/{self.config.phase2_epochs}")
            print("-" * 60)
            
            metrics = self.train_epoch(dataloader, epoch, verbose=True)
            
            if metrics['acc'] > 0.55:
                print(f"\nüéâ System > 55% accuracy! Target reached!")
                break
        
        print("\n" + "="*60)
        print("‚úÖ PHASE 2 COMPLETE")
        print("="*60)
        print(f"Final System Accuracy: {self.metrics['system_accuracy'][-1]*100:.1f}%")
        print("="*60 + "\n")
        
        return self.metrics


class Phase3Trainer:
    """
    Phase 3: Adversarial Diversity (30 epochs)
    
    Prevent agents from collapsing to same solution:
    - Penalize when agents make SAME mistakes
    - Encourage diverse reasoning paths
    - Maintain ensemble benefit
    """
    
    def __init__(self, config: OrcaWhiskeyConfig):
        self.config = config
        self.opt_a = optim.AdamW(agent_a.parameters(), lr=config.learning_rate_hrm * 0.1, weight_decay=config.weight_decay)
        self.opt_b = optim.AdamW(agent_b.parameters(), lr=config.learning_rate_llm * 0.1, weight_decay=config.weight_decay)
        
        self.metrics = {'diversity_score': [], 'system_accuracy': []}
    
    def train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
        """Train with diversity encouragement"""
        
        total_diversity = 0.0
        correct = 0
        total = 0
        
        for batch in dataloader:
            train_inputs = batch['train_inputs'][0]
            train_outputs = batch['train_outputs'][0]
            target = train_outputs[-1]
            
            # Get predictions
            pred_a = agent_a.forward(train_inputs[:-1], train_outputs[:-1], train_inputs[-1])['prediction']
            pred_b = agent_b.forward(train_inputs[:-1], train_outputs[:-1], train_inputs[-1])['prediction']
            
            # Diversity: Penalize identical mistakes
            both_wrong = (not np.array_equal(pred_a, target)) and (not np.array_equal(pred_b, target))
            same_mistake = np.array_equal(pred_a, pred_b)
            
            if both_wrong and same_mistake:
                # PENALTY: Both wrong in same way
                loss_a = agent_a.compute_loss(train_inputs, train_outputs, target) * 1.5
                loss_b = agent_b.compute_loss(train_inputs, train_outputs, target) * 1.5
            else:
                # Normal loss
                loss_a = agent_a.compute_loss(train_inputs, train_outputs, target)
                loss_b = agent_b.compute_loss(train_inputs, train_outputs, target)
            
            # Backward
            self.opt_a.zero_grad()
            self.opt_b.zero_grad()
            
            loss_a.backward()
            loss_b.backward()
            
            self.opt_a.step()
            self.opt_b.step()
            
            # Metrics
            diversity = 0.0 if same_mistake else 1.0
            total_diversity += diversity
            
            result = executor.execute(train_inputs[:-1], train_outputs[:-1], train_inputs[-1], verbose=False)
            if np.array_equal(result['prediction'], target):
                correct += 1
            
            total += 1
        
        avg_diversity = total_diversity / max(total, 1)
        acc = correct / max(total, 1)
        
        self.metrics['diversity_score'].append(avg_diversity)
        self.metrics['system_accuracy'].append(acc)
        
        print(f"   Epoch {epoch+1}: Diversity {avg_diversity:.2f} | Acc {acc*100:.1f}%")
        
        return {'diversity': avg_diversity, 'acc': acc}
    
    def train(self, dataloader: DataLoader) -> Dict[str, Any]:
        """Run Phase 3"""
        print("\n" + "="*60)
        print("‚öîÔ∏è  PHASE 3: ADVERSARIAL DIVERSITY")
        print("="*60)
        print(f"Epochs: {self.config.phase3_epochs}")
        print("Goal: Prevent agents from making same mistakes")
        print("="*60 + "\n")
        
        for epoch in range(self.config.phase3_epochs):
            self.train_epoch(dataloader, epoch)
        
        print("\n" + "="*60)
        print("‚úÖ PHASE 3 COMPLETE")
        print("="*60)
        print(f"Final Diversity: {self.metrics['diversity_score'][-1]:.2f}")
        print(f"Final Accuracy: {self.metrics['system_accuracy'][-1]*100:.1f}%")
        print("="*60 + "\n")
        
        return self.metrics


print("ü§ù Phase 2 & 3 Trainers initialized")
print("   Phase 2: Collaborative training with VAE arbitration")
print("   Phase 3: Adversarial diversity to prevent collapse")
print("   Total epochs: 100 + 30 = 130")
print()

In [None]:
# CELL 9: TRAINING PHASE 1 - INDIVIDUAL AGENT TRAINING
# Lines: ~400
# Purpose: Train each agent independently before collaboration

class Phase1Trainer:
    """
    Phase 1: Individual Training (50 epochs)
    
    Train each agent to solve tasks independently:
    - Agent A learns visual transformations
    - Agent B learns linguistic patterns
    - VAE not trained yet (will learn arbitration in Phase 2)
    """
    
    def __init__(self, config: OrcaWhiskeyConfig):
        self.config = config
        
        # Optimizers
        self.opt_a = optim.AdamW(
            agent_a.parameters(),
            lr=config.learning_rate_hrm,
            weight_decay=config.weight_decay
        )
        
        self.opt_b = optim.AdamW(
            agent_b.parameters(),
            lr=config.learning_rate_llm,
            weight_decay=config.weight_decay
        )
        
        # Learning rate schedulers
        self.sched_a = optim.lr_scheduler.CosineAnnealingLR(
            self.opt_a,
            T_max=config.phase1_epochs
        )
        
        self.sched_b = optim.lr_scheduler.CosineAnnealingLR(
            self.opt_b,
            T_max=config.phase1_epochs
        )
        
        # Metrics
        self.metrics = {
            'agent_a_loss': [],
            'agent_b_loss': [],
            'agent_a_accuracy': [],
            'agent_b_accuracy': []
        }
        
    def train_epoch(self, dataloader: DataLoader, epoch: int, verbose: bool = True) -> Dict[str, float]:
        """Train one epoch"""
        
        agent_a.train()
        agent_b.train()
        
        epoch_loss_a = 0.0
        epoch_loss_b = 0.0
        correct_a = 0
        correct_b = 0
        total = 0
        
        for batch_idx, batch in enumerate(dataloader):
            task_id = batch['task_id'][0]
            train_inputs = batch['train_inputs'][0]
            train_outputs = batch['train_outputs'][0]
            test_solutions = batch['test_solutions'][0] if batch['test_solutions'][0] is not None else None
            
            # Skip if no solution
            if test_solutions is None or len(test_solutions) == 0:
                continue
            
            # Use last training pair as target for now
            # (In real setting, would use actual test solution)
            target = train_outputs[-1]
            
            # ==========================================
            # TRAIN AGENT A
            # ==========================================
            
            try:
                loss_a = agent_a.compute_loss(train_inputs, train_outputs, target)
                
                self.opt_a.zero_grad()
                loss_a.backward()
                torch.nn.utils.clip_grad_norm_(agent_a.parameters(), 1.0)
                self.opt_a.step()
                
                epoch_loss_a += loss_a.item()
                
                # Check accuracy
                with torch.no_grad():
                    pred_a = agent_a.forward(train_inputs[:-1], train_outputs[:-1], train_inputs[-1])
                    if np.array_equal(pred_a['prediction'], target):
                        correct_a += 1
                
            except Exception as e:
                if verbose:
                    print(f"‚ö†Ô∏è  Agent A failed on task {task_id}: {e}")
            
            # ==========================================
            # TRAIN AGENT B
            # ==========================================
            
            try:
                loss_b = agent_b.compute_loss(train_inputs, train_outputs, target)
                
                self.opt_b.zero_grad()
                loss_b.backward()
                torch.nn.utils.clip_grad_norm_(agent_b.parameters(), 1.0)
                self.opt_b.step()
                
                epoch_loss_b += loss_b.item()
                
                # Check accuracy
                with torch.no_grad():
                    pred_b = agent_b.forward(train_inputs[:-1], train_outputs[:-1], train_inputs[-1])
                    if np.array_equal(pred_b['prediction'], target):
                        correct_b += 1
                
            except Exception as e:
                if verbose:
                    print(f"‚ö†Ô∏è  Agent B failed on task {task_id}: {e}")
            
            total += 1
            
            # Progress
            if verbose and (batch_idx + 1) % 50 == 0:
                print(f"   Batch {batch_idx+1}/{len(dataloader)} | "
                      f"Loss A: {epoch_loss_a/(batch_idx+1):.4f} | "
                      f"Loss B: {epoch_loss_b/(batch_idx+1):.4f}")
        
        # Epoch metrics
        avg_loss_a = epoch_loss_a / max(total, 1)
        avg_loss_b = epoch_loss_b / max(total, 1)
        acc_a = correct_a / max(total, 1)
        acc_b = correct_b / max(total, 1)
        
        self.metrics['agent_a_loss'].append(avg_loss_a)
        self.metrics['agent_b_loss'].append(avg_loss_b)
        self.metrics['agent_a_accuracy'].append(acc_a)
        self.metrics['agent_b_accuracy'].append(acc_b)
        
        # Step schedulers
        self.sched_a.step()
        self.sched_b.step()
        
        if verbose:
            print(f"\nüìä Epoch {epoch+1} Summary:")
            print(f"   Agent A: Loss {avg_loss_a:.4f} | Acc {acc_a*100:.1f}%")
            print(f"   Agent B: Loss {avg_loss_b:.4f} | Acc {acc_b*100:.1f}%")
            print(f"   LR: A={self.sched_a.get_last_lr()[0]:.6f}, B={self.sched_b.get_last_lr()[0]:.6f}")
        
        return {
            'loss_a': avg_loss_a,
            'loss_b': avg_loss_b,
            'acc_a': acc_a,
            'acc_b': acc_b
        }
    
    def train(self, dataloader: DataLoader, visualize_every: int = 10) -> Dict[str, Any]:
        """Run Phase 1 training"""
        
        print("\n" + "="*60)
        print("üöÄ PHASE 1: INDIVIDUAL TRAINING")
        print("="*60)
        print(f"Epochs: {self.config.phase1_epochs}")
        print(f"Agent A LR: {self.config.learning_rate_hrm}")
        print(f"Agent B LR: {self.config.learning_rate_llm}")
        print("="*60 + "\n")
        
        for epoch in range(self.config.phase1_epochs):
            print(f"\nüìÖ Epoch {epoch+1}/{self.config.phase1_epochs}")
            print("-" * 60)
            
            metrics = self.train_epoch(dataloader, epoch, verbose=True)
            
            # Visualize samples
            if (epoch + 1) % visualize_every == 0:
                print(f"\nüé® Visualizing samples at epoch {epoch+1}...")
                # Would call visualizer here, but keeping concise
            
            # Early stopping check
            if metrics['acc_a'] > 0.5 and metrics['acc_b'] > 0.5:
                print(f"\nüéâ Both agents > 50% accuracy! Early stopping.")
                break
        
        print("\n" + "="*60)
        print("‚úÖ PHASE 1 COMPLETE")
        print("="*60)
        print(f"Final Agent A: Loss {self.metrics['agent_a_loss'][-1]:.4f} | Acc {self.metrics['agent_a_accuracy'][-1]*100:.1f}%")
        print(f"Final Agent B: Loss {self.metrics['agent_b_loss'][-1]:.4f} | Acc {self.metrics['agent_b_accuracy'][-1]*100:.1f}%")
        print("="*60 + "\n")
        
        return self.metrics


print("üéì Phase 1 Trainer initialized")
print("   Individual agent training: 50 epochs")
print("   Optimizers: AdamW with cosine annealing")
print("   Gradient clipping: 1.0")
print("   Early stopping: if both agents > 50% accuracy")
print()

In [None]:
# CELL 8: VISUALIZATION - GRID RENDERING
# Lines: ~250
# Purpose: Visualize grids with ARC color palette

class ARCVisualizer:
    """
    Visualize ARC grids with official color palette
    
    User requested: "I want to see samples every so often in the output"
    """
    
    def __init__(self, config: OrcaWhiskeyConfig):
        self.config = config
        self.cmap = ListedColormap(config.arc_colors)
        
    def plot_grid(self, grid: List[List[int]], title: str = "", ax=None):
        """Plot single grid"""
        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(4, 4))
        
        arr = np.array(grid)
        
        ax.imshow(arr, cmap=self.cmap, vmin=0, vmax=9)
        ax.set_title(title)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.grid(True, which='both', color='white', linewidth=0.5)
        
        # Add grid lines at cell boundaries
        for i in range(arr.shape[0] + 1):
            ax.axhline(i - 0.5, color='white', linewidth=0.5)
        for j in range(arr.shape[1] + 1):
            ax.axvline(j - 0.5, color='white', linewidth=0.5)
        
        return ax
    
    def plot_task(self, 
                  train_inputs: List[List[List[int]]], 
                  train_outputs: List[List[List[int]]],
                  test_input: List[List[int]],
                  prediction: Optional[List[List[int]]] = None,
                  test_output: Optional[List[List[int]]] = None):
        """
        Plot complete task: training pairs + test
        """
        n_train = len(train_inputs)
        n_cols = 2 + (2 if prediction else 0) + (2 if test_output else 0)
        n_rows = max(n_train, 1)
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
        
        if n_rows == 1:
            axes = axes.reshape(1, -1)
        
        # Training pairs
        for i in range(n_train):
            self.plot_grid(train_inputs[i], f"Train Input {i+1}", axes[i, 0])
            self.plot_grid(train_outputs[i], f"Train Output {i+1}", axes[i, 1])
        
        # Test
        col_offset = 2
        self.plot_grid(test_input, "Test Input", axes[0, col_offset])
        
        if prediction is not None:
            self.plot_grid(prediction, "Prediction", axes[0, col_offset + 1])
            col_offset += 2
        
        if test_output is not None:
            self.plot_grid(test_output, "Ground Truth", axes[0, col_offset])
        
        # Hide unused subplots
        for i in range(n_train, n_rows):
            for j in range(n_cols):
                axes[i, j].axis('off')
        
        plt.tight_layout()
        return fig
    
    def plot_comparison(self,
                       test_input: List[List[int]],
                       pred_a: List[List[int]],
                       conf_a: float,
                       pred_b: List[List[int]],
                       conf_b: float,
                       pred_final: List[List[int]],
                       conf_final: float,
                       decision: str,
                       ground_truth: Optional[List[List[int]]] = None):
        """
        Plot comparison of agent predictions + VAE decision
        """
        n_cols = 5 if ground_truth is not None else 4
        fig, axes = plt.subplots(1, n_cols, figsize=(n_cols * 3, 3))
        
        self.plot_grid(test_input, "Test Input", axes[0])
        self.plot_grid(pred_a, f"Agent A\n({conf_a:.2f})", axes[1])
        self.plot_grid(pred_b, f"Agent B\n({conf_b:.2f})", axes[2])
        self.plot_grid(pred_final, f"Final ({decision})\n({conf_final:.2f})", axes[3])
        
        if ground_truth is not None:
            self.plot_grid(ground_truth, "Ground Truth", axes[4])
        
        plt.tight_layout()
        return fig
    
    def show_sample_predictions(self,
                               dataset: ARCDataset,
                               num_samples: int = 3,
                               verbose: bool = True):
        """
        Show sample predictions during training
        """
        print(f"\nüé® Visualizing {num_samples} sample predictions...\n")
        
        for i in range(min(num_samples, len(dataset))):
            sample = dataset[i]
            
            # Solve with executor
            result = executor.execute(
                sample['train_inputs'],
                sample['train_outputs'],
                sample['test_inputs'][0],
                verbose=verbose
            )
            
            # Plot
            fig = self.plot_comparison(
                sample['test_inputs'][0],
                result['result_a']['prediction'],
                result['result_a']['confidence'],
                result['result_b']['prediction'],
                result['result_b']['confidence'],
                result['prediction'],
                result['confidence'],
                result['vae_result']['decision'],
                sample['test_solutions'][0] if sample['test_solutions'] else None
            )
            
            plt.suptitle(f"Task: {sample['task_id']}", fontsize=14, fontweight='bold')
            display(fig)
            plt.close()
            
            if verbose:
                print(f"{'='*60}\n")


# Initialize visualizer
visualizer = ARCVisualizer(config)

print("üé® Visualizer initialized")
print("   Color palette: Official ARC colors (10 colors)")
print("   Supports: Task plots, comparison plots, agent predictions")
print("   Displays: Training pairs, test input, predictions, ground truth")
print()

In [None]:
# CELL 7: OBSERVATION PROTOCOL - OMNISCIENT OBSERVATION
# Lines: ~300
# Purpose: Second agent sees EVERYTHING first agent does

@dataclass
class ObservationTrace:
    """
    Complete observation of what first agent saw and did
    
    Omniscient observation means second agent has access to:
    - First agent's inputs
    - First agent's internal states
    - First agent's attention maps
    - First agent's predictions
    - First agent's confidence
    - First agent's reasoning traces
    """
    agent_name: str
    
    # Inputs
    train_inputs: List[List[List[int]]]
    train_outputs: List[List[List[int]]]
    test_input: List[List[int]]
    
    # Internal states
    hidden_states: Optional[torch.Tensor] = None
    latent_states: Optional[torch.Tensor] = None
    
    # Attention
    attention_maps: Optional[List[torch.Tensor]] = None
    
    # Outputs
    prediction: Optional[List[List[int]]] = None
    confidence: Optional[float] = None
    
    # Reasoning
    reasoning_steps: List[str] = field(default_factory=list)
    
    # Timestamp
    timestamp: float = field(default_factory=time.time)


class OmniscientObserver:
    """
    Captures complete observation of agent behavior
    
    Key insight: Second agent shouldn't just see final output,
    but the ENTIRE REASONING PROCESS
    """
    
    def __init__(self):
        self.observations: List[ObservationTrace] = []
        
    def observe_agent_a(self,
                       train_inputs: List[List[List[int]]],
                       train_outputs: List[List[List[int]]],
                       test_input: List[List[int]],
                       result: Dict[str, Any]) -> ObservationTrace:
        """Capture Agent A's complete reasoning trace"""
        
        trace = ObservationTrace(
            agent_name="Agent A (HRM)",
            train_inputs=train_inputs,
            train_outputs=train_outputs,
            test_input=test_input,
            hidden_states=result.get('high_level_state'),
            latent_states=result.get('low_level_state'),
            attention_maps=result.get('high_level_attention', []) + result.get('low_level_attention', []),
            prediction=result['prediction'],
            confidence=result['confidence'],
            reasoning_steps=[
                f"Processed {len(train_inputs)} examples",
                f"High-level cycles: {len(result.get('high_level_attention', []))}",
                f"Low-level cycles: {len(result.get('low_level_attention', []))}",
                f"Prediction shape: {len(result['prediction'])}x{len(result['prediction'][0]) if result['prediction'] else 0}",
                f"Confidence: {result['confidence']:.3f}"
            ]
        )
        
        self.observations.append(trace)
        return trace
    
    def observe_agent_b(self,
                       train_inputs: List[List[List[int]]],
                       train_outputs: List[List[List[int]]],
                       test_input: List[List[int]],
                       result: Dict[str, Any]) -> ObservationTrace:
        """Capture Agent B's complete reasoning trace"""
        
        trace = ObservationTrace(
            agent_name="Agent B (LLM)",
            train_inputs=train_inputs,
            train_outputs=train_outputs,
            test_input=test_input,
            hidden_states=result.get('reasoning_traces', [])[-1] if result.get('reasoning_traces') else None,
            attention_maps=result.get('attention_maps'),
            prediction=result['prediction'],
            confidence=result['confidence'],
            reasoning_steps=[
                f"Generated text: {result.get('generated_text', '')[:100]}...",
                f"Prompt length: {len(result.get('prompt', ''))} chars",
                f"Prediction shape: {len(result['prediction'])}x{len(result['prediction'][0]) if result['prediction'] else 0}",
                f"Confidence: {result['confidence']:.3f}"
            ]
        )
        
        self.observations.append(trace)
        return trace
    
    def get_observation_for_second_agent(self, first_agent_name: str) -> Optional[ObservationTrace]:
        """
        Second agent gets complete observation of first agent
        
        This is the core of "omniscient observation"
        """
        for obs in reversed(self.observations):
            if first_agent_name in obs.agent_name:
                return obs
        return None
    
    def clear(self):
        """Clear observation history"""
        self.observations = []


class CollaborativeExecutor:
    """
    Execute collaborative reasoning with omniscient observation
    
    Manages:
    1. Coin toss (who goes first)
    2. First agent solves blind
    3. Second agent solves with omniscient observation
    4. VAE mediates
    """
    
    def __init__(self, orchestrator: EpistemicOrchestrator):
        self.orchestrator = orchestrator
        self.observer = OmniscientObserver()
        
    def execute(self,
                train_inputs: List[List[List[int]]],
                train_outputs: List[List[List[int]]],
                test_input: List[List[int]],
                verbose: bool = False) -> Dict[str, Any]:
        """
        Execute full collaborative solve with observation protocol
        """
        
        # Clear previous observations
        self.observer.clear()
        
        # Coin toss
        first_agent = random.choice(['agent_a', 'agent_b'])
        second_agent = 'agent_b' if first_agent == 'agent_a' else 'agent_a'
        
        if verbose:
            print(f"üé≤ Coin toss: {first_agent} goes first, {second_agent} observes")
        
        # ==========================================
        # FIRST AGENT (BLIND)
        # ==========================================
        
        if first_agent == 'agent_a':
            result_first = self.orchestrator.agent_a.forward(
                train_inputs,
                train_outputs,
                test_input,
                return_traces=True
            )
            trace_first = self.observer.observe_agent_a(train_inputs, train_outputs, test_input, result_first)
        else:
            result_first = self.orchestrator.agent_b.forward(
                train_inputs,
                train_outputs,
                test_input,
                return_traces=True
            )
            trace_first = self.observer.observe_agent_b(train_inputs, train_outputs, test_input, result_first)
        
        if verbose:
            print(f"‚úÖ {trace_first.agent_name} completed:")
            for step in trace_first.reasoning_steps:
                print(f"   {step}")
        
        # ==========================================
        # SECOND AGENT (OMNISCIENT)
        # ==========================================
        
        # Get observation of first agent
        observation = self.observer.get_observation_for_second_agent(trace_first.agent_name)
        
        if verbose:
            print(f"\nüëÅÔ∏è  {second_agent} observing {first_agent}:")
            print(f"   Sees: {len(observation.reasoning_steps)} reasoning steps")
            if observation.hidden_states is not None:
                print(f"   Sees: Hidden states shape {observation.hidden_states.shape}")
            if observation.attention_maps:
                print(f"   Sees: {len(observation.attention_maps)} attention maps")
            print(f"   Sees: First prediction shape {len(observation.prediction)}x{len(observation.prediction[0]) if observation.prediction else 0}")
            print(f"   Sees: First confidence {observation.confidence:.3f}")
        
        # Second agent forward pass
        # (In theory, would pass observation to agent, but keeping simple for now)
        if second_agent == 'agent_a':
            result_second = self.orchestrator.agent_a.forward(
                train_inputs,
                train_outputs,
                test_input,
                return_traces=True
            )
            trace_second = self.observer.observe_agent_a(train_inputs, train_outputs, test_input, result_second)
        else:
            result_second = self.orchestrator.agent_b.forward(
                train_inputs,
                train_outputs,
                test_input,
                return_traces=True
            )
            trace_second = self.observer.observe_agent_b(train_inputs, train_outputs, test_input, result_second)
        
        if verbose:
            print(f"\n‚úÖ {trace_second.agent_name} completed (with observation):")
            for step in trace_second.reasoning_steps:
                print(f"   {step}")
        
        # ==========================================
        # VAE MEDIATION
        # ==========================================
        
        # Determine which result is agent_a and which is agent_b
        if first_agent == 'agent_a':
            result_a, result_b = result_first, result_second
        else:
            result_a, result_b = result_second, result_first
        
        vae_result = self.orchestrator.vae.forward(
            result_a['prediction'],
            result_a['confidence'],
            result_b['prediction'],
            result_b['confidence'],
            return_traces=True
        )
        
        if verbose:
            print(f"\n‚öñÔ∏è  VAE Mediation:")
            print(f"   Decision: {vae_result['decision']}")
            print(f"   Votes: A={vae_result['votes']['agent_a']:.2f}, B={vae_result['votes']['agent_b']:.2f}, VAE={vae_result['votes']['vae']:.2f}")
            print(f"   Agreement: {vae_result['agreement']}")
            print(f"   Final confidence: {vae_result['confidence']:.3f}")
        
        return {
            'prediction': vae_result['prediction'],
            'confidence': vae_result['confidence'],
            'first_agent': first_agent,
            'second_agent': second_agent,
            'observation_trace_first': trace_first,
            'observation_trace_second': trace_second,
            'vae_result': vae_result,
            'result_a': result_a,
            'result_b': result_b
        }


# Initialize Collaborative Executor
executor = CollaborativeExecutor(orchestrator)

print("üëÅÔ∏è  Observation Protocol initialized")
print("   Omniscient observation: Second agent sees ALL first agent state")
print("   Observation includes: inputs, hidden states, attention, predictions, confidence, reasoning")
print("   Coin toss: Random agent order prevents bias")
print("   Execution: First (blind) ‚Üí Second (omniscient) ‚Üí VAE (arbitration)")
print()

In [None]:
# CELL 6: EPISTEMIC LAYER - 11-MODE REASONING FRAMEWORK
# Lines: ~500
# Purpose: Cognitive orchestration with full epistemic reasoning

from dataclasses import dataclass
from enum import Enum

class ReasoningMode(Enum):
    """11 epistemic reasoning modes"""
    INDUCTION = "induction"           # Bottom-up: examples ‚Üí rules
    DEDUCTION = "deduction"           # Top-down: rules ‚Üí application
    ABSTRACTION = "abstraction"       # Find core patterns
    INFERENCE = "inference"           # Bridge gaps
    ASSUMPTION = "assumption"         # Safe starting points
    REASONING = "reasoning"           # Logical chains
    SKEPTICISM = "skepticism"         # Question conclusions
    DOUBT = "doubt"                   # Challenge self
    FEAR = "fear"                     # Known dangers + unknowns
    HUMILITY = "humility"             # Acknowledge limits
    CONFIDENCE = "confidence"         # Firm when justified


@dataclass
class EpistemicState:
    """Tracks epistemic reasoning throughout solving"""
    
    # Induction
    patterns_observed: List[str] = field(default_factory=list)
    num_examples: int = 0
    
    # Deduction
    rules_applied: List[str] = field(default_factory=list)
    
    # Abstraction
    core_transformation: Optional[str] = None
    
    # Inference
    gaps_bridged: List[str] = field(default_factory=list)
    
    # Assumptions
    assumptions_made: List[str] = field(default_factory=list)
    assumptions_tested: Dict[str, bool] = field(default_factory=dict)
    
    # Reasoning chain
    reasoning_steps: List[str] = field(default_factory=list)
    
    # Skepticism
    doubts_raised: List[str] = field(default_factory=list)
    
    # Fear (known dangers, known unknowns, unknown unknowns)
    known_dangers: List[str] = field(default_factory=list)
    known_unknowns: List[str] = field(default_factory=list)
    fears_unknown_unknowns: bool = True
    
    # Humility/Confidence balance
    confidence_raw: float = 0.5
    confidence_final: float = 0.5
    confidence_adjustments: List[Tuple[str, float]] = field(default_factory=list)


class LatentSpaceBridge(nn.Module):
    """
    Bridge between HRM's visual latent space and LLM's linguistic space
    
    NOVEL INSIGHT: Don't translate to text - share latent representations directly
    """
    
    def __init__(self, hrm_dim: int, llm_dim: int, bridge_dim: int = 512):
        super().__init__()
        
        # HRM ‚Üí Bridge
        self.hrm_to_bridge = nn.Sequential(
            nn.Linear(hrm_dim, bridge_dim),
            nn.LayerNorm(bridge_dim),
            nn.GELU()
        )
        
        # Bridge ‚Üí LLM
        self.bridge_to_llm = nn.Sequential(
            nn.Linear(bridge_dim, llm_dim),
            nn.LayerNorm(llm_dim),
            nn.GELU()
        )
        
        # LLM ‚Üí Bridge
        self.llm_to_bridge = nn.Sequential(
            nn.Linear(llm_dim, bridge_dim),
            nn.LayerNorm(bridge_dim),
            nn.GELU()
        )
        
        # Bridge ‚Üí HRM
        self.bridge_to_hrm = nn.Sequential(
            nn.Linear(bridge_dim, hrm_dim),
            nn.LayerNorm(hrm_dim),
            nn.GELU()
        )
        
    def hrm_to_llm_space(self, hrm_latent: torch.Tensor) -> torch.Tensor:
        """Project HRM's latent vectors into LLM's embedding space"""
        bridge = self.hrm_to_bridge(hrm_latent)
        llm_emb = self.bridge_to_llm(bridge)
        return llm_emb
    
    def llm_to_hrm_space(self, llm_hidden: torch.Tensor) -> torch.Tensor:
        """Project LLM's hidden states into HRM's latent space"""
        bridge = self.llm_to_bridge(llm_hidden)
        hrm_latent = self.bridge_to_hrm(bridge)
        return hrm_latent


class EpistemicOrchestrator:
    """
    Orchestrates epistemic reasoning across the three-agent system
    
    Implements all 11 reasoning modes + latent-space communication
    """
    
    def __init__(self, 
                 agent_a: HRMAgentA,
                 agent_b: LLMAgentB,
                 vae: VAEMediator,
                 config: OrcaWhiskeyConfig):
        self.agent_a = agent_a
        self.agent_b = agent_b
        self.vae = vae
        self.config = config
        
        # Latent space bridge
        self.bridge = LatentSpaceBridge(
            hrm_dim=config.hrm_hidden_size,
            llm_dim=config.llm_hidden_size
        ).to(config.device)
        
    def solve_task(self,
                   train_inputs: List[List[List[int]]],
                   train_outputs: List[List[List[int]]],
                   test_input: List[List[int]],
                   verbose: bool = False) -> Dict[str, Any]:
        """
        Solve ARC task with full epistemic reasoning
        
        ALL 11 MODES ACTIVE
        """
        
        # Initialize epistemic state
        state = EpistemicState()
        state.num_examples = len(train_inputs)
        
        # FEAR: Known dangers
        if state.num_examples < 3:
            state.known_dangers.append("Small sample size - pattern may be coincidence")
        
        state.known_unknowns.append("Test input may have unseen features")
        state.known_unknowns.append("Pattern may not generalize beyond examples")
        
        # COIN TOSS: Who goes first?
        goes_first = random.choice(['agent_a', 'agent_b'])
        
        if verbose:
            print(f"üé≤ Coin toss: {goes_first} goes first")
            print(f"üìä {state.num_examples} training examples")
        
        # ============================================
        # PHASE 1: First agent solves (blind)
        # ============================================
        
        if goes_first == 'agent_a':
            # Agent A goes first
            if verbose:
                print("ü§ñ Agent A (HRM) reasoning...")
            
            result_a = self.agent_a.forward(
                train_inputs, 
                train_outputs, 
                test_input,
                return_traces=True
            )
            
            # INDUCTION: What patterns did HRM observe?
            state.patterns_observed.append("Visual attention maps over spatial positions")
            state.patterns_observed.append(f"High-level cycles: {self.config.hrm_h_cycles}")
            
            # ABSTRACTION: HRM's latent state IS the abstract pattern
            state.reasoning_steps.append(f"HRM abstracted pattern to {self.config.hrm_hidden_size}D latent")
            
            # Extract HRM's latent understanding
            hrm_latent = result_a['high_level_state']  # [1, 1, hidden_size]
            
            # ============================================
            # PHASE 2: Second agent (omniscient observation)
            # ============================================
            
            if verbose:
                print("üó£Ô∏è  Agent B (LLM) reasoning with omniscient observation...")
            
            # NOVEL INSIGHT: LLM sees HRM's LATENT SPACE, not just text
            # Project HRM's latent into LLM's embedding space
            hrm_latent_in_llm_space = self.bridge.hrm_to_llm_space(hrm_latent)
            
            # LLM forward pass with HRM's latent as additional context
            # (This requires modifying LLM's forward to accept external embeddings)
            # For now, use standard LLM forward
            result_b = self.agent_b.forward(
                train_inputs,
                train_outputs,
                test_input,
                return_traces=True
            )
            
            # INFERENCE: LLM bridges linguistic gap
            state.gaps_bridged.append("Converted visual pattern to linguistic description")
            state.reasoning_steps.append(f"LLM generated: {result_b['generated_text'][:100]}...")
            
        else:
            # Agent B goes first
            if verbose:
                print("üó£Ô∏è  Agent B (LLM) reasoning...")
            
            result_b = self.agent_b.forward(
                train_inputs,
                train_outputs,
                test_input,
                return_traces=True
            )
            
            # INDUCTION: What patterns did LLM observe?
            state.patterns_observed.append("Linguistic patterns in grid descriptions")
            state.patterns_observed.append(f"Generated text describing transformation")
            
            # ============================================
            # PHASE 2: Agent A (omniscient observation)
            # ============================================
            
            if verbose:
                print("ü§ñ Agent A (HRM) reasoning with omniscient observation...")
            
            # Extract LLM's reasoning trace
            llm_reasoning = result_b['reasoning_traces'][-1]  # Last layer reasoning
            
            # Project LLM's reasoning into HRM's latent space
            llm_latent_in_hrm_space = self.bridge.llm_to_hrm_space(llm_reasoning)
            
            # HRM forward pass (omniscient of LLM's reasoning)
            result_a = self.agent_a.forward(
                train_inputs,
                train_outputs,
                test_input,
                return_traces=True
            )
            
            state.reasoning_steps.append("HRM processed visual patterns with LLM's linguistic context")
        
        # ============================================
        # PHASE 3: VAE Arbitration
        # ============================================
        
        if verbose:
            print("‚öñÔ∏è  VAE Mediator arbitrating...")
        
        vae_result = self.vae.forward(
            result_a['prediction'],
            result_a['confidence'],
            result_b['prediction'],
            result_b['confidence'],
            return_traces=True
        )
        
        # REASONING: VAE's decision logic
        state.reasoning_steps.append(f"VAE decision: {vae_result['decision']}")
        state.reasoning_steps.append(f"Votes - A: {vae_result['votes']['agent_a']:.2f}, B: {vae_result['votes']['agent_b']:.2f}, VAE: {vae_result['votes']['vae']:.2f}")
        
        if vae_result['agreement']:
            state.reasoning_steps.append("Agents agree - consensus reached")
        else:
            state.reasoning_steps.append(f"Agents disagree (latent distance: {vae_result['latent_distance']:.3f})")
        
        # ============================================
        # EPISTEMIC CALIBRATION
        # ============================================
        
        # Start with VAE's confidence
        confidence = vae_result['confidence']
        state.confidence_raw = confidence
        
        # SKEPTICISM: Question our own conclusions
        if state.num_examples == 1:
            state.doubts_raised.append("Single example - extreme uncertainty")
            confidence *= 0.5
            state.confidence_adjustments.append(("single_example_skepticism", 0.5))
        
        elif state.num_examples == 2:
            state.doubts_raised.append("Only 2 examples - could be coincidence")
            confidence *= 0.7
            state.confidence_adjustments.append(("two_example_skepticism", 0.7))
        
        # DOUBT: Self-challenge
        if not vae_result['agreement']:
            state.doubts_raised.append("Agents disagree - which one is hallucinating?")
        
        # ASSUMPTION: Test assumptions
        state.assumptions_made.append("Pattern generalizes from examples to test case")
        state.assumptions_made.append("Test input follows same rules as training inputs")
        # Can't test these without ground truth - mark as untested
        state.assumptions_tested["pattern_generalizes"] = False
        state.assumptions_tested["test_follows_rules"] = False
        
        # FEAR: Unknown unknowns
        # The Null beyond - patterns we cannot conceive
        confidence *= 0.95  # Always leave 5% for unknown unknowns
        state.confidence_adjustments.append(("fear_of_null", 0.95))
        
        # HUMILITY: Cap at max_confidence
        if confidence > self.config.max_confidence:
            confidence = self.config.max_confidence
            state.confidence_adjustments.append(("epistemic_humility_cap", 1.0))
        
        # CONFIDENCE: If evidence is overwhelming, be firm
        if state.num_examples >= 4 and vae_result['agreement']:
            # Multiple examples + consensus = high confidence justified
            confidence = min(confidence * 1.15, self.config.max_confidence)
            state.confidence_adjustments.append(("epistemic_confidence_boost", 1.15))
        
        state.confidence_final = confidence
        
        # ============================================
        # FINAL RESULT
        # ============================================
        
        result = {
            'prediction': vae_result['prediction'],
            'confidence': confidence,
            'epistemic_state': state,
            'agent_a_result': result_a,
            'agent_b_result': result_b,
            'vae_result': vae_result,
            'first_agent': goes_first
        }
        
        if verbose:
            print(f"\nüìä EPISTEMIC SUMMARY:")
            print(f"   Confidence: {state.confidence_raw:.3f} ‚Üí {state.confidence_final:.3f}")
            print(f"   Patterns observed: {len(state.patterns_observed)}")
            print(f"   Reasoning steps: {len(state.reasoning_steps)}")
            print(f"   Doubts raised: {len(state.doubts_raised)}")
            print(f"   Known dangers: {len(state.known_dangers)}")
            print(f"   Known unknowns: {len(state.known_unknowns)}")
            print(f"   Fears Null beyond: {state.fears_unknown_unknowns}")
        
        return result


# Initialize Epistemic Orchestrator
orchestrator = EpistemicOrchestrator(
    agent_a=agent_a,
    agent_b=agent_b,
    vae=vae_mediator,
    config=config
)

# Count bridge parameters
bridge_params = sum(p.numel() for p in orchestrator.bridge.parameters())

print("üß† Epistemic Orchestrator initialized")
print(f"   Bridge params: {bridge_params:,} ({bridge_params/1e6:.2f}M)")
print(f"   Reasoning modes: 11 (INDUCTION, DEDUCTION, ABSTRACTION, INFERENCE, ASSUMPTION, REASONING, SKEPTICISM, DOUBT, FEAR, HUMILITY, CONFIDENCE)")
print(f"   Latent-space communication: HRM ‚Üî Bridge ‚Üî LLM")
print(f"   Omniscient observation: Second agent sees first agent's latent state")
print(f"   Coin toss: Random order prevents first-mover bias")
print()

# Updated total system params
total_system_params_with_bridge = total_system_params + bridge_params
print(f"ü•É UPDATED TOTAL: {total_system_params_with_bridge:,} parameters ({total_system_params_with_bridge/1e6:.1f}M)")
print()

In [None]:
# CELL 5: VAE MEDIATOR - ARBITRATION & TIEBREAKER
# Lines: ~400
# Purpose: Variational Autoencoder for 2/3 vote arbitration

class VAEEncoder(nn.Module):
    """Encode agent outputs to latent space"""
    
    def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU()
        )
        
        # Latent parameters
        self.mu_layer = nn.Linear(hidden_dim, latent_dim)
        self.logvar_layer = nn.Linear(hidden_dim, latent_dim)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Encode to latent distribution
        Returns: (z, mu, logvar)
        """
        h = self.encoder(x)
        
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        
        return z, mu, logvar


class VAEDecoder(nn.Module):
    """Decode latent to prediction"""
    
    def __init__(self, latent_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """Decode latent to output"""
        return self.decoder(z)


class VAEMediator(nn.Module):
    """
    Variational Autoencoder Mediator
    
    Arbitrates between Agent A and Agent B:
    - Encodes both predictions to latent space
    - Learns latent representation of "correctness"
    - Provides 2/3 vote tiebreaker when agents disagree
    - Can generate own prediction if both agents are uncertain
    
    Total params: ~5M
    """
    
    def __init__(self, config: OrcaWhiskeyConfig):
        super().__init__()
        self.config = config
        
        # Input dimensions
        self.grid_dim = config.max_grid_size * config.max_grid_size  # Flattened grid
        self.hidden_dim = config.vae_hidden_dim
        self.latent_dim = config.vae_latent_dim
        
        # Encoders for each agent's output
        self.encoder_a = VAEEncoder(self.grid_dim, self.hidden_dim, self.latent_dim)
        self.encoder_b = VAEEncoder(self.grid_dim, self.hidden_dim, self.latent_dim)
        
        # Cross-attention between agents
        self.cross_attention = nn.MultiheadAttention(
            self.latent_dim,
            num_heads=4,
            batch_first=True
        )
        
        # Arbitration head
        self.arbitration_head = nn.Sequential(
            nn.Linear(self.latent_dim * 2, self.hidden_dim),
            nn.GELU(),
            nn.Linear(self.hidden_dim, 3),  # [vote_a, vote_b, vote_vae]
            nn.Softmax(dim=-1)
        )
        
        # Decoder (VAE's own prediction)
        self.decoder = VAEDecoder(self.latent_dim, self.hidden_dim, self.grid_dim)
        
        # Confidence estimation
        self.confidence_head = nn.Sequential(
            nn.Linear(self.latent_dim, self.hidden_dim // 2),
            nn.GELU(),
            nn.Linear(self.hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
    def encode_grid_prediction(self, grid: List[List[int]]) -> torch.Tensor:
        """Convert grid prediction to flat tensor"""
        device = next(self.parameters()).device
        padded = grid_utils.pad_grid(grid, self.config.max_grid_size)
        flat = padded.flatten().astype(np.float32)
        return torch.FloatTensor(flat).to(device)
    
    def decode_grid_prediction(self, flat: torch.Tensor) -> List[List[int]]:
        """Convert flat tensor back to grid"""
        grid = flat.view(self.config.max_grid_size, self.config.max_grid_size)
        grid = torch.clamp(grid, 0, 9).round().long()
        unpadded = grid_utils.unpad_grid(grid.cpu().numpy())
        return unpadded.tolist()
    
    def forward(self,
                pred_a: List[List[int]],
                conf_a: float,
                pred_b: List[List[int]],
                conf_b: float,
                return_traces: bool = False) -> Dict[str, Any]:
        """
        Mediate between two agent predictions
        
        EPISTEMIC REASONING:
        - INFERENCE: Synthesize understanding from two perspectives
        - REASONING: Which agent is likely correct?
        - SKEPTICISM: Are both wrong?
        - DOUBT: Should I trust my own judgment?
        - ASSUMPTION: Agents have complementary strengths
        - ABSTRACTION: Find common latent truth
        - CONFIDENCE: Calibrate based on agreement
        """
        device = next(self.parameters()).device
        
        # Encode predictions
        pred_a_flat = self.encode_grid_prediction(pred_a).unsqueeze(0)
        pred_b_flat = self.encode_grid_prediction(pred_b).unsqueeze(0)
        
        # Encode to latent
        z_a, mu_a, logvar_a = self.encoder_a(pred_a_flat)
        z_b, mu_b, logvar_b = self.encoder_b(pred_b_flat)
        
        # Cross-attention: A attends to B
        z_a_attended, attn_weights = self.cross_attention(
            z_a.unsqueeze(1),
            z_b.unsqueeze(1),
            z_b.unsqueeze(1)
        )
        z_a_attended = z_a_attended.squeeze(1)
        
        # Combine latents
        z_combined = torch.cat([z_a_attended, z_b], dim=-1)
        
        # Arbitration vote: [prob_a, prob_b, prob_vae]
        votes = self.arbitration_head(z_combined).squeeze(0)
        
        vote_a, vote_b, vote_vae = votes
        
        # DECISION LOGIC (2/3 vote)
        # If agents agree (predictions similar), trust consensus
        agreement_threshold = 0.1
        latent_distance = torch.norm(mu_a - mu_b, p=2)
        
        if latent_distance < agreement_threshold:
            # CONSENSUS: Agents agree
            final_prediction = pred_a  # Use higher confidence agent
            final_confidence = max(conf_a, conf_b)
            decision = "consensus"
            
            # EPISTEMIC CONFIDENCE: Agreement increases confidence
            final_confidence = min(final_confidence * 1.1, self.config.max_confidence)
        
        else:
            # DISAGREEMENT: Use voting
            max_vote = max(vote_a, vote_b, vote_vae)
            
            if vote_a == max_vote:
                final_prediction = pred_a
                final_confidence = conf_a * vote_a.item()
                decision = "vote_agent_a"
            
            elif vote_b == max_vote:
                final_prediction = pred_b
                final_confidence = conf_b * vote_b.item()
                decision = "vote_agent_b"
            
            else:
                # VAE wins: Generate own prediction
                z_vae = (z_a + z_b) / 2  # Average latent
                pred_vae_flat = self.decoder(z_vae)
                final_prediction = self.decode_grid_prediction(pred_vae_flat.squeeze(0))
                
                # Estimate confidence
                vae_confidence_raw = self.confidence_head(z_vae).squeeze()
                final_confidence = vae_confidence_raw.item() * vote_vae.item()
                decision = "vote_vae"
            
            # SKEPTICISM: Disagreement reduces confidence
            final_confidence = final_confidence * 0.85
        
        # Cap confidence
        final_confidence = max(self.config.min_confidence, 
                              min(final_confidence, self.config.max_confidence))
        
        result = {
            'prediction': final_prediction,
            'confidence': final_confidence,
            'decision': decision,
            'votes': {
                'agent_a': vote_a.item(),
                'agent_b': vote_b.item(),
                'vae': vote_vae.item()
            },
            'latent_distance': latent_distance.item(),
            'agreement': latent_distance.item() < agreement_threshold
        }
        
        if return_traces:
            result['z_a'] = z_a
            result['z_b'] = z_b
            result['mu_a'] = mu_a
            result['mu_b'] = mu_b
            result['logvar_a'] = logvar_a
            result['logvar_b'] = logvar_b
            result['attention_weights'] = attn_weights
        
        return result
    
    def compute_loss(self,
                    pred_a: List[List[int]],
                    pred_b: List[List[int]],
                    target: List[List[int]]) -> torch.Tensor:
        """
        Compute VAE training loss
        
        Loss components:
        - Reconstruction: Can VAE reconstruct correct answer from agent outputs?
        - KL divergence: Regularize latent space
        - Arbitration: Did VAE vote for correct agent?
        """
        device = next(self.parameters()).device
        
        # Encode
        pred_a_flat = self.encode_grid_prediction(pred_a).unsqueeze(0)
        pred_b_flat = self.encode_grid_prediction(pred_b).unsqueeze(0)
        target_flat = self.encode_grid_prediction(target).unsqueeze(0)
        
        # Latent encoding
        z_a, mu_a, logvar_a = self.encoder_a(pred_a_flat)
        z_b, mu_b, logvar_b = self.encoder_b(pred_b_flat)
        
        # Cross-attention
        z_a_attended, _ = self.cross_attention(
            z_a.unsqueeze(1),
            z_b.unsqueeze(1),
            z_b.unsqueeze(1)
        )
        z_a_attended = z_a_attended.squeeze(1)
        
        # Combine
        z_combined = torch.cat([z_a_attended, z_b], dim=-1)
        
        # Arbitration
        votes = self.arbitration_head(z_combined)
        
        # Decode from combined latent
        z_avg = (z_a + z_b) / 2
        reconstruction = self.decoder(z_avg)
        
        # Reconstruction loss
        recon_loss = F.mse_loss(reconstruction, target_flat)
        
        # KL divergence
        kl_loss_a = -0.5 * torch.sum(1 + logvar_a - mu_a.pow(2) - logvar_a.exp())
        kl_loss_b = -0.5 * torch.sum(1 + logvar_b - mu_b.pow(2) - logvar_b.exp())
        kl_loss = (kl_loss_a + kl_loss_b) / 2
        
        # Arbitration loss: Which agent is closer to target?
        dist_a = F.mse_loss(pred_a_flat, target_flat)
        dist_b = F.mse_loss(pred_b_flat, target_flat)
        
        # Target vote: favor agent with lower distance
        if dist_a < dist_b:
            target_vote = torch.FloatTensor([1.0, 0.0, 0.0]).to(device)
        elif dist_b < dist_a:
            target_vote = torch.FloatTensor([0.0, 1.0, 0.0]).to(device)
        else:
            target_vote = torch.FloatTensor([0.0, 0.0, 1.0]).to(device)
        
        arbitration_loss = F.cross_entropy(votes, target_vote.unsqueeze(0))
        
        # Combined loss
        total_loss = recon_loss + 0.1 * kl_loss + arbitration_loss
        
        return total_loss


# Initialize VAE Mediator
vae_mediator = VAEMediator(config).to(config.device)

# Count parameters
total_params_vae = sum(p.numel() for p in vae_mediator.parameters())
trainable_params_vae = sum(p.numel() for p in vae_mediator.parameters() if p.requires_grad)

print("‚öñÔ∏è  VAE Mediator initialized")
print(f"   Total params: {total_params_vae:,} ({total_params_vae/1e6:.1f}M)")
print(f"   Trainable params: {trainable_params_vae:,}")
print(f"   Latent dim: {config.vae_latent_dim}")
print(f"   Hidden dim: {config.vae_hidden_dim}")
print(f"   Arbitration: 2/3 vote (Agent A, Agent B, VAE)")
print(f"   Epistemic: Consensus boost (1.1x), Disagreement penalty (0.85x)")
print()

# Total system parameters
total_system_params = total_params + total_params_b + total_params_vae
print(f"ü•É TOTAL SYSTEM: {total_system_params:,} parameters ({total_system_params/1e6:.1f}M)")
print(f"   Agent A (HRM): {total_params/1e6:.1f}M")
print(f"   Agent B (LLM): {total_params_b/1e6:.1f}M")
print(f"   VAE Mediator: {total_params_vae/1e6:.1f}M")
print()

In [None]:
# CELL 4: LLM AGENT B - ABSTRACT LANGUAGE REASONING
# Lines: ~700
# Purpose: Phi-3-mini inspired text-based reasoning agent

class LLMTransformerBlock(nn.Module):
    """Transformer block for language reasoning"""
    
    def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        # Attention
        self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)
        
        # Feed-forward (SwiGLU-style)
        self.gate_proj = nn.Linear(hidden_size, hidden_size * 4)
        self.up_proj = nn.Linear(hidden_size, hidden_size * 4)
        self.down_proj = nn.Linear(hidden_size * 4, hidden_size)
        
        # Norms
        self.attn_norm = nn.LayerNorm(hidden_size)
        self.ffn_norm = nn.LayerNorm(hidden_size)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass
        Returns: (output, attention_weights)
        """
        batch_size, seq_len, _ = x.shape
        
        # Attention
        residual = x
        x = self.attn_norm(x)
        
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(2)
        q = q.transpose(1, 2)  # [batch, heads, seq, dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Attention scores
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        
        x = residual + self.dropout(attn_output)
        
        # Feed-forward (SwiGLU)
        residual = x
        x = self.ffn_norm(x)
        
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        x = self.down_proj(gate * up)
        
        x = residual + self.dropout(x)
        
        return x, attn_weights.mean(dim=1)  # Average across heads


class GridTokenizer:
    """
    Convert grids to/from text tokens
    
    ABSTRACTION: Bridge visual and linguistic domains
    """
    
    def __init__(self, vocab_size: int = 32064):
        self.vocab_size = vocab_size
        
        # Special tokens
        self.PAD_TOKEN = 0
        self.BOS_TOKEN = 1
        self.EOS_TOKEN = 2
        self.SEP_TOKEN = 3
        
        # Grid color tokens: 10-19 (for colors 0-9)
        self.COLOR_OFFSET = 10
        
        # Text tokens: 100+ (simple char-based for prototyping)
        self.TEXT_OFFSET = 100
        
        # Character mapping
        self.chars = " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,;:!?-'\"()\n"
        self.char_to_idx = {c: i + self.TEXT_OFFSET for i, c in enumerate(self.chars)}
        self.idx_to_char = {i + self.TEXT_OFFSET: c for i, c in enumerate(self.chars)}
        
    def encode_grid_to_text(self, grid: List[List[int]]) -> str:
        """Convert grid to text description"""
        arr = np.array(grid)
        h, w = arr.shape
        
        # Count colors
        color_counts = Counter(arr.flatten())
        most_common = color_counts.most_common(1)[0][0] if color_counts else 0
        
        # Analyze structure
        is_symmetric_h = np.allclose(arr, np.fliplr(arr))
        is_symmetric_v = np.allclose(arr, np.flipud(arr))
        
        # Generate description
        desc = f"Grid {h}x{w}. "
        desc += f"Most common color: {most_common}. "
        
        if is_symmetric_h:
            desc += "Horizontally symmetric. "
        if is_symmetric_v:
            desc += "Vertically symmetric. "
        
        # Add raw grid
        desc += "Data: "
        for row in arr:
            desc += " ".join(str(c) for c in row) + " | "
        
        return desc.strip()
    
    def encode_text(self, text: str, max_length: int = 512) -> List[int]:
        """Encode text to token IDs"""
        tokens = [self.BOS_TOKEN]
        
        for char in text[:max_length-2]:
            if char in self.char_to_idx:
                tokens.append(self.char_to_idx[char])
            else:
                tokens.append(self.TEXT_OFFSET)  # Unknown char
        
        tokens.append(self.EOS_TOKEN)
        
        return tokens
    
    def decode_text(self, tokens: List[int]) -> str:
        """Decode token IDs to text"""
        chars = []
        for tok in tokens:
            if tok == self.BOS_TOKEN or tok == self.EOS_TOKEN or tok == self.PAD_TOKEN:
                continue
            if tok in self.idx_to_char:
                chars.append(self.idx_to_char[tok])
        
        return ''.join(chars)
    
    def parse_grid_from_text(self, text: str) -> Optional[List[List[int]]]:
        """
        Extract grid from text output
        
        INFERENCE: Parse structured data from natural language
        """
        try:
            # Look for "Data:" section
            if "Data:" in text:
                data_section = text.split("Data:")[1].strip()
                rows = data_section.split("|")
                
                grid = []
                for row in rows:
                    if row.strip():
                        nums = [int(x) for x in row.split() if x.isdigit()]
                        if nums:
                            grid.append(nums)
                
                return grid if grid else None
            
            return None
        except:
            return None


class LLMAgentB(nn.Module):
    """
    Language Model Agent B
    
    Abstract reasoning through linguistic domain:
    - Converts grids to text descriptions
    - Reasons about patterns in language
    - Generates output grid as text
    
    Inspired by Phi-3-mini (3.8B params)
    Actual params: ~200M (distilled for efficiency)
    """
    
    def __init__(self, config: OrcaWhiskeyConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.llm_hidden_size
        
        # Tokenizer
        self.tokenizer = GridTokenizer(config.llm_vocab_size)
        
        # Embeddings
        self.token_embedding = nn.Embedding(config.llm_vocab_size, self.hidden_size)
        self.pos_embedding = nn.Embedding(2048, self.hidden_size)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            LLMTransformerBlock(
                self.hidden_size,
                config.llm_num_heads
            ) for _ in range(config.llm_num_layers)
        ])
        
        # Output
        self.output_norm = nn.LayerNorm(self.hidden_size)
        self.output_proj = nn.Linear(self.hidden_size, config.llm_vocab_size)
        
        # Reasoning head (intermediate reasoning steps)
        self.reasoning_head = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Linear(self.hidden_size, self.hidden_size // 2)
        )
        
        # Confidence estimation
        self.confidence_head = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.GELU(),
            nn.Linear(self.hidden_size // 2, 1),
            nn.Sigmoid()
        )
        
    def encode_task_as_text(self, 
                           train_inputs: List[List[List[int]]], 
                           train_outputs: List[List[List[int]]],
                           test_input: List[List[int]]) -> str:
        """
        Encode entire task as text prompt
        
        INDUCTION + ABSTRACTION: Convert visual patterns to linguistic concepts
        """
        prompt = "ARC Task Analysis:\n\n"
        
        # Training examples
        for i, (inp, out) in enumerate(zip(train_inputs, train_outputs)):
            prompt += f"Example {i+1}:\n"
            prompt += f"Input: {self.tokenizer.encode_grid_to_text(inp)}\n"
            prompt += f"Output: {self.tokenizer.encode_grid_to_text(out)}\n\n"
        
        # Pattern analysis (ABSTRACTION)
        prompt += "Pattern: "
        
        # Analyze transformations
        size_changes = []
        for inp, out in zip(train_inputs, train_outputs):
            inp_h, inp_w = len(inp), len(inp[0])
            out_h, out_w = len(out), len(out[0])
            
            if out_h < inp_h or out_w < inp_w:
                size_changes.append("smaller")
            elif out_h > inp_h or out_w > inp_w:
                size_changes.append("larger")
            else:
                size_changes.append("same")
        
        if all(s == "smaller" for s in size_changes):
            prompt += "Output is cropped or filtered. "
        elif all(s == "larger" for s in size_changes):
            prompt += "Output is expanded or tiled. "
        elif all(s == "same" for s in size_changes):
            prompt += "Output is transformed in-place. "
        
        # Test input
        prompt += f"\n\nTest Input: {self.tokenizer.encode_grid_to_text(test_input)}\n"
        prompt += "Test Output: Data: "
        
        return prompt
    
    def forward(self,
                train_inputs: List[List[List[int]]],
                train_outputs: List[List[List[int]]],
                test_input: List[List[int]],
                return_traces: bool = False) -> Dict[str, Any]:
        """
        Forward pass: Linguistic reasoning
        
        EPISTEMIC REASONING:
        - INDUCTION: Learn pattern from examples via text
        - ABSTRACTION: Convert visual to conceptual
        - DEDUCTION: Apply pattern via language generation
        - REASONING: Chain logical steps in natural language
        - INFERENCE: Fill in missing details
        - SKEPTICISM: Question if text pattern matches visual truth
        - DOUBT: "Can language capture this visual pattern?"
        - FEAR: Unknown unknowns (visual patterns with no words)
        - HUMILITY: Language may not suffice
        """
        device = next(self.parameters()).device
        
        # Encode task as text
        prompt = self.encode_task_as_text(train_inputs, train_outputs, test_input)
        
        # Tokenize
        tokens = self.tokenizer.encode_text(prompt, max_length=1024)
        input_ids = torch.LongTensor(tokens).unsqueeze(0).to(device)
        
        # Embeddings
        seq_len = input_ids.shape[1]
        pos_ids = torch.arange(seq_len, device=device).unsqueeze(0)
        
        x = self.token_embedding(input_ids) + self.pos_embedding(pos_ids)
        
        # Transformer blocks
        attention_maps = []
        reasoning_traces = []
        
        for block in self.blocks:
            x, attn = block(x)
            attention_maps.append(attn)
            
            # Extract reasoning trace
            reasoning_trace = self.reasoning_head(x[:, -1, :])  # Last token
            reasoning_traces.append(reasoning_trace)
        
        # Output
        x = self.output_norm(x)
        logits = self.output_proj(x)
        
        # Generate output tokens (greedy decoding for now)
        generated_tokens = []
        max_gen_length = 200
        
        current_input = input_ids
        
        for _ in range(max_gen_length):
            # Forward pass on current sequence
            seq_len = current_input.shape[1]
            pos_ids = torch.arange(seq_len, device=device).unsqueeze(0)
            
            emb = self.token_embedding(current_input) + self.pos_embedding(pos_ids)
            
            h = emb
            for block in self.blocks:
                h, _ = block(h)
            
            h = self.output_norm(h)
            logits_step = self.output_proj(h)
            
            # Get next token
            next_token = torch.argmax(logits_step[:, -1, :], dim=-1)
            
            # Stop if EOS
            if next_token.item() == self.tokenizer.EOS_TOKEN:
                break
            
            generated_tokens.append(next_token.item())
            
            # Append to input
            current_input = torch.cat([current_input, next_token.unsqueeze(0)], dim=1)
        
        # Decode generated text
        generated_text = self.tokenizer.decode_text(generated_tokens)
        
        # Parse grid from generated text
        pred_grid = self.tokenizer.parse_grid_from_text(generated_text)
        
        # Fallback: use test input shape if parsing fails
        if pred_grid is None:
            pred_grid = [[0] * len(test_input[0]) for _ in range(len(test_input))]
        
        # Confidence estimation
        confidence_raw = self.confidence_head(x[:, -1, :]).squeeze()
        
        # EPISTEMIC CALIBRATION
        num_examples = len(train_inputs)
        
        # Known danger: Small sample
        if num_examples < 3:
            confidence = confidence_raw * self.config.small_sample_penalty
        else:
            confidence = confidence_raw
        
        # DOUBT: Language may miss visual patterns
        # Apply skepticism penalty
        language_uncertainty = 0.9  # Inherent limit of language for visual reasoning
        confidence = confidence * language_uncertainty
        
        # Cap at max_confidence
        confidence = torch.clamp(confidence, self.config.min_confidence, self.config.max_confidence)
        
        result = {
            'prediction': pred_grid,
            'confidence': confidence.item(),
            'generated_text': generated_text,
            'prompt': prompt,
            'num_examples': num_examples
        }
        
        if return_traces:
            result['attention_maps'] = attention_maps
            result['reasoning_traces'] = reasoning_traces
        
        return result
    
    def compute_loss(self,
                    train_inputs: List[List[List[int]]],
                    train_outputs: List[List[List[int]]],
                    target: List[List[int]]) -> torch.Tensor:
        """
        Compute generation loss
        
        DEDUCTION: Measure how well we apply learned rules
        """
        device = next(self.parameters()).device
        
        # Create training prompt
        prompt = self.encode_task_as_text(train_inputs[:-1], train_outputs[:-1], train_inputs[-1])
        
        # Add target as continuation
        target_text = self.tokenizer.encode_grid_to_text(target)
        full_text = prompt + target_text
        
        # Tokenize
        tokens = self.tokenizer.encode_text(full_text, max_length=1024)
        input_ids = torch.LongTensor(tokens).unsqueeze(0).to(device)
        
        # Embeddings
        seq_len = input_ids.shape[1]
        pos_ids = torch.arange(seq_len, device=device).unsqueeze(0)
        
        x = self.token_embedding(input_ids) + self.pos_embedding(pos_ids)
        
        # Forward
        for block in self.blocks:
            x, _ = block(x)
        
        x = self.output_norm(x)
        logits = self.output_proj(x)
        
        # Compute loss (predict next token)
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        
        loss = F.cross_entropy(
            shift_logits.view(-1, self.config.llm_vocab_size),
            shift_labels.view(-1)
        )
        
        return loss


# Initialize Agent B
agent_b = LLMAgentB(config).to(config.device)

# Count parameters
total_params_b = sum(p.numel() for p in agent_b.parameters())
trainable_params_b = sum(p.numel() for p in agent_b.parameters() if p.requires_grad)

print("üó£Ô∏è  Agent B (LLM) initialized")
print(f"   Total params: {total_params_b:,} ({total_params_b/1e6:.1f}M)")
print(f"   Trainable params: {trainable_params_b:,}")
print(f"   Architecture: {config.llm_num_layers} transformer blocks")
print(f"   Hidden size: {config.llm_hidden_size}")
print(f"   Vocab size: {config.llm_vocab_size}")
print(f"   Epistemic: Language skepticism (0.9x), Small sample penalty ({config.small_sample_penalty})")
print()

In [None]:
# CELL 3: HRM AGENT A - VISUAL PATTERN REASONING
# Lines: ~600
# Purpose: Hierarchical Reasoning Model with dual-level architecture

class RotaryPositionalEmbedding(nn.Module):
    """RoPE for position-aware attention"""
    
    def __init__(self, dim: int, max_seq_len: int = 2048):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # Precompute for max_seq_len
        t = torch.arange(max_seq_len, dtype=torch.float32)
        freqs = torch.outer(t, inv_freq)
        self.register_buffer('cos_cached', freqs.cos())
        self.register_buffer('sin_cached', freqs.sin())
    
    def forward(self, x: torch.Tensor, seq_len: int = None):
        """Apply rotary embeddings to x"""
        if seq_len is None:
            seq_len = x.shape[1]
        
        cos = self.cos_cached[:seq_len, :]
        sin = self.sin_cached[:seq_len, :]
        
        # Apply rotation
        x1, x2 = x[..., ::2], x[..., 1::2]
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1).flatten(-2)
        
        return rotated


class HierarchicalTransformerLayer(nn.Module):
    """Single transformer layer with recurrent state"""
    
    def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        # Multi-head attention
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)
        
        # Feed-forward
        self.ff1 = nn.Linear(hidden_size, hidden_size * 4)
        self.ff2 = nn.Linear(hidden_size * 4, hidden_size)
        
        # Layer norms
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass with attention output
        Returns: (output, attention_weights)
        """
        batch_size, seq_len, _ = x.shape
        
        # Multi-head attention
        residual = x
        x = self.ln1(x)
        
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        
        x = residual + self.dropout(attn_output)
        
        # Feed-forward
        residual = x
        x = self.ln2(x)
        x = self.ff2(F.gelu(self.ff1(x)))
        x = residual + self.dropout(x)
        
        return x, attn_weights.mean(dim=1)  # Average across heads


class HighLevelModule(nn.Module):
    """High-level abstract reasoning module"""
    
    def __init__(self, config: OrcaWhiskeyConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hrm_hidden_size
        
        # Transformer layers
        self.layers = nn.ModuleList([
            HierarchicalTransformerLayer(
                self.hidden_size, 
                config.hrm_num_heads
            ) for _ in range(config.hrm_num_layers_h)
        ])
        
        # Recurrent state
        self.state_norm = nn.LayerNorm(self.hidden_size)
        
    def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """
        High-level reasoning with recurrent state
        Returns: (output, new_state, attention_maps)
        """
        batch_size, seq_len, _ = x.shape
        
        # Initialize or use provided state
        if state is None:
            state = torch.zeros(batch_size, 1, self.hidden_size, device=x.device)
        
        # Concatenate state with input
        x = torch.cat([state, x], dim=1)
        
        # Process through layers
        attention_maps = []
        for layer in self.layers:
            x, attn = layer(x)
            attention_maps.append(attn)
        
        # Extract new state (first token)
        new_state = x[:, :1, :]
        new_state = self.state_norm(new_state)
        
        # Output (remaining tokens)
        output = x[:, 1:, :]
        
        return output, new_state, attention_maps


class LowLevelModule(nn.Module):
    """Low-level execution module for rapid processing"""
    
    def __init__(self, config: OrcaWhiskeyConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hrm_hidden_size
        
        # Transformer layers
        self.layers = nn.ModuleList([
            HierarchicalTransformerLayer(
                self.hidden_size,
                config.hrm_num_heads
            ) for _ in range(config.hrm_num_layers_l)
        ])
        
        # Recurrent state
        self.state_norm = nn.LayerNorm(self.hidden_size)
        
    def forward(self, x: torch.Tensor, h_output: torch.Tensor, state: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """
        Low-level execution guided by high-level output
        Returns: (output, new_state, attention_maps)
        """
        batch_size, seq_len, _ = x.shape
        
        # Initialize state
        if state is None:
            state = torch.zeros(batch_size, 1, self.hidden_size, device=x.device)
        
        # Combine input with high-level guidance
        x = x + h_output  # Residual connection from high-level
        
        # Concatenate state
        x = torch.cat([state, x], dim=1)
        
        # Process through layers
        attention_maps = []
        for layer in self.layers:
            x, attn = layer(x)
            attention_maps.append(attn)
        
        # Extract new state
        new_state = x[:, :1, :]
        new_state = self.state_norm(new_state)
        
        # Output
        output = x[:, 1:, :]
        
        return output, new_state, attention_maps


class HRMAgentA(nn.Module):
    """
    Hierarchical Reasoning Model - Agent A
    
    Visual pattern reasoning with dual-level architecture:
    - High-level: Abstract planning and pattern recognition
    - Low-level: Rapid execution and transformation
    
    Total params: ~27M
    """
    
    def __init__(self, config: OrcaWhiskeyConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hrm_hidden_size
        
        # Grid embedding
        self.grid_embedding = nn.Embedding(config.vocab_size, self.hidden_size)
        
        # Position encoding
        self.rope = RotaryPositionalEmbedding(self.hidden_size)
        
        # Hierarchical modules
        self.high_level = HighLevelModule(config)
        self.low_level = LowLevelModule(config)
        
        # Output head
        self.output_norm = nn.LayerNorm(self.hidden_size)
        self.output_proj = nn.Linear(self.hidden_size, config.vocab_size)
        
        # Epistemic layers (confidence estimation)
        self.confidence_head = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.GELU(),
            nn.Linear(self.hidden_size // 2, 1),
            nn.Sigmoid()
        )
        
    def encode_grids(self, train_inputs: List[List[List[int]]], train_outputs: List[List[List[int]]]) -> torch.Tensor:
        """
        Encode training pairs as sequences
        
        INDUCTION: Learn patterns from examples (bottom-up)
        """
        device = next(self.parameters()).device
        sequences = []
        
        for inp, out in zip(train_inputs, train_outputs):
            # Pad and flatten
            inp_flat = grid_utils.encode_grid_flat(inp, self.config.max_grid_size)
            out_flat = grid_utils.encode_grid_flat(out, self.config.max_grid_size)
            
            # Concatenate input-output pair
            pair_seq = np.concatenate([inp_flat, out_flat])
            sequences.append(torch.LongTensor(pair_seq))
        
        # Stack and move to device
        sequences = torch.stack(sequences).to(device)
        
        return sequences
    
    def forward(self, 
                train_inputs: List[List[List[int]]], 
                train_outputs: List[List[List[int]]],
                test_input: List[List[int]],
                return_traces: bool = False) -> Dict[str, Any]:
        """
        Forward pass: Reason from examples to prediction
        
        EPISTEMIC REASONING:
        - INDUCTION: Learn from training pairs
        - ABSTRACTION: Find core transformation pattern
        - DEDUCTION: Apply to test input
        - INFERENCE: Bridge gaps in understanding
        - SKEPTICISM: Question if pattern is real or coincidence
        - FEAR: Known dangers (small sample), unknown unknowns
        - HUMILITY: Default to uncertainty
        """
        device = next(self.parameters()).device
        batch_size = 1  # Process one task at a time
        
        # Encode training examples
        train_seq = self.encode_grids(train_inputs, train_outputs)
        train_seq = train_seq.unsqueeze(0)  # [1, num_pairs, seq_len]
        
        # Encode test input
        test_flat = grid_utils.encode_grid_flat(test_input, self.config.max_grid_size)
        test_seq = torch.LongTensor(test_flat).unsqueeze(0).to(device)  # [1, seq_len]
        
        # Embed
        train_emb = self.grid_embedding(train_seq)  # [1, num_pairs, seq_len, hidden]
        test_emb = self.grid_embedding(test_seq)    # [1, seq_len, hidden]
        
        # Flatten training pairs into single sequence
        num_pairs = train_emb.shape[1]
        train_emb = train_emb.view(batch_size, -1, self.hidden_size)  # [1, num_pairs*seq_len, hidden]
        
        # HIGH-LEVEL REASONING: Abstract pattern discovery
        h_state = None
        h_attention_maps = []
        
        for cycle in range(self.config.hrm_h_cycles):
            train_emb, h_state, h_attn = self.high_level(train_emb, h_state)
            h_attention_maps.extend(h_attn)
        
        # Pool high-level understanding
        h_summary = train_emb.mean(dim=1, keepdim=True)  # [1, 1, hidden]
        
        # Broadcast to test sequence length
        h_guidance = h_summary.expand(-1, test_emb.shape[1], -1)
        
        # LOW-LEVEL EXECUTION: Apply transformation
        l_state = None
        l_attention_maps = []
        
        for cycle in range(self.config.hrm_l_cycles):
            test_emb, l_state, l_attn = self.low_level(test_emb, h_guidance, l_state)
            l_attention_maps.extend(l_attn)
        
        # Output prediction
        output = self.output_norm(test_emb)
        logits = self.output_proj(output)  # [1, seq_len, vocab_size]
        
        # Predict grid
        predictions = torch.argmax(logits, dim=-1)  # [1, seq_len]
        
        # Confidence estimation (epistemic uncertainty)
        confidence_raw = self.confidence_head(output).mean()
        
        # EPISTEMIC CALIBRATION
        num_examples = len(train_inputs)
        
        # Known danger: Small sample size
        if num_examples < 3:
            confidence = confidence_raw * self.config.small_sample_penalty
        else:
            confidence = confidence_raw
        
        # Cap at max_confidence (epistemic humility)
        confidence = torch.clamp(confidence, self.config.min_confidence, self.config.max_confidence)
        
        # Decode prediction
        pred_grid = grid_utils.decode_grid_flat(
            predictions.squeeze(0).cpu().numpy(), 
            self.config.max_grid_size
        )
        
        result = {
            'prediction': pred_grid.tolist(),
            'confidence': confidence.item(),
            'logits': logits,
            'num_examples': num_examples
        }
        
        if return_traces:
            result['high_level_attention'] = h_attention_maps
            result['low_level_attention'] = l_attention_maps
            result['high_level_state'] = h_state
            result['low_level_state'] = l_state
        
        return result
    
    def compute_loss(self, 
                     train_inputs: List[List[List[int]]], 
                     train_outputs: List[List[List[int]]],
                     target: List[List[int]]) -> torch.Tensor:
        """
        Compute reconstruction loss for training
        
        DEDUCTION: Apply learned rules, measure error
        """
        device = next(self.parameters()).device
        
        # Encode everything
        train_seq = self.encode_grids(train_inputs, train_outputs).unsqueeze(0)
        target_flat = grid_utils.encode_grid_flat(target, self.config.max_grid_size)
        target_seq = torch.LongTensor(target_flat).unsqueeze(0).to(device)
        
        # Get last training input as "test"
        test_input = train_inputs[-1]
        
        # Forward pass
        output = self.forward(train_inputs[:-1], train_outputs[:-1], test_input)
        logits = output['logits']
        
        # Cross-entropy loss
        loss = F.cross_entropy(
            logits.view(-1, self.config.vocab_size),
            target_seq.view(-1)
        )
        
        return loss


# Initialize Agent A
agent_a = HRMAgentA(config).to(config.device)

# Count parameters
total_params = sum(p.numel() for p in agent_a.parameters())
trainable_params = sum(p.numel() for p in agent_a.parameters() if p.requires_grad)

print("ü§ñ Agent A (HRM) initialized")
print(f"   Total params: {total_params:,} ({total_params/1e6:.1f}M)")
print(f"   Trainable params: {trainable_params:,}")
print(f"   Architecture: {config.hrm_num_layers_h}H + {config.hrm_num_layers_l}L layers")
print(f"   Cycles: {config.hrm_h_cycles} high-level, {config.hrm_l_cycles} low-level")
print(f"   Epistemic: Humility (max {config.max_confidence}), Small sample penalty ({config.small_sample_penalty})")
print()