In [2]:
# Install
try:
  !pip uninstall -qy geofractal geometricvocab
except:
  pass

!pip install -q git+https://github.com/AbstractEyes/geofractal.git

[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for geofractal (pyproject.toml) ... [?25l[?25hdone
  Building wheel for geometricvocab (pyproject.toml) ... [?25l[?25hdone


# mail - raw basic

In [None]:
"""
INCEPTIVE TOWER EXPERIMENT
===========================

Testing the InceptiveTower concept at small scale:
- Encoder A (Inception-style): what tower sees (its inceptive view)
- Encoder B (Simple conv): the "true" space tower must predict into
- Tower learns to translate from A's view to B's space via indoctrination

Indoctrination: mail from true encoder forces tower alignment until internalized.

MNIST for fast iteration.

Author: AbstractPhil + Claude
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from typing import Tuple, Dict
import math


# =============================================================================
# INCEPTION-STYLE ENCODER (The Inceptive View)
# =============================================================================

class InceptionBlock(nn.Module):
    """
    Branched convolutions - sees input differently than simple conv.

    Branches:
    - 1x1 conv (point-wise)
    - 3x3 conv (local)
    - 5x5 conv (wider context)
    - 3x3 max pool + 1x1 conv

    All concatenated → different feature organization than simple conv.
    """

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        # Distribute output channels across branches
        branch_channels = out_channels // 4
        remainder = out_channels % 4

        self.branch_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, 1),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
        )

        self.branch_3x3 = nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, 1),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_channels, branch_channels, 3, padding=1),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
        )

        self.branch_5x5 = nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, 1),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_channels, branch_channels, 5, padding=2),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, branch_channels + remainder, 1),
            nn.BatchNorm2d(branch_channels + remainder),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: Tensor) -> Tensor:
        b1 = self.branch_1x1(x)
        b2 = self.branch_3x3(x)
        b3 = self.branch_5x5(x)
        b4 = self.branch_pool(x)
        return torch.cat([b1, b2, b3, b4], dim=1)


class InceptionEncoder(nn.Module):
    """
    Inception-style encoder for MNIST.

    Output: [B, latent_dim] features with inception's branched perspective.
    """

    def __init__(self, latent_dim: int = 128):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )

        self.inception1 = InceptionBlock(32, 64)
        self.pool1 = nn.MaxPool2d(2, 2)  # 28 -> 14

        self.inception2 = InceptionBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2, 2)  # 14 -> 7

        self.inception3 = InceptionBlock(128, 256)
        self.pool3 = nn.AdaptiveAvgPool2d(1)  # 7 -> 1

        self.fc = nn.Linear(256, latent_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.pool1(self.inception1(x))
        x = self.pool2(self.inception2(x))
        x = self.pool3(self.inception3(x))
        x = x.flatten(1)
        return self.fc(x)


# =============================================================================
# SIMPLE CONV ENCODER (The True Space)
# =============================================================================

class SimpleConvEncoder(nn.Module):
    """
    Standard conv encoder - the "true" target space.

    Different architecture = different feature organization.
    Tower must learn to translate from Inception view to this space.
    """

    def __init__(self, latent_dim: int = 128):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),   # 28 -> 14
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 14 -> 7
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), # 7 -> 4
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 4),                      # 4 -> 1
            nn.ReLU(inplace=True),
        )

        self.fc = nn.Linear(256, latent_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.encoder(x)
        x = x.flatten(1)
        return self.fc(x)


# =============================================================================
# INCEPTIVE TOWER (Sees Inception, Predicts Simple)
# =============================================================================

class InceptiveTower(nn.Module):
    """
    Tower with inceptive view, indoctrinated to true space.

    - Receives: Inception features (its inceptive view)
    - Must output: predictions in Simple encoder's space
    - Mail: Simple encoder's actual output (indoctrination signal)

    Tests whether geometric structure can bridge misaligned spaces.
    """

    def __init__(
        self,
        latent_dim: int = 128,
        hidden_dim: int = 256,
        num_layers: int = 2,
        use_mail: bool = True,
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.use_mail = use_mail

        # Input projection (sees inception features)
        self.input_proj = nn.Linear(latent_dim, hidden_dim)

        # Mail integration (optional - the "answer sheet")
        if use_mail:
            self.mail_proj = nn.Linear(latent_dim, hidden_dim)
            self.mail_gate = nn.Linear(hidden_dim * 2, hidden_dim)

        # Processing layers (could be transformer, keeping simple for now)
        layers = []
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
            ])
        self.layers = nn.Sequential(*layers)

        # Output projection (to simple encoder space)
        self.output_proj = nn.Linear(hidden_dim, latent_dim)

    def forward(
        self,
        inception_features: Tensor,
        mail: Tensor = None,
        mail_dropout: float = 0.0,
    ) -> Tensor:
        """
        Args:
            inception_features: [B, latent_dim] from InceptionEncoder
            mail: [B, latent_dim] from SimpleConvEncoder (ground truth)
            mail_dropout: probability of dropping mail (curriculum)

        Returns:
            y_pred: [B, latent_dim] prediction in Simple encoder space
        """
        B = inception_features.shape[0]
        device = inception_features.device

        h = self.input_proj(inception_features)

        # Integrate mail if available and enabled
        if self.use_mail and mail is not None:
            # Optionally drop mail (for curriculum learning)
            if mail_dropout > 0 and self.training:
                mask = torch.rand(B, 1, device=device) > mail_dropout
                mail = mail * mask.float()

            mail_h = self.mail_proj(mail)
            combined = torch.cat([h, mail_h], dim=-1)
            gate = torch.sigmoid(self.mail_gate(combined))
            h = h * (1 - gate) + mail_h * gate

        h = self.layers(h)
        y_pred = self.output_proj(h)

        return y_pred


# =============================================================================
# CLASSIFIER HEAD (For evaluation)
# =============================================================================

class ClassifierHead(nn.Module):
    """Simple classifier to evaluate latent quality."""

    def __init__(self, latent_dim: int = 128, num_classes: int = 10):
        super().__init__()
        self.fc = nn.Linear(latent_dim, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        return self.fc(x)


# =============================================================================
# EXPERIMENT HARNESS
# =============================================================================

class InceptiveTowerExperiment:
    """
    Full experiment comparing:
    1. Inception encoder alone (baseline)
    2. Simple encoder alone (target space baseline)
    3. Inceptive tower WITH mail (indoctrination)
    4. Inceptive tower WITHOUT mail (must generalize)
    """

    def __init__(
        self,
        latent_dim: int = 128,
        batch_size: int = 128,
        lr: float = 1e-3,
        device: str = 'cuda',
    ):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.latent_dim = latent_dim
        self.batch_size = batch_size

        # Data
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        self.train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
        self.test_data = datasets.MNIST('./data', train=False, transform=transform)

        self.train_loader = DataLoader(self.train_data, batch_size=batch_size, shuffle=True, num_workers=2)
        self.test_loader = DataLoader(self.test_data, batch_size=batch_size, shuffle=False, num_workers=2)

        # Models
        self.inception_enc = InceptionEncoder(latent_dim).to(self.device)
        self.simple_enc = SimpleConvEncoder(latent_dim).to(self.device)
        self.inceptive_tower = InceptiveTower(latent_dim, use_mail=True).to(self.device)

        # Classifiers (one per latent source)
        self.clf_inception = ClassifierHead(latent_dim).to(self.device)
        self.clf_simple = ClassifierHead(latent_dim).to(self.device)
        self.clf_tower = ClassifierHead(latent_dim).to(self.device)

        # Optimizers
        self.opt_encoders = torch.optim.Adam(
            list(self.inception_enc.parameters()) +
            list(self.simple_enc.parameters()),
            lr=lr,
        )
        self.opt_tower = torch.optim.Adam(self.inceptive_tower.parameters(), lr=lr)
        self.opt_clfs = torch.optim.Adam(
            list(self.clf_inception.parameters()) +
            list(self.clf_simple.parameters()) +
            list(self.clf_tower.parameters()),
            lr=lr,
        )

    def train_epoch_encoders(self) -> Dict[str, float]:
        """Train encoders + classifiers (Phase 1)."""
        self.inception_enc.train()
        self.simple_enc.train()
        self.clf_inception.train()
        self.clf_simple.train()

        total_loss_inc = 0
        total_loss_sim = 0
        correct_inc = 0
        correct_sim = 0
        total = 0

        for images, labels in self.train_loader:
            images, labels = images.to(self.device), labels.to(self.device)

            self.opt_encoders.zero_grad()
            self.opt_clfs.zero_grad()

            # Forward through both encoders
            z_inception = self.inception_enc(images)
            z_simple = self.simple_enc(images)

            # Classify
            logits_inc = self.clf_inception(z_inception)
            logits_sim = self.clf_simple(z_simple)

            # Losses
            loss_inc = F.cross_entropy(logits_inc, labels)
            loss_sim = F.cross_entropy(logits_sim, labels)
            loss = loss_inc + loss_sim

            loss.backward()
            self.opt_encoders.step()
            self.opt_clfs.step()

            total_loss_inc += loss_inc.item() * images.size(0)
            total_loss_sim += loss_sim.item() * images.size(0)
            correct_inc += (logits_inc.argmax(1) == labels).sum().item()
            correct_sim += (logits_sim.argmax(1) == labels).sum().item()
            total += images.size(0)

        return {
            'loss_inception': total_loss_inc / total,
            'loss_simple': total_loss_sim / total,
            'acc_inception': correct_inc / total,
            'acc_simple': correct_sim / total,
        }

    def train_epoch_tower(self, mail_dropout: float = 0.0) -> Dict[str, float]:
        """Train inceptive tower to translate inception → simple (Phase 2)."""
        self.inception_enc.eval()
        self.simple_enc.eval()
        self.inceptive_tower.train()
        self.clf_tower.train()

        total_loss_trans = 0
        total_loss_clf = 0
        correct = 0
        total = 0

        for images, labels in self.train_loader:
            images, labels = images.to(self.device), labels.to(self.device)

            self.opt_tower.zero_grad()
            self.opt_clfs.zero_grad()

            with torch.no_grad():
                z_inception = self.inception_enc(images)
                z_simple = self.simple_enc(images)  # This is the mail / ground truth

            # Tower predicts simple space from inception view
            z_tower = self.inceptive_tower(z_inception, mail=z_simple, mail_dropout=mail_dropout)

            # Translation loss (tower output should match simple encoder)
            loss_trans = F.mse_loss(z_tower, z_simple)

            # Classification loss (tower output should be classifiable)
            logits = self.clf_tower(z_tower)
            loss_clf = F.cross_entropy(logits, labels)

            loss = loss_trans + loss_clf

            loss.backward()
            self.opt_tower.step()
            self.opt_clfs.step()

            total_loss_trans += loss_trans.item() * images.size(0)
            total_loss_clf += loss_clf.item() * images.size(0)
            correct += (logits.argmax(1) == labels).sum().item()
            total += images.size(0)

        return {
            'loss_translation': total_loss_trans / total,
            'loss_clf': total_loss_clf / total,
            'acc_tower': correct / total,
        }

    @torch.no_grad()
    def evaluate(self, use_mail: bool = True) -> Dict[str, float]:
        """Evaluate all three paths."""
        self.inception_enc.eval()
        self.simple_enc.eval()
        self.inceptive_tower.eval()
        self.clf_inception.eval()
        self.clf_simple.eval()
        self.clf_tower.eval()

        correct_inc = 0
        correct_sim = 0
        correct_tower = 0
        total = 0
        trans_error = 0

        for images, labels in self.test_loader:
            images, labels = images.to(self.device), labels.to(self.device)

            z_inception = self.inception_enc(images)
            z_simple = self.simple_enc(images)

            # Tower with or without mail
            mail = z_simple if use_mail else None
            z_tower = self.inceptive_tower(z_inception, mail=mail, mail_dropout=0.0)

            # Classify
            pred_inc = self.clf_inception(z_inception).argmax(1)
            pred_sim = self.clf_simple(z_simple).argmax(1)
            pred_tower = self.clf_tower(z_tower).argmax(1)

            correct_inc += (pred_inc == labels).sum().item()
            correct_sim += (pred_sim == labels).sum().item()
            correct_tower += (pred_tower == labels).sum().item()
            trans_error += F.mse_loss(z_tower, z_simple, reduction='sum').item()
            total += images.size(0)

        return {
            'acc_inception': correct_inc / total,
            'acc_simple': correct_sim / total,
            'acc_tower': correct_tower / total,
            'translation_mse': trans_error / total / self.latent_dim,
        }

    def run(
        self,
        encoder_epochs: int = 5,
        tower_epochs: int = 10,
        mail_curriculum: bool = True,
    ):
        """
        Full experiment:
        1. Train encoders (establish the two spaces)
        2. Train tower with indoctrination (learn translation)
        3. Evaluate tower without mail (test generalization)
        """
        print("=" * 60)
        print("INCEPTIVE TOWER EXPERIMENT")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Latent dim: {self.latent_dim}")
        print()

        # Phase 1: Train encoders
        print("Phase 1: Training encoders...")
        print("-" * 40)
        for epoch in range(encoder_epochs):
            metrics = self.train_epoch_encoders()
            print(f"Epoch {epoch+1}/{encoder_epochs} | "
                  f"Inc: {metrics['acc_inception']:.3f} | "
                  f"Sim: {metrics['acc_simple']:.3f}")

        eval_metrics = self.evaluate(use_mail=False)
        print(f"\nEncoder test accuracy:")
        print(f"  Inception: {eval_metrics['acc_inception']:.3f}")
        print(f"  Simple:    {eval_metrics['acc_simple']:.3f}")
        print()

        # Phase 2: Train tower
        print("Phase 2: Training inceptive tower with indoctrination...")
        print("-" * 40)
        for epoch in range(tower_epochs):
            # Curriculum: gradually reduce mail
            if mail_curriculum:
                mail_dropout = min(0.8, epoch / tower_epochs)
            else:
                mail_dropout = 0.0

            metrics = self.train_epoch_tower(mail_dropout=mail_dropout)
            print(f"Epoch {epoch+1}/{tower_epochs} | "
                  f"Trans: {metrics['loss_translation']:.4f} | "
                  f"Acc: {metrics['acc_tower']:.3f} | "
                  f"Mail drop: {mail_dropout:.1%}")

        print()

        # Phase 3: Evaluate
        print("Phase 3: Evaluation")
        print("-" * 40)

        metrics_with_mail = self.evaluate(use_mail=True)
        metrics_without_mail = self.evaluate(use_mail=False)

        print(f"\nWith mail (teacher forcing):")
        print(f"  Tower accuracy:    {metrics_with_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_with_mail['translation_mse']:.4f}")

        print(f"\nWithout mail (generalization):")
        print(f"  Tower accuracy:    {metrics_without_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_without_mail['translation_mse']:.4f}")

        print(f"\nBaselines:")
        print(f"  Inception direct:  {metrics_without_mail['acc_inception']:.3f}")
        print(f"  Simple direct:     {metrics_without_mail['acc_simple']:.3f}")

        # Success criteria
        print()
        print("=" * 60)
        if metrics_without_mail['acc_tower'] > 0.90:
            print("SUCCESS: Tower learned to translate without mail")
        elif metrics_without_mail['acc_tower'] > 0.80:
            print("PARTIAL: Tower shows some generalization")
        else:
            print("NEEDS WORK: Tower relies too heavily on mail")
        print("=" * 60)

        return {
            'with_mail': metrics_with_mail,
            'without_mail': metrics_without_mail,
        }


# =============================================================================
# MAIN
# =============================================================================

if __name__ == '__main__':
    experiment = InceptiveTowerExperiment(
        latent_dim=128,
        batch_size=128,
        lr=1e-3,
        device='cuda',
    )

    results = experiment.run(
        encoder_epochs=5,
        tower_epochs=15,
        mail_curriculum=True,
    )

100%|██████████| 9.91M/9.91M [00:00<00:00, 20.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 499kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.55MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.5MB/s]


INCEPTIVE TOWER EXPERIMENT
Device: cuda
Latent dim: 128

Phase 1: Training encoders...
----------------------------------------
Epoch 1/5 | Inc: 0.955 | Sim: 0.934
Epoch 2/5 | Inc: 0.989 | Sim: 0.982
Epoch 3/5 | Inc: 0.992 | Sim: 0.987
Epoch 4/5 | Inc: 0.994 | Sim: 0.991
Epoch 5/5 | Inc: 0.994 | Sim: 0.993

Encoder test accuracy:
  Inception: 0.982
  Simple:    0.986

Phase 2: Training inceptive tower with indoctrination...
----------------------------------------
Epoch 1/15 | Trans: 0.3597 | Acc: 0.995 | Mail drop: 0.0%
Epoch 2/15 | Trans: 0.1121 | Acc: 0.998 | Mail drop: 6.7%
Epoch 3/15 | Trans: 0.1233 | Acc: 0.998 | Mail drop: 13.3%
Epoch 4/15 | Trans: 0.1439 | Acc: 0.998 | Mail drop: 20.0%
Epoch 5/15 | Trans: 0.1698 | Acc: 0.998 | Mail drop: 26.7%
Epoch 6/15 | Trans: 0.1930 | Acc: 0.998 | Mail drop: 33.3%
Epoch 7/15 | Trans: 0.2108 | Acc: 0.998 | Mail drop: 40.0%
Epoch 8/15 | Trans: 0.2312 | Acc: 0.998 | Mail drop: 46.7%
Epoch 9/15 | Trans: 0.2473 | Acc: 0.998 | Mail drop: 53.3%
Ep

# mail 2 - t5 + clip_base - Cifar10 - too easy

In [None]:
"""
INCEPTIVE TOWER EXPERIMENT V2
==============================

Scaling up with real pretrained experts:
- CLIP ViT-base-patch16: Vision expert (provides mail)
- T5-base: Text expert (provides sequence features)
- Inception encoder: Inceptive view (sees differently)

Tower learns to translate inceptive view → expert space via indoctrination.

Features are cached to .pt files for speed.

Author: AbstractPhil + Claude
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from typing import Dict, Optional, Tuple
from transformers import (
    CLIPVisionModel,
    CLIPProcessor,
    T5EncoderModel,
    T5Tokenizer,
)
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import math


# =============================================================================
# INCEPTION ENCODER (The Inceptive View)
# =============================================================================

class InceptionBlock(nn.Module):
    """Multi-branch convolution block."""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        branch_ch = out_channels // 4
        remainder = out_channels % 4

        self.branch_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_3x3 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_ch, branch_ch, 3, padding=1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_5x5 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_ch, branch_ch, 5, padding=2),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, branch_ch + remainder, 1),
            nn.BatchNorm2d(branch_ch + remainder),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: Tensor) -> Tensor:
        return torch.cat([
            self.branch_1x1(x),
            self.branch_3x3(x),
            self.branch_5x5(x),
            self.branch_pool(x),
        ], dim=1)


class InceptionEncoder(nn.Module):
    """
    Inception-style encoder that outputs CLIP-compatible dimension.

    Processes 224x224 RGB images → [B, 768] features
    """

    def __init__(self, out_features: int = 768):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),  # 224 -> 112
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),  # 112 -> 56
        )

        self.inception1 = InceptionBlock(64, 128)
        self.pool1 = nn.MaxPool2d(2, 2)  # 56 -> 28

        self.inception2 = InceptionBlock(128, 256)
        self.pool2 = nn.MaxPool2d(2, 2)  # 28 -> 14

        self.inception3 = InceptionBlock(256, 512)
        self.pool3 = nn.MaxPool2d(2, 2)  # 14 -> 7

        self.inception4 = InceptionBlock(512, 768)
        self.pool4 = nn.AdaptiveAvgPool2d(1)  # 7 -> 1

        self.fc = nn.Linear(768, out_features)

    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.pool1(self.inception1(x))
        x = self.pool2(self.inception2(x))
        x = self.pool3(self.inception3(x))
        x = self.pool4(self.inception4(x))
        x = x.flatten(1)
        return self.fc(x)


# =============================================================================
# INCEPTIVE TOWER
# =============================================================================

class InceptiveTower(nn.Module):
    """
    Tower with inceptive view, indoctrinated to expert space.

    - Receives: Inception features (its inceptive view)
    - Optional: Text features from T5 for conditioning
    - Must output: predictions in CLIP expert space
    - Mail: CLIP's actual output (indoctrination signal)
    """

    def __init__(
        self,
        vision_dim: int = 768,
        text_dim: int = 768,
        hidden_dim: int = 1024,
        num_layers: int = 3,
        use_text: bool = True,
    ):
        super().__init__()

        self.vision_dim = vision_dim
        self.text_dim = text_dim
        self.use_text = use_text

        # Vision input projection
        self.vision_proj = nn.Linear(vision_dim, hidden_dim)

        # Text conditioning (optional)
        if use_text:
            self.text_proj = nn.Linear(text_dim, hidden_dim)
            self.cross_attn = nn.MultiheadAttention(
                hidden_dim, num_heads=8, batch_first=True
            )

        # Mail integration (indoctrination)
        self.mail_proj = nn.Linear(vision_dim, hidden_dim)
        self.mail_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid(),
        )

        # Processing layers
        layers = []
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.1),
            ])
        self.layers = nn.Sequential(*layers)

        # Output projection (to CLIP space)
        self.output_proj = nn.Linear(hidden_dim, vision_dim)

    def forward(
        self,
        inception_features: Tensor,
        text_features: Optional[Tensor] = None,
        mail: Optional[Tensor] = None,
        mail_dropout: float = 0.0,
    ) -> Tensor:
        """
        Args:
            inception_features: [B, vision_dim] from InceptionEncoder
            text_features: [B, seq_len, text_dim] from T5 (optional)
            mail: [B, vision_dim] from CLIP (ground truth)
            mail_dropout: probability of dropping mail
        """
        B = inception_features.shape[0]
        device = inception_features.device

        # Project vision
        h = self.vision_proj(inception_features)  # [B, hidden]

        # Condition on text if available
        if self.use_text and text_features is not None:
            text_h = self.text_proj(text_features)  # [B, seq_len, hidden]
            h_query = h.unsqueeze(1)  # [B, 1, hidden]
            h_attn, _ = self.cross_attn(h_query, text_h, text_h)
            h = h + h_attn.squeeze(1)  # [B, hidden]

        # Integrate mail (indoctrination)
        if mail is not None:
            if mail_dropout > 0 and self.training:
                mask = (torch.rand(B, 1, device=device) > mail_dropout).float()
                mail = mail * mask

            mail_h = self.mail_proj(mail)
            combined = torch.cat([h, mail_h], dim=-1)
            gate = self.mail_gate(combined)
            h = h * (1 - gate) + mail_h * gate

        # Process
        h = self.layers(h)

        # Output in CLIP space
        return self.output_proj(h)


# =============================================================================
# EXPERIMENT
# =============================================================================

class InceptiveTowerExperimentV2:
    """
    Experiment with real pretrained experts:
    - CLIP ViT-base-patch16 (vision expert)
    - T5-base (text expert)
    - Inception encoder (inceptive view)

    Features cached to .pt files for speed.
    """

    def __init__(
        self,
        batch_size: int = 32,
        lr: float = 1e-4,
        device: str = 'cuda',
        cache_dir: str = './cache',
    ):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        # CIFAR-10 class names
        self.class_names = [
            'airplane', 'automobile', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck'
        ]

        # Dims
        self.clip_dim = 768
        self.t5_dim = 768

        # Cache paths
        self.t5_cache_path = self.cache_dir / 't5_class_features.pt'
        self.clip_train_cache_path = self.cache_dir / 'clip_train_features.pt'
        self.clip_test_cache_path = self.cache_dir / 'clip_test_features.pt'
        self.images_train_cache_path = self.cache_dir / 'images_train.pt'
        self.images_test_cache_path = self.cache_dir / 'images_test.pt'
        self.labels_train_cache_path = self.cache_dir / 'labels_train.pt'
        self.labels_test_cache_path = self.cache_dir / 'labels_test.pt'

        # Load or create caches
        self._setup_caches()

        # Create trainable modules
        self.inception_enc = InceptionEncoder(out_features=self.clip_dim).to(self.device)
        self.tower = InceptiveTower(
            vision_dim=self.clip_dim,
            text_dim=self.t5_dim,
            hidden_dim=1024,
            num_layers=3,
            use_text=True,
        ).to(self.device)

        # Classifier head
        self.classifier = nn.Linear(self.clip_dim, 10).to(self.device)

        print(f"  Inception encoder: {sum(p.numel() for p in self.inception_enc.parameters()):,} params")
        print(f"  Inceptive tower: {sum(p.numel() for p in self.tower.parameters()):,} params")

        # Optimizers
        self.opt_inception = torch.optim.AdamW(self.inception_enc.parameters(), lr=lr)
        self.opt_tower = torch.optim.AdamW(self.tower.parameters(), lr=lr)
        self.opt_clf = torch.optim.AdamW(self.classifier.parameters(), lr=lr)

        # Create dataloaders from cached tensors
        train_dataset = TensorDataset(
            self.images_train, self.labels_train, self.clip_train_features
        )
        test_dataset = TensorDataset(
            self.images_test, self.labels_test, self.clip_test_features
        )

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
        )
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False, num_workers=0
        )

    def _setup_caches(self):
        """Load or create all cached features."""

        # Data transforms
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711],
            ),
        ])

        # Check if all caches exist
        all_cached = all([
            self.t5_cache_path.exists(),
            self.clip_train_cache_path.exists(),
            self.clip_test_cache_path.exists(),
            self.images_train_cache_path.exists(),
            self.images_test_cache_path.exists(),
            self.labels_train_cache_path.exists(),
            self.labels_test_cache_path.exists(),
        ])

        if all_cached:
            print("Loading cached features...")
            self.t5_features = torch.load(self.t5_cache_path)
            self.clip_train_features = torch.load(self.clip_train_cache_path)
            self.clip_test_features = torch.load(self.clip_test_cache_path)
            self.images_train = torch.load(self.images_train_cache_path)
            self.images_test = torch.load(self.images_test_cache_path)
            self.labels_train = torch.load(self.labels_train_cache_path)
            self.labels_test = torch.load(self.labels_test_cache_path)
            print(f"  T5 features: {self.t5_features.shape}")
            print(f"  CLIP train: {self.clip_train_features.shape}")
            print(f"  CLIP test: {self.clip_test_features.shape}")
            return

        print("Creating feature caches (one-time)...")

        # Load raw datasets
        train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
        test_data = datasets.CIFAR10('./data', train=False, download=True, transform=transform)

        # ===== Cache T5 features (10 classes) =====
        print("  Caching T5 features for 10 classes...")
        t5_model = T5EncoderModel.from_pretrained("t5-base").to(self.device)
        t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
        t5_model.eval()

        texts = [f"a photo of a {name}" for name in self.class_names]
        tokens = t5_tokenizer(
            texts, return_tensors='pt', padding=True, truncation=True, max_length=32
        ).to(self.device)

        with torch.no_grad():
            t5_out = t5_model(**tokens)
            self.t5_features = t5_out.last_hidden_state.cpu()  # [10, seq_len, 768]

        torch.save(self.t5_features, self.t5_cache_path)
        print(f"    Saved: {self.t5_features.shape}")

        del t5_model, t5_tokenizer
        torch.cuda.empty_cache()

        # ===== Cache CLIP features =====
        print("  Loading CLIP...")
        clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
        clip_model.eval()

        def extract_clip_and_images(dataset, desc):
            loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
            all_clip = []
            all_images = []
            all_labels = []

            with torch.no_grad():
                for images, labels in tqdm(loader, desc=desc):
                    images = images.to(self.device)
                    clip_out = clip_model(pixel_values=images).pooler_output
                    all_clip.append(clip_out.cpu())
                    all_images.append(images.cpu())
                    all_labels.append(labels)

            return (
                torch.cat(all_clip, dim=0),
                torch.cat(all_images, dim=0),
                torch.cat(all_labels, dim=0),
            )

        # Train set
        self.clip_train_features, self.images_train, self.labels_train = extract_clip_and_images(
            train_data, "  Caching CLIP train"
        )
        torch.save(self.clip_train_features, self.clip_train_cache_path)
        torch.save(self.images_train, self.images_train_cache_path)
        torch.save(self.labels_train, self.labels_train_cache_path)
        print(f"    Train: {self.clip_train_features.shape}")

        # Test set
        self.clip_test_features, self.images_test, self.labels_test = extract_clip_and_images(
            test_data, "  Caching CLIP test"
        )
        torch.save(self.clip_test_features, self.clip_test_cache_path)
        torch.save(self.images_test, self.images_test_cache_path)
        torch.save(self.labels_test, self.labels_test_cache_path)
        print(f"    Test: {self.clip_test_features.shape}")

        del clip_model
        torch.cuda.empty_cache()
        print("  Caching complete!")

    def get_t5_features_for_labels(self, labels: Tensor) -> Tensor:
        """Look up cached T5 features by label index."""
        # self.t5_features: [10, seq_len, 768]
        # labels: [B]
        # return: [B, seq_len, 768]
        return self.t5_features[labels.cpu()]

    def train_epoch(
        self,
        mail_dropout: float = 0.0,
        use_text: bool = True,
    ) -> Dict[str, float]:
        """Train one epoch."""
        self.inception_enc.train()
        self.tower.train()
        self.classifier.train()

        total_loss_trans = 0
        total_loss_clf = 0
        correct_tower = 0
        correct_clip = 0
        total = 0

        pbar = tqdm(self.train_loader, desc="Training", leave=False)
        for images, labels, clip_features in pbar:
            images = images.to(self.device)
            labels = labels.to(self.device)
            clip_features = clip_features.to(self.device)

            # Get T5 features from cache
            t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

            # Get inceptive view
            inception_features = self.inception_enc(images)

            # Tower predicts CLIP space from inception view
            tower_output = self.tower(
                inception_features,
                text_features=t5_features,
                mail=clip_features,
                mail_dropout=mail_dropout,
            )

            # Losses
            loss_trans = F.mse_loss(tower_output, clip_features)

            # Classification from tower output
            logits = self.classifier(tower_output)
            loss_clf = F.cross_entropy(logits, labels)

            loss = loss_trans + loss_clf

            # Optimize
            self.opt_inception.zero_grad()
            self.opt_tower.zero_grad()
            self.opt_clf.zero_grad()
            loss.backward()
            self.opt_inception.step()
            self.opt_tower.step()
            self.opt_clf.step()

            # Track metrics
            total_loss_trans += loss_trans.item() * images.size(0)
            total_loss_clf += loss_clf.item() * images.size(0)
            correct_tower += (logits.argmax(1) == labels).sum().item()

            # CLIP baseline accuracy
            with torch.no_grad():
                clip_logits = self.classifier(clip_features)
                correct_clip += (clip_logits.argmax(1) == labels).sum().item()

            total += images.size(0)

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'acc': f'{correct_tower/total:.3f}',
            })

        return {
            'loss_translation': total_loss_trans / total,
            'loss_clf': total_loss_clf / total,
            'acc_tower': correct_tower / total,
            'acc_clip': correct_clip / total,
        }

    @torch.no_grad()
    def evaluate(self, use_mail: bool = True, use_text: bool = True) -> Dict[str, float]:
        """Evaluate all paths."""
        self.inception_enc.eval()
        self.tower.eval()
        self.classifier.eval()

        correct_tower = 0
        correct_clip = 0
        correct_inception = 0
        trans_error = 0
        total = 0

        desc = f"Eval (mail={use_mail}, text={use_text})"
        for images, labels, clip_features in tqdm(self.test_loader, desc=desc, leave=False):
            images = images.to(self.device)
            labels = labels.to(self.device)
            clip_features = clip_features.to(self.device)

            # T5 features
            t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

            # Inception features
            inception_features = self.inception_enc(images)

            # Tower output
            mail = clip_features if use_mail else None
            tower_output = self.tower(
                inception_features,
                text_features=t5_features,
                mail=mail,
                mail_dropout=0.0,
            )

            # Classifications
            pred_tower = self.classifier(tower_output).argmax(1)
            pred_clip = self.classifier(clip_features).argmax(1)
            pred_inception = self.classifier(inception_features).argmax(1)

            correct_tower += (pred_tower == labels).sum().item()
            correct_clip += (pred_clip == labels).sum().item()
            correct_inception += (pred_inception == labels).sum().item()
            trans_error += F.mse_loss(tower_output, clip_features, reduction='sum').item()
            total += images.size(0)

        return {
            'acc_tower': correct_tower / total,
            'acc_clip': correct_clip / total,
            'acc_inception': correct_inception / total,
            'translation_mse': trans_error / total / self.clip_dim,
        }

    def run(
        self,
        epochs: int = 10,
        mail_curriculum: bool = True,
        use_text: bool = True,
    ):
        """Run full experiment."""
        print("\n" + "=" * 60)
        print("INCEPTIVE TOWER EXPERIMENT V2")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Vision expert: CLIP ViT-base-patch16 (768-dim)")
        print(f"Text expert: T5-base (768-dim)")
        print(f"Use text conditioning: {use_text}")
        print()

        # Training with indoctrination
        print("Training with indoctrination...")
        print("-" * 40)

        for epoch in range(epochs):
            if mail_curriculum:
                mail_dropout = min(0.8, epoch / epochs)
            else:
                mail_dropout = 0.0

            metrics = self.train_epoch(mail_dropout=mail_dropout, use_text=use_text)
            print(f"Epoch {epoch+1}/{epochs} | "
                  f"Trans: {metrics['loss_translation']:.4f} | "
                  f"Tower: {metrics['acc_tower']:.3f} | "
                  f"CLIP: {metrics['acc_clip']:.3f} | "
                  f"Mail drop: {mail_dropout:.1%}")

        print()

        # Evaluation
        print("Evaluation")
        print("-" * 40)

        metrics_with_mail = self.evaluate(use_mail=True, use_text=use_text)
        metrics_without_mail = self.evaluate(use_mail=False, use_text=use_text)
        metrics_no_text = self.evaluate(use_mail=False, use_text=False)

        print(f"\nWith mail + text:")
        print(f"  Tower accuracy:    {metrics_with_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_with_mail['translation_mse']:.4f}")

        print(f"\nWithout mail (+ text):")
        print(f"  Tower accuracy:    {metrics_without_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_without_mail['translation_mse']:.4f}")

        print(f"\nWithout mail, without text:")
        print(f"  Tower accuracy:    {metrics_no_text['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_no_text['translation_mse']:.4f}")

        print(f"\nBaselines:")
        print(f"  CLIP direct:       {metrics_with_mail['acc_clip']:.3f}")
        print(f"  Inception direct:  {metrics_with_mail['acc_inception']:.3f}")

        # Success criteria
        print()
        print("=" * 60)
        no_mail_acc = metrics_without_mail['acc_tower']
        clip_acc = metrics_with_mail['acc_clip']

        if no_mail_acc >= clip_acc * 0.95:
            print(f"SUCCESS: Tower matches CLIP without mail ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        elif no_mail_acc >= clip_acc * 0.85:
            print(f"PARTIAL: Tower approaches CLIP ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        else:
            print(f"NEEDS WORK: Tower below CLIP ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        print("=" * 60)

        return {
            'with_mail': metrics_with_mail,
            'without_mail': metrics_without_mail,
            'no_text': metrics_no_text,
        }

    def run_side_by_side(
        self,
        epochs: int = 15,
        use_text: bool = True,
    ):
        """
        Train two towers side-by-side:
        - Tower A: WITH mail (indoctrinated)
        - Tower B: WITHOUT mail (blind)

        Compare learning curves directly.
        """
        print("\n" + "=" * 60)
        print("SIDE-BY-SIDE TRAINING: Mail vs No Mail")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Use text conditioning: {use_text}")
        print()

        # Create second tower (no mail)
        tower_no_mail = InceptiveTower(
            vision_dim=self.clip_dim,
            text_dim=self.t5_dim,
            hidden_dim=1024,
            num_layers=3,
            use_text=True,
        ).to(self.device)

        inception_no_mail = InceptionEncoder(out_features=self.clip_dim).to(self.device)
        classifier_no_mail = nn.Linear(self.clip_dim, 10).to(self.device)

        opt_tower_no_mail = torch.optim.AdamW(tower_no_mail.parameters(), lr=1e-4)
        opt_inc_no_mail = torch.optim.AdamW(inception_no_mail.parameters(), lr=1e-4)
        opt_clf_no_mail = torch.optim.AdamW(classifier_no_mail.parameters(), lr=1e-4)

        print(f"Tower A (with mail): {sum(p.numel() for p in self.tower.parameters()):,} params")
        print(f"Tower B (no mail):   {sum(p.numel() for p in tower_no_mail.parameters()):,} params")
        print()

        print("Training...")
        print("-" * 70)
        print(f"{'Epoch':>5} | {'Mail Tower':^20} | {'No Mail Tower':^20} | {'CLIP':>6}")
        print(f"{'':>5} | {'Loss':>8} {'Acc':>8} | {'Loss':>8} {'Acc':>8} | {'Acc':>6}")
        print("-" * 70)

        history = {
            'mail_acc': [], 'mail_loss': [],
            'no_mail_acc': [], 'no_mail_loss': [],
            'clip_acc': [],
        }

        for epoch in range(epochs):
            # Curriculum for mail tower
            mail_dropout = min(0.8, epoch / epochs)

            # ===== Train both towers =====
            self.inception_enc.train()
            self.tower.train()
            self.classifier.train()
            inception_no_mail.train()
            tower_no_mail.train()
            classifier_no_mail.train()

            mail_loss_sum, mail_correct = 0, 0
            no_mail_loss_sum, no_mail_correct = 0, 0
            clip_correct = 0
            total = 0

            pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
            for images, labels, clip_features in pbar:
                images = images.to(self.device)
                labels = labels.to(self.device)
                clip_features = clip_features.to(self.device)
                t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

                # ===== Tower A: With Mail =====
                inc_feat = self.inception_enc(images)
                tower_out = self.tower(
                    inc_feat, text_features=t5_features,
                    mail=clip_features, mail_dropout=mail_dropout
                )
                loss_trans = F.mse_loss(tower_out, clip_features)
                logits = self.classifier(tower_out)
                loss_clf = F.cross_entropy(logits, labels)
                loss_mail = loss_trans + loss_clf

                self.opt_inception.zero_grad()
                self.opt_tower.zero_grad()
                self.opt_clf.zero_grad()
                loss_mail.backward()
                self.opt_inception.step()
                self.opt_tower.step()
                self.opt_clf.step()

                mail_loss_sum += loss_mail.item() * images.size(0)
                mail_correct += (logits.argmax(1) == labels).sum().item()

                # ===== Tower B: No Mail =====
                inc_feat_b = inception_no_mail(images)
                tower_out_b = tower_no_mail(
                    inc_feat_b, text_features=t5_features,
                    mail=None, mail_dropout=0.0  # Never gets mail
                )
                loss_trans_b = F.mse_loss(tower_out_b, clip_features)
                logits_b = classifier_no_mail(tower_out_b)
                loss_clf_b = F.cross_entropy(logits_b, labels)
                loss_no_mail = loss_trans_b + loss_clf_b

                opt_inc_no_mail.zero_grad()
                opt_tower_no_mail.zero_grad()
                opt_clf_no_mail.zero_grad()
                loss_no_mail.backward()
                opt_inc_no_mail.step()
                opt_tower_no_mail.step()
                opt_clf_no_mail.step()

                no_mail_loss_sum += loss_no_mail.item() * images.size(0)
                no_mail_correct += (logits_b.argmax(1) == labels).sum().item()

                # CLIP baseline
                with torch.no_grad():
                    clip_logits = self.classifier(clip_features)
                    clip_correct += (clip_logits.argmax(1) == labels).sum().item()

                total += images.size(0)

                pbar.set_postfix({
                    'mail': f'{mail_correct/total:.3f}',
                    'no_mail': f'{no_mail_correct/total:.3f}',
                })

            # Epoch stats
            mail_acc = mail_correct / total
            mail_loss = mail_loss_sum / total
            no_mail_acc = no_mail_correct / total
            no_mail_loss = no_mail_loss_sum / total
            clip_acc = clip_correct / total

            history['mail_acc'].append(mail_acc)
            history['mail_loss'].append(mail_loss)
            history['no_mail_acc'].append(no_mail_acc)
            history['no_mail_loss'].append(no_mail_loss)
            history['clip_acc'].append(clip_acc)

            print(f"{epoch+1:>5} | {mail_loss:>8.4f} {mail_acc:>8.3f} | {no_mail_loss:>8.4f} {no_mail_acc:>8.3f} | {clip_acc:>6.3f}")

        print("-" * 70)
        print()

        # ===== Final Evaluation =====
        print("Final Evaluation (test set)")
        print("-" * 40)

        def eval_tower(tower, inc_enc, clf, use_mail_eval):
            tower.eval()
            inc_enc.eval()
            clf.eval()
            correct = 0
            total = 0

            with torch.no_grad():
                for images, labels, clip_features in self.test_loader:
                    images = images.to(self.device)
                    labels = labels.to(self.device)
                    clip_features = clip_features.to(self.device)
                    t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

                    inc_feat = inc_enc(images)
                    mail = clip_features if use_mail_eval else None
                    tower_out = tower(inc_feat, text_features=t5_features, mail=mail)
                    pred = clf(tower_out).argmax(1)
                    correct += (pred == labels).sum().item()
                    total += images.size(0)

            return correct / total

        # Mail tower - with and without mail at eval
        mail_tower_with = eval_tower(self.tower, self.inception_enc, self.classifier, use_mail_eval=True)
        mail_tower_without = eval_tower(self.tower, self.inception_enc, self.classifier, use_mail_eval=False)

        # No-mail tower
        no_mail_tower = eval_tower(tower_no_mail, inception_no_mail, classifier_no_mail, use_mail_eval=False)

        print(f"\nMail Tower (trained with indoctrination):")
        print(f"  With mail at eval:    {mail_tower_with:.3f}")
        print(f"  Without mail at eval: {mail_tower_without:.3f}")

        print(f"\nNo-Mail Tower (trained blind):")
        print(f"  Accuracy:             {no_mail_tower:.3f}")

        print(f"\nBaseline:")
        print(f"  CLIP direct:          {history['clip_acc'][-1]:.3f}")

        print()
        print("=" * 60)
        diff = mail_tower_without - no_mail_tower
        if diff > 0.02:
            print(f"INDOCTRINATION WINS: +{diff:.1%} accuracy from mail training")
        elif diff > -0.02:
            print(f"COMPARABLE: Both towers within 2% ({mail_tower_without:.3f} vs {no_mail_tower:.3f})")
        else:
            print(f"NO-MAIL WINS: Blind tower better by {-diff:.1%}")
        print("=" * 60)

        return history


# =============================================================================
# MAIN
# =============================================================================

if __name__ == '__main__':
    experiment = InceptiveTowerExperimentV2(
        batch_size=64,
        lr=1e-4,
        device='cuda',
        cache_dir='./cache',
    )

    # Side-by-side comparison
    history = experiment.run_side_by_side(
        epochs=15,
        use_text=True,
    )

Loading cached features...
  T5 features: torch.Size([10, 10, 768])
  CLIP train: torch.Size([50000, 768])
  CLIP test: torch.Size([10000, 768])
  Inception encoder: 3,157,440 params
  Inceptive tower: 12,601,088 params

SIDE-BY-SIDE TRAINING: Mail vs No Mail
Device: cuda
Use text conditioning: True

Tower A (with mail): 12,601,088 params
Tower B (no mail):   12,601,088 params

Training...
----------------------------------------------------------------------
Epoch |      Mail Tower      |    No Mail Tower     |   CLIP
      |     Loss      Acc |     Loss      Acc |    Acc
----------------------------------------------------------------------




    1 |   0.3756    0.992 |   0.4552    0.977 |  0.607




    2 |   0.1709    1.000 |   0.2642    1.000 |  0.913




    3 |   0.1360    1.000 |   0.2322    1.000 |  0.928




    4 |   0.1261    1.000 |   0.2147    1.000 |  0.934




    5 |   0.1255    1.000 |   0.2028    1.000 |  0.937




KeyboardInterrupt: 

# t5-base + clip_b - mail cifar100

## learned too quickly

In [None]:
"""
INCEPTIVE TOWER EXPERIMENT V2
==============================

Scaling up with real pretrained experts:
- CLIP ViT-base-patch16: Vision expert (provides mail)
- T5-base: Text expert (provides sequence features)
- Inception encoder: Inceptive view (sees differently)

Tower learns to translate inceptive view → expert space via indoctrination.

Features are cached to .pt files for speed.

Author: AbstractPhil + Claude
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from typing import Dict, Optional, Tuple
from transformers import (
    CLIPVisionModel,
    CLIPProcessor,
    T5EncoderModel,
    T5Tokenizer,
)
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import math


# =============================================================================
# INCEPTION ENCODER (The Inceptive View)
# =============================================================================

class InceptionBlock(nn.Module):
    """Multi-branch convolution block."""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        branch_ch = out_channels // 4
        remainder = out_channels % 4

        self.branch_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_3x3 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_ch, branch_ch, 3, padding=1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_5x5 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_ch, branch_ch, 5, padding=2),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, branch_ch + remainder, 1),
            nn.BatchNorm2d(branch_ch + remainder),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: Tensor) -> Tensor:
        return torch.cat([
            self.branch_1x1(x),
            self.branch_3x3(x),
            self.branch_5x5(x),
            self.branch_pool(x),
        ], dim=1)


class InceptionEncoder(nn.Module):
    """
    Inception-style encoder that outputs CLIP-compatible dimension.

    Processes 224x224 RGB images → [B, 768] features
    """

    def __init__(self, out_features: int = 768):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),  # 224 -> 112
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),  # 112 -> 56
        )

        self.inception1 = InceptionBlock(64, 128)
        self.pool1 = nn.MaxPool2d(2, 2)  # 56 -> 28

        self.inception2 = InceptionBlock(128, 256)
        self.pool2 = nn.MaxPool2d(2, 2)  # 28 -> 14

        self.inception3 = InceptionBlock(256, 512)
        self.pool3 = nn.MaxPool2d(2, 2)  # 14 -> 7

        self.inception4 = InceptionBlock(512, 768)
        self.pool4 = nn.AdaptiveAvgPool2d(1)  # 7 -> 1

        self.fc = nn.Linear(768, out_features)

    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.pool1(self.inception1(x))
        x = self.pool2(self.inception2(x))
        x = self.pool3(self.inception3(x))
        x = self.pool4(self.inception4(x))
        x = x.flatten(1)
        return self.fc(x)


# =============================================================================
# INCEPTIVE TOWER
# =============================================================================

class InceptiveTower(nn.Module):
    """
    Tower with inceptive view, indoctrinated to expert space.

    - Receives: Inception features (its inceptive view)
    - Optional: Text features from T5 for conditioning
    - Must output: predictions in CLIP expert space
    - Mail: CLIP's actual output (indoctrination signal)
    """

    def __init__(
        self,
        vision_dim: int = 768,
        text_dim: int = 768,
        hidden_dim: int = 1024,
        num_layers: int = 3,
        use_text: bool = True,
    ):
        super().__init__()

        self.vision_dim = vision_dim
        self.text_dim = text_dim
        self.use_text = use_text

        # Vision input projection
        self.vision_proj = nn.Linear(vision_dim, hidden_dim)

        # Text conditioning (optional)
        if use_text:
            self.text_proj = nn.Linear(text_dim, hidden_dim)
            self.cross_attn = nn.MultiheadAttention(
                hidden_dim, num_heads=8, batch_first=True
            )

        # Mail integration (indoctrination)
        self.mail_proj = nn.Linear(vision_dim, hidden_dim)
        self.mail_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid(),
        )

        # Processing layers
        layers = []
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.1),
            ])
        self.layers = nn.Sequential(*layers)

        # Output projection (to CLIP space)
        self.output_proj = nn.Linear(hidden_dim, vision_dim)

    def forward(
        self,
        inception_features: Tensor,
        text_features: Optional[Tensor] = None,
        mail: Optional[Tensor] = None,
        mail_dropout: float = 0.0,
    ) -> Tensor:
        """
        Args:
            inception_features: [B, vision_dim] from InceptionEncoder
            text_features: [B, seq_len, text_dim] from T5 (optional)
            mail: [B, vision_dim] from CLIP (ground truth)
            mail_dropout: probability of dropping mail
        """
        B = inception_features.shape[0]
        device = inception_features.device

        # Project vision
        h = self.vision_proj(inception_features)  # [B, hidden]

        # Condition on text if available
        if self.use_text and text_features is not None:
            text_h = self.text_proj(text_features)  # [B, seq_len, hidden]
            h_query = h.unsqueeze(1)  # [B, 1, hidden]
            h_attn, _ = self.cross_attn(h_query, text_h, text_h)
            h = h + h_attn.squeeze(1)  # [B, hidden]

        # Integrate mail (indoctrination)
        if mail is not None:
            if mail_dropout > 0 and self.training:
                mask = (torch.rand(B, 1, device=device) > mail_dropout).float()
                mail = mail * mask

            mail_h = self.mail_proj(mail)
            combined = torch.cat([h, mail_h], dim=-1)
            gate = self.mail_gate(combined)
            h = h * (1 - gate) + mail_h * gate

        # Process
        h = self.layers(h)

        # Output in CLIP space
        return self.output_proj(h)


# =============================================================================
# EXPERIMENT
# =============================================================================

class InceptiveTowerExperimentV2:
    """
    Experiment with real pretrained experts:
    - CLIP ViT-base-patch16 (vision expert)
    - T5-base (text expert)
    - Inception encoder (inceptive view)

    Features cached to .pt files for speed.
    """

    def __init__(
        self,
        batch_size: int = 32,
        lr: float = 1e-4,
        device: str = 'cuda',
        cache_dir: str = './cache',
    ):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        # CIFAR-100 class names
        self.class_names = [
            'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
            'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
            'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
            'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
            'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
            'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
            'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
            'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
            'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
            'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
            'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
            'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
            'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
            'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
        ]
        self.num_classes = 100

        # Dims
        self.clip_dim = 768
        self.t5_dim = 768

        # Cache paths (CIFAR-100)
        self.t5_cache_path = self.cache_dir / 't5_class_features_cifar100.pt'
        self.clip_train_cache_path = self.cache_dir / 'clip_train_features_cifar100.pt'
        self.clip_test_cache_path = self.cache_dir / 'clip_test_features_cifar100.pt'
        self.images_train_cache_path = self.cache_dir / 'images_train_cifar100.pt'
        self.images_test_cache_path = self.cache_dir / 'images_test_cifar100.pt'
        self.labels_train_cache_path = self.cache_dir / 'labels_train_cifar100.pt'
        self.labels_test_cache_path = self.cache_dir / 'labels_test_cifar100.pt'

        # Load or create caches
        self._setup_caches()

        # Create trainable modules
        self.inception_enc = InceptionEncoder(out_features=self.clip_dim).to(self.device)
        self.tower = InceptiveTower(
            vision_dim=self.clip_dim,
            text_dim=self.t5_dim,
            hidden_dim=1024,
            num_layers=3,
            use_text=True,
        ).to(self.device)

        # Classifier head
        self.classifier = nn.Linear(self.clip_dim, self.num_classes).to(self.device)

        print(f"  Inception encoder: {sum(p.numel() for p in self.inception_enc.parameters()):,} params")
        print(f"  Inceptive tower: {sum(p.numel() for p in self.tower.parameters()):,} params")

        # Optimizers
        self.opt_inception = torch.optim.AdamW(self.inception_enc.parameters(), lr=lr)
        self.opt_tower = torch.optim.AdamW(self.tower.parameters(), lr=lr)
        self.opt_clf = torch.optim.AdamW(self.classifier.parameters(), lr=lr)

        # Create dataloaders from cached tensors
        train_dataset = TensorDataset(
            self.images_train, self.labels_train, self.clip_train_features
        )
        test_dataset = TensorDataset(
            self.images_test, self.labels_test, self.clip_test_features
        )

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
        )
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False, num_workers=0
        )

    def _setup_caches(self):
        """Load or create all cached features."""

        # Data transforms
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711],
            ),
        ])

        # Check if all caches exist
        all_cached = all([
            self.t5_cache_path.exists(),
            self.clip_train_cache_path.exists(),
            self.clip_test_cache_path.exists(),
            self.images_train_cache_path.exists(),
            self.images_test_cache_path.exists(),
            self.labels_train_cache_path.exists(),
            self.labels_test_cache_path.exists(),
        ])

        if all_cached:
            print("Loading cached features...")
            self.t5_features = torch.load(self.t5_cache_path)
            self.clip_train_features = torch.load(self.clip_train_cache_path)
            self.clip_test_features = torch.load(self.clip_test_cache_path)
            self.images_train = torch.load(self.images_train_cache_path)
            self.images_test = torch.load(self.images_test_cache_path)
            self.labels_train = torch.load(self.labels_train_cache_path)
            self.labels_test = torch.load(self.labels_test_cache_path)
            print(f"  T5 features: {self.t5_features.shape}")
            print(f"  CLIP train: {self.clip_train_features.shape}")
            print(f"  CLIP test: {self.clip_test_features.shape}")
            return

        print("Creating feature caches (one-time)...")

        # Load raw datasets (CIFAR-100)
        train_data = datasets.CIFAR100('./data', train=True, download=True, transform=transform)
        test_data = datasets.CIFAR100('./data', train=False, download=True, transform=transform)

        # ===== Cache T5 features (10 classes) =====
        print("  Caching T5 features for 10 classes...")
        t5_model = T5EncoderModel.from_pretrained("t5-base").to(self.device)
        t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
        t5_model.eval()

        texts = [f"a photo of a {name}" for name in self.class_names]
        tokens = t5_tokenizer(
            texts, return_tensors='pt', padding=True, truncation=True, max_length=32
        ).to(self.device)

        with torch.no_grad():
            t5_out = t5_model(**tokens)
            self.t5_features = t5_out.last_hidden_state.cpu()  # [10, seq_len, 768]

        torch.save(self.t5_features, self.t5_cache_path)
        print(f"    Saved: {self.t5_features.shape}")

        del t5_model, t5_tokenizer
        torch.cuda.empty_cache()

        # ===== Cache CLIP features =====
        print("  Loading CLIP...")
        clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
        clip_model.eval()

        def extract_clip_and_images(dataset, desc):
            loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
            all_clip = []
            all_images = []
            all_labels = []

            with torch.no_grad():
                for images, labels in tqdm(loader, desc=desc):
                    images = images.to(self.device)
                    clip_out = clip_model(pixel_values=images).pooler_output
                    all_clip.append(clip_out.cpu())
                    all_images.append(images.cpu())
                    all_labels.append(labels)

            return (
                torch.cat(all_clip, dim=0),
                torch.cat(all_images, dim=0),
                torch.cat(all_labels, dim=0),
            )

        # Train set
        self.clip_train_features, self.images_train, self.labels_train = extract_clip_and_images(
            train_data, "  Caching CLIP train"
        )
        torch.save(self.clip_train_features, self.clip_train_cache_path)
        torch.save(self.images_train, self.images_train_cache_path)
        torch.save(self.labels_train, self.labels_train_cache_path)
        print(f"    Train: {self.clip_train_features.shape}")

        # Test set
        self.clip_test_features, self.images_test, self.labels_test = extract_clip_and_images(
            test_data, "  Caching CLIP test"
        )
        torch.save(self.clip_test_features, self.clip_test_cache_path)
        torch.save(self.images_test, self.images_test_cache_path)
        torch.save(self.labels_test, self.labels_test_cache_path)
        print(f"    Test: {self.clip_test_features.shape}")

        del clip_model
        torch.cuda.empty_cache()
        print("  Caching complete!")

    def get_t5_features_for_labels(self, labels: Tensor) -> Tensor:
        """Look up cached T5 features by label index."""
        # self.t5_features: [10, seq_len, 768]
        # labels: [B]
        # return: [B, seq_len, 768]
        return self.t5_features[labels.cpu()]

    def train_epoch(
        self,
        mail_dropout: float = 0.0,
        use_text: bool = True,
    ) -> Dict[str, float]:
        """Train one epoch."""
        self.inception_enc.train()
        self.tower.train()
        self.classifier.train()

        total_loss_trans = 0
        total_loss_clf = 0
        correct_tower = 0
        correct_clip = 0
        total = 0

        pbar = tqdm(self.train_loader, desc="Training", leave=False)
        for images, labels, clip_features in pbar:
            images = images.to(self.device)
            labels = labels.to(self.device)
            clip_features = clip_features.to(self.device)

            # Get T5 features from cache
            t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

            # Get inceptive view
            inception_features = self.inception_enc(images)

            # Tower predicts CLIP space from inception view
            tower_output = self.tower(
                inception_features,
                text_features=t5_features,
                mail=clip_features,
                mail_dropout=mail_dropout,
            )

            # Losses
            loss_trans = F.mse_loss(tower_output, clip_features)

            # Classification from tower output
            logits = self.classifier(tower_output)
            loss_clf = F.cross_entropy(logits, labels)

            loss = loss_trans + loss_clf

            # Optimize
            self.opt_inception.zero_grad()
            self.opt_tower.zero_grad()
            self.opt_clf.zero_grad()
            loss.backward()
            self.opt_inception.step()
            self.opt_tower.step()
            self.opt_clf.step()

            # Track metrics
            total_loss_trans += loss_trans.item() * images.size(0)
            total_loss_clf += loss_clf.item() * images.size(0)
            correct_tower += (logits.argmax(1) == labels).sum().item()

            # CLIP baseline accuracy
            with torch.no_grad():
                clip_logits = self.classifier(clip_features)
                correct_clip += (clip_logits.argmax(1) == labels).sum().item()

            total += images.size(0)

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'acc': f'{correct_tower/total:.3f}',
            })

        return {
            'loss_translation': total_loss_trans / total,
            'loss_clf': total_loss_clf / total,
            'acc_tower': correct_tower / total,
            'acc_clip': correct_clip / total,
        }

    @torch.no_grad()
    def evaluate(self, use_mail: bool = True, use_text: bool = True) -> Dict[str, float]:
        """Evaluate all paths."""
        self.inception_enc.eval()
        self.tower.eval()
        self.classifier.eval()

        correct_tower = 0
        correct_clip = 0
        correct_inception = 0
        trans_error = 0
        total = 0

        desc = f"Eval (mail={use_mail}, text={use_text})"
        for images, labels, clip_features in tqdm(self.test_loader, desc=desc, leave=False):
            images = images.to(self.device)
            labels = labels.to(self.device)
            clip_features = clip_features.to(self.device)

            # T5 features
            t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

            # Inception features
            inception_features = self.inception_enc(images)

            # Tower output
            mail = clip_features if use_mail else None
            tower_output = self.tower(
                inception_features,
                text_features=t5_features,
                mail=mail,
                mail_dropout=0.0,
            )

            # Classifications
            pred_tower = self.classifier(tower_output).argmax(1)
            pred_clip = self.classifier(clip_features).argmax(1)
            pred_inception = self.classifier(inception_features).argmax(1)

            correct_tower += (pred_tower == labels).sum().item()
            correct_clip += (pred_clip == labels).sum().item()
            correct_inception += (pred_inception == labels).sum().item()
            trans_error += F.mse_loss(tower_output, clip_features, reduction='sum').item()
            total += images.size(0)

        return {
            'acc_tower': correct_tower / total,
            'acc_clip': correct_clip / total,
            'acc_inception': correct_inception / total,
            'translation_mse': trans_error / total / self.clip_dim,
        }

    def run(
        self,
        epochs: int = 10,
        mail_curriculum: bool = True,
        use_text: bool = True,
    ):
        """Run full experiment."""
        print("\n" + "=" * 60)
        print("INCEPTIVE TOWER EXPERIMENT V2")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Vision expert: CLIP ViT-base-patch16 (768-dim)")
        print(f"Text expert: T5-base (768-dim)")
        print(f"Use text conditioning: {use_text}")
        print()

        # Training with indoctrination
        print("Training with indoctrination...")
        print("-" * 40)

        for epoch in range(epochs):
            if mail_curriculum:
                mail_dropout = min(0.8, epoch / epochs)
            else:
                mail_dropout = 0.0

            metrics = self.train_epoch(mail_dropout=mail_dropout, use_text=use_text)
            print(f"Epoch {epoch+1}/{epochs} | "
                  f"Trans: {metrics['loss_translation']:.4f} | "
                  f"Tower: {metrics['acc_tower']:.3f} | "
                  f"CLIP: {metrics['acc_clip']:.3f} | "
                  f"Mail drop: {mail_dropout:.1%}")

        print()

        # Evaluation
        print("Evaluation")
        print("-" * 40)

        metrics_with_mail = self.evaluate(use_mail=True, use_text=use_text)
        metrics_without_mail = self.evaluate(use_mail=False, use_text=use_text)
        metrics_no_text = self.evaluate(use_mail=False, use_text=False)

        print(f"\nWith mail + text:")
        print(f"  Tower accuracy:    {metrics_with_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_with_mail['translation_mse']:.4f}")

        print(f"\nWithout mail (+ text):")
        print(f"  Tower accuracy:    {metrics_without_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_without_mail['translation_mse']:.4f}")

        print(f"\nWithout mail, without text:")
        print(f"  Tower accuracy:    {metrics_no_text['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_no_text['translation_mse']:.4f}")

        print(f"\nBaselines:")
        print(f"  CLIP direct:       {metrics_with_mail['acc_clip']:.3f}")
        print(f"  Inception direct:  {metrics_with_mail['acc_inception']:.3f}")

        # Success criteria
        print()
        print("=" * 60)
        no_mail_acc = metrics_without_mail['acc_tower']
        clip_acc = metrics_with_mail['acc_clip']

        if no_mail_acc >= clip_acc * 0.95:
            print(f"SUCCESS: Tower matches CLIP without mail ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        elif no_mail_acc >= clip_acc * 0.85:
            print(f"PARTIAL: Tower approaches CLIP ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        else:
            print(f"NEEDS WORK: Tower below CLIP ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        print("=" * 60)

        return {
            'with_mail': metrics_with_mail,
            'without_mail': metrics_without_mail,
            'no_text': metrics_no_text,
        }

    def run_side_by_side(
        self,
        epochs: int = 15,
        use_text: bool = True,
    ):
        """
        Train two towers side-by-side:
        - Tower A: WITH mail (indoctrinated)
        - Tower B: WITHOUT mail (blind)

        Compare learning curves directly.
        """
        print("\n" + "=" * 60)
        print("SIDE-BY-SIDE TRAINING: Mail vs No Mail")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Use text conditioning: {use_text}")
        print()

        # Create second tower (no mail)
        tower_no_mail = InceptiveTower(
            vision_dim=self.clip_dim,
            text_dim=self.t5_dim,
            hidden_dim=1024,
            num_layers=3,
            use_text=True,
        ).to(self.device)

        inception_no_mail = InceptionEncoder(out_features=self.clip_dim).to(self.device)
        classifier_no_mail = nn.Linear(self.clip_dim, self.num_classes).to(self.device)

        opt_tower_no_mail = torch.optim.AdamW(tower_no_mail.parameters(), lr=1e-4)
        opt_inc_no_mail = torch.optim.AdamW(inception_no_mail.parameters(), lr=1e-4)
        opt_clf_no_mail = torch.optim.AdamW(classifier_no_mail.parameters(), lr=1e-4)

        print(f"Tower A (with mail): {sum(p.numel() for p in self.tower.parameters()):,} params")
        print(f"Tower B (no mail):   {sum(p.numel() for p in tower_no_mail.parameters()):,} params")
        print()

        print("Training...")
        print("-" * 70)
        print(f"{'Epoch':>5} | {'Mail Tower':^20} | {'No Mail Tower':^20} | {'CLIP':>6}")
        print(f"{'':>5} | {'Loss':>8} {'Acc':>8} | {'Loss':>8} {'Acc':>8} | {'Acc':>6}")
        print("-" * 70)

        history = {
            'mail_acc': [], 'mail_loss': [],
            'no_mail_acc': [], 'no_mail_loss': [],
            'clip_acc': [],
        }

        for epoch in range(epochs):
            # Curriculum for mail tower
            mail_dropout = min(0.8, epoch / epochs)

            # ===== Train both towers =====
            self.inception_enc.train()
            self.tower.train()
            self.classifier.train()
            inception_no_mail.train()
            tower_no_mail.train()
            classifier_no_mail.train()

            mail_loss_sum, mail_correct = 0, 0
            no_mail_loss_sum, no_mail_correct = 0, 0
            clip_correct = 0
            total = 0

            pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
            for images, labels, clip_features in pbar:
                images = images.to(self.device)
                labels = labels.to(self.device)
                clip_features = clip_features.to(self.device)
                t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

                # ===== Tower A: With Mail =====
                inc_feat = self.inception_enc(images)
                tower_out = self.tower(
                    inc_feat, text_features=t5_features,
                    mail=clip_features, mail_dropout=mail_dropout
                )
                loss_trans = F.mse_loss(tower_out, clip_features)
                logits = self.classifier(tower_out)
                loss_clf = F.cross_entropy(logits, labels)
                loss_mail = loss_trans + loss_clf

                self.opt_inception.zero_grad()
                self.opt_tower.zero_grad()
                self.opt_clf.zero_grad()
                loss_mail.backward()
                self.opt_inception.step()
                self.opt_tower.step()
                self.opt_clf.step()

                mail_loss_sum += loss_mail.item() * images.size(0)
                mail_correct += (logits.argmax(1) == labels).sum().item()

                # ===== Tower B: No Mail =====
                inc_feat_b = inception_no_mail(images)
                tower_out_b = tower_no_mail(
                    inc_feat_b, text_features=t5_features,
                    mail=None, mail_dropout=0.0  # Never gets mail
                )
                loss_trans_b = F.mse_loss(tower_out_b, clip_features)
                logits_b = classifier_no_mail(tower_out_b)
                loss_clf_b = F.cross_entropy(logits_b, labels)
                loss_no_mail = loss_trans_b + loss_clf_b

                opt_inc_no_mail.zero_grad()
                opt_tower_no_mail.zero_grad()
                opt_clf_no_mail.zero_grad()
                loss_no_mail.backward()
                opt_inc_no_mail.step()
                opt_tower_no_mail.step()
                opt_clf_no_mail.step()

                no_mail_loss_sum += loss_no_mail.item() * images.size(0)
                no_mail_correct += (logits_b.argmax(1) == labels).sum().item()

                # CLIP baseline
                with torch.no_grad():
                    clip_logits = self.classifier(clip_features)
                    clip_correct += (clip_logits.argmax(1) == labels).sum().item()

                total += images.size(0)

                pbar.set_postfix({
                    'mail': f'{mail_correct/total:.3f}',
                    'no_mail': f'{no_mail_correct/total:.3f}',
                })

            # Epoch stats
            mail_acc = mail_correct / total
            mail_loss = mail_loss_sum / total
            no_mail_acc = no_mail_correct / total
            no_mail_loss = no_mail_loss_sum / total
            clip_acc = clip_correct / total

            history['mail_acc'].append(mail_acc)
            history['mail_loss'].append(mail_loss)
            history['no_mail_acc'].append(no_mail_acc)
            history['no_mail_loss'].append(no_mail_loss)
            history['clip_acc'].append(clip_acc)

            print(f"{epoch+1:>5} | {mail_loss:>8.4f} {mail_acc:>8.3f} | {no_mail_loss:>8.4f} {no_mail_acc:>8.3f} | {clip_acc:>6.3f}")

        print("-" * 70)
        print()

        # ===== Final Evaluation =====
        print("Final Evaluation (test set)")
        print("-" * 40)

        def eval_tower(tower, inc_enc, clf, use_mail_eval):
            tower.eval()
            inc_enc.eval()
            clf.eval()
            correct = 0
            total = 0

            with torch.no_grad():
                for images, labels, clip_features in self.test_loader:
                    images = images.to(self.device)
                    labels = labels.to(self.device)
                    clip_features = clip_features.to(self.device)
                    t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

                    inc_feat = inc_enc(images)
                    mail = clip_features if use_mail_eval else None
                    tower_out = tower(inc_feat, text_features=t5_features, mail=mail)
                    pred = clf(tower_out).argmax(1)
                    correct += (pred == labels).sum().item()
                    total += images.size(0)

            return correct / total

        # Mail tower - with and without mail at eval
        mail_tower_with = eval_tower(self.tower, self.inception_enc, self.classifier, use_mail_eval=True)
        mail_tower_without = eval_tower(self.tower, self.inception_enc, self.classifier, use_mail_eval=False)

        # No-mail tower
        no_mail_tower = eval_tower(tower_no_mail, inception_no_mail, classifier_no_mail, use_mail_eval=False)

        print(f"\nMail Tower (trained with indoctrination):")
        print(f"  With mail at eval:    {mail_tower_with:.3f}")
        print(f"  Without mail at eval: {mail_tower_without:.3f}")

        print(f"\nNo-Mail Tower (trained blind):")
        print(f"  Accuracy:             {no_mail_tower:.3f}")

        print(f"\nBaseline:")
        print(f"  CLIP direct:          {history['clip_acc'][-1]:.3f}")

        print()
        print("=" * 60)
        diff = mail_tower_without - no_mail_tower
        if diff > 0.02:
            print(f"INDOCTRINATION WINS: +{diff:.1%} accuracy from mail training")
        elif diff > -0.02:
            print(f"COMPARABLE: Both towers within 2% ({mail_tower_without:.3f} vs {no_mail_tower:.3f})")
        else:
            print(f"NO-MAIL WINS: Blind tower better by {-diff:.1%}")
        print("=" * 60)

        return history


# =============================================================================
# MAIN
# =============================================================================

if __name__ == '__main__':
    experiment = InceptiveTowerExperimentV2(
        batch_size=64,
        lr=1e-4,
        device='cuda',
        cache_dir='./cache',
    )

    # Side-by-side comparison
    history = experiment.run_side_by_side(
        epochs=5,
        use_text=True,
    )

Creating feature caches (one-time)...


100%|██████████| 169M/169M [00:05<00:00, 29.1MB/s]


  Caching T5 features for 10 classes...
    Saved: torch.Size([100, 12, 768])
  Loading CLIP...


  Caching CLIP train: 100%|██████████| 782/782 [02:24<00:00,  5.43it/s]


    Train: torch.Size([50000, 768])


  Caching CLIP test: 100%|██████████| 157/157 [00:33<00:00,  4.73it/s]


    Test: torch.Size([10000, 768])
  Caching complete!
  Inception encoder: 3,157,440 params
  Inceptive tower: 12,601,088 params

SIDE-BY-SIDE TRAINING: Mail vs No Mail
Device: cuda
Use text conditioning: True

Tower A (with mail): 12,601,088 params
Tower B (no mail):   12,601,088 params

Training...
----------------------------------------------------------------------
Epoch |      Mail Tower      |    No Mail Tower     |   CLIP
      |     Loss      Acc |     Loss      Acc |    Acc
----------------------------------------------------------------------




    1 |   1.0090    0.924 |   1.0906    0.903 |  0.037




    2 |   0.3551    1.000 |   0.3889    1.000 |  0.311




    3 |   0.2852    1.000 |   0.3189    1.000 |  0.530




    4 |   0.2560    1.000 |   0.2795    1.000 |  0.618




    5 |   0.2437    1.000 |   0.2543    1.000 |  0.660
----------------------------------------------------------------------

Final Evaluation (test set)
----------------------------------------

Mail Tower (trained with indoctrination):
  With mail at eval:    1.000
  Without mail at eval: 1.000

No-Mail Tower (trained blind):
  Accuracy:             1.000

Baseline:
  CLIP direct:          0.660

COMPARABLE: Both towers within 2% (1.000 vs 1.000)


## better version, more learnable

In [None]:
"""
INCEPTIVE TOWER EXPERIMENT V2
==============================

Scaling up with real pretrained experts:
- CLIP ViT-base-patch16: Vision expert (provides mail)
- T5-base: Text expert (provides sequence features)
- Inception encoder: Inceptive view (sees differently)

Tower learns to translate inceptive view → expert space via indoctrination.

Features are cached to .pt files for speed.

Author: AbstractPhil + Claude
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from typing import Dict, Optional, Tuple
from transformers import (
    CLIPVisionModel,
    CLIPProcessor,
    T5EncoderModel,
    T5Tokenizer,
)
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import math


# =============================================================================
# INCEPTION ENCODER (The Inceptive View)
# =============================================================================

class InceptionBlock(nn.Module):
    """Multi-branch convolution block."""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        branch_ch = out_channels // 4
        remainder = out_channels % 4

        self.branch_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_3x3 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_ch, branch_ch, 3, padding=1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_5x5 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_ch, branch_ch, 5, padding=2),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, branch_ch + remainder, 1),
            nn.BatchNorm2d(branch_ch + remainder),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: Tensor) -> Tensor:
        return torch.cat([
            self.branch_1x1(x),
            self.branch_3x3(x),
            self.branch_5x5(x),
            self.branch_pool(x),
        ], dim=1)


class InceptionEncoder(nn.Module):
    """
    Inception-style encoder that outputs CLIP-compatible dimension.

    Processes 224x224 RGB images → [B, 768] features
    """

    def __init__(self, out_features: int = 768):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),  # 224 -> 112
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),  # 112 -> 56
        )

        self.inception1 = InceptionBlock(64, 128)
        self.pool1 = nn.MaxPool2d(2, 2)  # 56 -> 28

        self.inception2 = InceptionBlock(128, 256)
        self.pool2 = nn.MaxPool2d(2, 2)  # 28 -> 14

        self.inception3 = InceptionBlock(256, 512)
        self.pool3 = nn.MaxPool2d(2, 2)  # 14 -> 7

        self.inception4 = InceptionBlock(512, 768)
        self.pool4 = nn.AdaptiveAvgPool2d(1)  # 7 -> 1

        self.fc = nn.Linear(768, out_features)

    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.pool1(self.inception1(x))
        x = self.pool2(self.inception2(x))
        x = self.pool3(self.inception3(x))
        x = self.pool4(self.inception4(x))
        x = x.flatten(1)
        return self.fc(x)


# =============================================================================
# INCEPTIVE TOWER
# =============================================================================

class InceptiveTower(nn.Module):
    """
    Tower with inceptive view, indoctrinated to expert space.

    - Receives: Inception features (its inceptive view)
    - Optional: Text features from T5 for conditioning
    - Must output: predictions in CLIP expert space
    - Mail: CLIP's actual output (indoctrination signal)
    """

    def __init__(
        self,
        vision_dim: int = 768,
        text_dim: int = 768,
        hidden_dim: int = 1024,
        num_layers: int = 3,
        use_text: bool = True,
    ):
        super().__init__()

        self.vision_dim = vision_dim
        self.text_dim = text_dim
        self.use_text = use_text

        # Vision input projection
        self.vision_proj = nn.Linear(vision_dim, hidden_dim)

        # Text conditioning (optional)
        if use_text:
            self.text_proj = nn.Linear(text_dim, hidden_dim)
            self.cross_attn = nn.MultiheadAttention(
                hidden_dim, num_heads=8, batch_first=True
            )

        # Mail integration (indoctrination)
        self.mail_proj = nn.Linear(vision_dim, hidden_dim)
        self.mail_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid(),
        )

        # Processing layers
        layers = []
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.1),
            ])
        self.layers = nn.Sequential(*layers)

        # Output projection (to CLIP space)
        self.output_proj = nn.Linear(hidden_dim, vision_dim)

    def forward(
        self,
        inception_features: Tensor,
        text_features: Optional[Tensor] = None,
        mail: Optional[Tensor] = None,
        mail_dropout: float = 0.0,
    ) -> Tensor:
        """
        Args:
            inception_features: [B, vision_dim] from InceptionEncoder
            text_features: [B, seq_len, text_dim] from T5 (optional)
            mail: [B, vision_dim] from CLIP (ground truth)
            mail_dropout: probability of dropping mail
        """
        B = inception_features.shape[0]
        device = inception_features.device

        # Project vision
        h = self.vision_proj(inception_features)  # [B, hidden]

        # Condition on text if available
        if self.use_text and text_features is not None:
            text_h = self.text_proj(text_features)  # [B, seq_len, hidden]
            h_query = h.unsqueeze(1)  # [B, 1, hidden]
            h_attn, _ = self.cross_attn(h_query, text_h, text_h)
            h = h + h_attn.squeeze(1)  # [B, hidden]

        # Integrate mail (indoctrination)
        if mail is not None:
            if mail_dropout > 0 and self.training:
                mask = (torch.rand(B, 1, device=device) > mail_dropout).float()
                mail = mail * mask

            mail_h = self.mail_proj(mail)
            combined = torch.cat([h, mail_h], dim=-1)
            gate = self.mail_gate(combined)
            h = h * (1 - gate) + mail_h * gate

        # Process
        h = self.layers(h)

        # Output in CLIP space
        return self.output_proj(h)


# =============================================================================
# EXPERIMENT
# =============================================================================

class InceptiveTowerExperimentV2:
    """
    Experiment with real pretrained experts:
    - CLIP ViT-base-patch16 (vision expert)
    - T5-base (text expert)
    - Inception encoder (inceptive view)

    Features cached to .pt files for speed.
    """

    def __init__(
        self,
        batch_size: int = 32,
        lr: float = 1e-4,
        device: str = 'cuda',
        cache_dir: str = './cache',
    ):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        # CIFAR-100 class names
        self.class_names = [
            'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
            'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
            'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
            'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
            'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
            'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
            'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
            'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
            'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
            'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
            'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
            'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
            'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
            'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
        ]
        self.num_classes = 100

        # Dims
        self.clip_dim = 768
        self.t5_dim = 768

        # Cache paths (CIFAR-100)
        self.t5_cache_path = self.cache_dir / 't5_class_features_cifar100.pt'
        self.clip_train_cache_path = self.cache_dir / 'clip_train_features_cifar100.pt'
        self.clip_test_cache_path = self.cache_dir / 'clip_test_features_cifar100.pt'
        self.images_train_cache_path = self.cache_dir / 'images_train_cifar100.pt'
        self.images_test_cache_path = self.cache_dir / 'images_test_cifar100.pt'
        self.labels_train_cache_path = self.cache_dir / 'labels_train_cifar100.pt'
        self.labels_test_cache_path = self.cache_dir / 'labels_test_cifar100.pt'

        # Load or create caches
        self._setup_caches()

        # Create trainable modules
        self.inception_enc = InceptionEncoder(out_features=self.clip_dim).to(self.device)
        self.tower = InceptiveTower(
            vision_dim=self.clip_dim,
            text_dim=self.t5_dim,
            hidden_dim=1024,
            num_layers=3,
            use_text=True,
        ).to(self.device)

        # Classifier head
        self.classifier = nn.Linear(self.clip_dim, self.num_classes).to(self.device)

        print(f"  Inception encoder: {sum(p.numel() for p in self.inception_enc.parameters()):,} params")
        print(f"  Inceptive tower: {sum(p.numel() for p in self.tower.parameters()):,} params")

        # Optimizers
        self.opt_inception = torch.optim.AdamW(self.inception_enc.parameters(), lr=lr)
        self.opt_tower = torch.optim.AdamW(self.tower.parameters(), lr=lr)
        self.opt_clf = torch.optim.AdamW(self.classifier.parameters(), lr=lr)

        # Create dataloaders from cached tensors
        train_dataset = TensorDataset(
            self.images_train, self.labels_train, self.clip_train_features
        )
        test_dataset = TensorDataset(
            self.images_test, self.labels_test, self.clip_test_features
        )

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
        )
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False, num_workers=0
        )

    def _setup_caches(self):
        """Load or create all cached features."""

        # Data transforms
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711],
            ),
        ])

        # Check if all caches exist
        all_cached = all([
            self.t5_cache_path.exists(),
            self.clip_train_cache_path.exists(),
            self.clip_test_cache_path.exists(),
            self.images_train_cache_path.exists(),
            self.images_test_cache_path.exists(),
            self.labels_train_cache_path.exists(),
            self.labels_test_cache_path.exists(),
        ])

        if all_cached:
            print("Loading cached features...")
            self.t5_features = torch.load(self.t5_cache_path)
            self.clip_train_features = torch.load(self.clip_train_cache_path)
            self.clip_test_features = torch.load(self.clip_test_cache_path)
            self.images_train = torch.load(self.images_train_cache_path)
            self.images_test = torch.load(self.images_test_cache_path)
            self.labels_train = torch.load(self.labels_train_cache_path)
            self.labels_test = torch.load(self.labels_test_cache_path)
            print(f"  T5 features: {self.t5_features.shape}")
            print(f"  CLIP train: {self.clip_train_features.shape}")
            print(f"  CLIP test: {self.clip_test_features.shape}")
            return

        print("Creating feature caches (one-time)...")

        # Load raw datasets (CIFAR-100)
        train_data = datasets.CIFAR100('./data', train=True, download=True, transform=transform)
        test_data = datasets.CIFAR100('./data', train=False, download=True, transform=transform)

        # ===== Cache T5 features (10 classes) =====
        print("  Caching T5 features for 10 classes...")
        t5_model = T5EncoderModel.from_pretrained("t5-base").to(self.device)
        t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
        t5_model.eval()

        texts = [f"a photo of a {name}" for name in self.class_names]
        tokens = t5_tokenizer(
            texts, return_tensors='pt', padding=True, truncation=True, max_length=32
        ).to(self.device)

        with torch.no_grad():
            t5_out = t5_model(**tokens)
            self.t5_features = t5_out.last_hidden_state.cpu()  # [10, seq_len, 768]

        torch.save(self.t5_features, self.t5_cache_path)
        print(f"    Saved: {self.t5_features.shape}")

        del t5_model, t5_tokenizer
        torch.cuda.empty_cache()

        # ===== Cache CLIP features =====
        print("  Loading CLIP...")
        clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
        clip_model.eval()

        def extract_clip_and_images(dataset, desc):
            loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
            all_clip = []
            all_images = []
            all_labels = []

            with torch.no_grad():
                for images, labels in tqdm(loader, desc=desc):
                    images = images.to(self.device)
                    clip_out = clip_model(pixel_values=images).pooler_output
                    all_clip.append(clip_out.cpu())
                    all_images.append(images.cpu())
                    all_labels.append(labels)

            return (
                torch.cat(all_clip, dim=0),
                torch.cat(all_images, dim=0),
                torch.cat(all_labels, dim=0),
            )

        # Train set
        self.clip_train_features, self.images_train, self.labels_train = extract_clip_and_images(
            train_data, "  Caching CLIP train"
        )
        torch.save(self.clip_train_features, self.clip_train_cache_path)
        torch.save(self.images_train, self.images_train_cache_path)
        torch.save(self.labels_train, self.labels_train_cache_path)
        print(f"    Train: {self.clip_train_features.shape}")

        # Test set
        self.clip_test_features, self.images_test, self.labels_test = extract_clip_and_images(
            test_data, "  Caching CLIP test"
        )
        torch.save(self.clip_test_features, self.clip_test_cache_path)
        torch.save(self.images_test, self.images_test_cache_path)
        torch.save(self.labels_test, self.labels_test_cache_path)
        print(f"    Test: {self.clip_test_features.shape}")

        del clip_model
        torch.cuda.empty_cache()
        print("  Caching complete!")

    def get_t5_features_for_labels(self, labels: Tensor) -> Tensor:
        """Look up cached T5 features by label index."""
        # self.t5_features: [10, seq_len, 768]
        # labels: [B]
        # return: [B, seq_len, 768]
        return self.t5_features[labels.cpu()]

    def train_epoch(
        self,
        mail_dropout: float = 0.0,
        use_text: bool = True,
    ) -> Dict[str, float]:
        """Train one epoch."""
        self.inception_enc.train()
        self.tower.train()
        self.classifier.train()

        total_loss_trans = 0
        total_loss_clf = 0
        correct_tower = 0
        correct_clip = 0
        total = 0

        pbar = tqdm(self.train_loader, desc="Training", leave=False)
        for images, labels, clip_features in pbar:
            images = images.to(self.device)
            labels = labels.to(self.device)
            clip_features = clip_features.to(self.device)

            # Get T5 features from cache
            t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

            # Get inceptive view
            inception_features = self.inception_enc(images)

            # Tower predicts CLIP space from inception view
            tower_output = self.tower(
                inception_features,
                text_features=t5_features,
                mail=clip_features,
                mail_dropout=mail_dropout,
            )

            # Losses
            loss_trans = F.mse_loss(tower_output, clip_features)

            # Classification from tower output
            logits = self.classifier(tower_output)
            loss_clf = F.cross_entropy(logits, labels)

            loss = loss_trans + loss_clf

            # Optimize
            self.opt_inception.zero_grad()
            self.opt_tower.zero_grad()
            self.opt_clf.zero_grad()
            loss.backward()
            self.opt_inception.step()
            self.opt_tower.step()
            self.opt_clf.step()

            # Track metrics
            total_loss_trans += loss_trans.item() * images.size(0)
            total_loss_clf += loss_clf.item() * images.size(0)
            correct_tower += (logits.argmax(1) == labels).sum().item()

            # CLIP baseline accuracy
            with torch.no_grad():
                clip_logits = self.classifier(clip_features)
                correct_clip += (clip_logits.argmax(1) == labels).sum().item()

            total += images.size(0)

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'acc': f'{correct_tower/total:.3f}',
            })

        return {
            'loss_translation': total_loss_trans / total,
            'loss_clf': total_loss_clf / total,
            'acc_tower': correct_tower / total,
            'acc_clip': correct_clip / total,
        }

    @torch.no_grad()
    def evaluate(self, use_mail: bool = True, use_text: bool = True) -> Dict[str, float]:
        """Evaluate all paths."""
        self.inception_enc.eval()
        self.tower.eval()
        self.classifier.eval()

        correct_tower = 0
        correct_clip = 0
        correct_inception = 0
        trans_error = 0
        total = 0

        desc = f"Eval (mail={use_mail}, text={use_text})"
        for images, labels, clip_features in tqdm(self.test_loader, desc=desc, leave=False):
            images = images.to(self.device)
            labels = labels.to(self.device)
            clip_features = clip_features.to(self.device)

            # T5 features
            t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

            # Inception features
            inception_features = self.inception_enc(images)

            # Tower output
            mail = clip_features if use_mail else None
            tower_output = self.tower(
                inception_features,
                text_features=t5_features,
                mail=mail,
                mail_dropout=0.0,
            )

            # Classifications
            pred_tower = self.classifier(tower_output).argmax(1)
            pred_clip = self.classifier(clip_features).argmax(1)
            pred_inception = self.classifier(inception_features).argmax(1)

            correct_tower += (pred_tower == labels).sum().item()
            correct_clip += (pred_clip == labels).sum().item()
            correct_inception += (pred_inception == labels).sum().item()
            trans_error += F.mse_loss(tower_output, clip_features, reduction='sum').item()
            total += images.size(0)

        return {
            'acc_tower': correct_tower / total,
            'acc_clip': correct_clip / total,
            'acc_inception': correct_inception / total,
            'translation_mse': trans_error / total / self.clip_dim,
        }

    def run(
        self,
        epochs: int = 10,
        mail_curriculum: bool = True,
        use_text: bool = True,
    ):
        """Run full experiment."""
        print("\n" + "=" * 60)
        print("INCEPTIVE TOWER EXPERIMENT V2")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Vision expert: CLIP ViT-base-patch16 (768-dim)")
        print(f"Text expert: T5-base (768-dim)")
        print(f"Use text conditioning: {use_text}")
        print()

        # Training with indoctrination
        print("Training with indoctrination...")
        print("-" * 40)

        for epoch in range(epochs):
            if mail_curriculum:
                mail_dropout = min(0.8, epoch / epochs)
            else:
                mail_dropout = 0.0

            metrics = self.train_epoch(mail_dropout=mail_dropout, use_text=use_text)
            print(f"Epoch {epoch+1}/{epochs} | "
                  f"Trans: {metrics['loss_translation']:.4f} | "
                  f"Tower: {metrics['acc_tower']:.3f} | "
                  f"CLIP: {metrics['acc_clip']:.3f} | "
                  f"Mail drop: {mail_dropout:.1%}")

        print()

        # Evaluation
        print("Evaluation")
        print("-" * 40)

        metrics_with_mail = self.evaluate(use_mail=True, use_text=use_text)
        metrics_without_mail = self.evaluate(use_mail=False, use_text=use_text)
        metrics_no_text = self.evaluate(use_mail=False, use_text=False)

        print(f"\nWith mail + text:")
        print(f"  Tower accuracy:    {metrics_with_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_with_mail['translation_mse']:.4f}")

        print(f"\nWithout mail (+ text):")
        print(f"  Tower accuracy:    {metrics_without_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_without_mail['translation_mse']:.4f}")

        print(f"\nWithout mail, without text:")
        print(f"  Tower accuracy:    {metrics_no_text['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_no_text['translation_mse']:.4f}")

        print(f"\nBaselines:")
        print(f"  CLIP direct:       {metrics_with_mail['acc_clip']:.3f}")
        print(f"  Inception direct:  {metrics_with_mail['acc_inception']:.3f}")

        # Success criteria
        print()
        print("=" * 60)
        no_mail_acc = metrics_without_mail['acc_tower']
        clip_acc = metrics_with_mail['acc_clip']

        if no_mail_acc >= clip_acc * 0.95:
            print(f"SUCCESS: Tower matches CLIP without mail ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        elif no_mail_acc >= clip_acc * 0.85:
            print(f"PARTIAL: Tower approaches CLIP ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        else:
            print(f"NEEDS WORK: Tower below CLIP ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        print("=" * 60)

        return {
            'with_mail': metrics_with_mail,
            'without_mail': metrics_without_mail,
            'no_text': metrics_no_text,
        }

    def run_side_by_side(
        self,
        epochs: int = 15,
        use_text: bool = True,
    ):
        """
        Train two towers side-by-side:
        - Tower A: WITH mail (indoctrinated)
        - Tower B: WITHOUT mail (blind)

        Compare learning curves directly.
        """
        print("\n" + "=" * 60)
        print("SIDE-BY-SIDE TRAINING: Mail vs No Mail")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Use text conditioning: {use_text}")
        print()

        # Create second tower (no mail)
        tower_no_mail = InceptiveTower(
            vision_dim=self.clip_dim,
            text_dim=self.t5_dim,
            hidden_dim=1024,
            num_layers=3,
            use_text=True,
        ).to(self.device)

        inception_no_mail = InceptionEncoder(out_features=self.clip_dim).to(self.device)
        classifier_no_mail = nn.Linear(self.clip_dim, self.num_classes).to(self.device)

        opt_tower_no_mail = torch.optim.AdamW(tower_no_mail.parameters(), lr=1e-4)
        opt_inc_no_mail = torch.optim.AdamW(inception_no_mail.parameters(), lr=1e-4)
        opt_clf_no_mail = torch.optim.AdamW(classifier_no_mail.parameters(), lr=1e-4)

        print(f"Tower A (with mail): {sum(p.numel() for p in self.tower.parameters()):,} params")
        print(f"Tower B (no mail):   {sum(p.numel() for p in tower_no_mail.parameters()):,} params")
        print()

        print("Training...")
        print("-" * 85)
        print(f"{'Epoch':>5} | {'Mail Tower':^20} | {'No Mail Tower':^20} | {'Test Acc (no mail)':^18}")
        print(f"{'':>5} | {'Loss':>8} {'Acc':>8} | {'Loss':>8} {'Acc':>8} | {'Mail':>8} {'NoMail':>8}")
        print("-" * 85)

        history = {
            'mail_acc': [], 'mail_loss': [],
            'no_mail_acc': [], 'no_mail_loss': [],
            'clip_acc': [],
            'test_mail': [], 'test_no_mail': [],
        }

        for epoch in range(epochs):
            # Curriculum for mail tower
            mail_dropout = min(0.8, epoch / epochs)

            # ===== Train both towers =====
            self.inception_enc.train()
            self.tower.train()
            self.classifier.train()
            inception_no_mail.train()
            tower_no_mail.train()
            classifier_no_mail.train()

            mail_loss_sum, mail_correct = 0, 0
            no_mail_loss_sum, no_mail_correct = 0, 0
            clip_correct = 0
            total = 0

            pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
            for images, labels, clip_features in pbar:
                images = images.to(self.device)
                labels = labels.to(self.device)
                clip_features = clip_features.to(self.device)
                t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

                # ===== Tower A: With Mail =====
                inc_feat = self.inception_enc(images)
                tower_out = self.tower(
                    inc_feat, text_features=t5_features,
                    mail=clip_features, mail_dropout=mail_dropout
                )
                loss_trans = F.mse_loss(tower_out, clip_features)
                logits = self.classifier(tower_out)
                loss_clf = F.cross_entropy(logits, labels)
                loss_mail = loss_trans + loss_clf

                self.opt_inception.zero_grad()
                self.opt_tower.zero_grad()
                self.opt_clf.zero_grad()
                loss_mail.backward()
                self.opt_inception.step()
                self.opt_tower.step()
                self.opt_clf.step()

                mail_loss_sum += loss_mail.item() * images.size(0)
                mail_correct += (logits.argmax(1) == labels).sum().item()

                # ===== Tower B: No Mail =====
                inc_feat_b = inception_no_mail(images)
                tower_out_b = tower_no_mail(
                    inc_feat_b, text_features=t5_features,
                    mail=None, mail_dropout=0.0  # Never gets mail
                )
                loss_trans_b = F.mse_loss(tower_out_b, clip_features)
                logits_b = classifier_no_mail(tower_out_b)
                loss_clf_b = F.cross_entropy(logits_b, labels)
                loss_no_mail = loss_trans_b + loss_clf_b

                opt_inc_no_mail.zero_grad()
                opt_tower_no_mail.zero_grad()
                opt_clf_no_mail.zero_grad()
                loss_no_mail.backward()
                opt_inc_no_mail.step()
                opt_tower_no_mail.step()
                opt_clf_no_mail.step()

                no_mail_loss_sum += loss_no_mail.item() * images.size(0)
                no_mail_correct += (logits_b.argmax(1) == labels).sum().item()

                # CLIP baseline
                with torch.no_grad():
                    clip_logits = self.classifier(clip_features)
                    clip_correct += (clip_logits.argmax(1) == labels).sum().item()

                total += images.size(0)

                pbar.set_postfix({
                    'mail': f'{mail_correct/total:.3f}',
                    'no_mail': f'{no_mail_correct/total:.3f}',
                })

            # Epoch stats
            mail_acc = mail_correct / total
            mail_loss = mail_loss_sum / total
            no_mail_acc = no_mail_correct / total
            no_mail_loss = no_mail_loss_sum / total
            clip_acc = clip_correct / total

            history['mail_acc'].append(mail_acc)
            history['mail_loss'].append(mail_loss)
            history['no_mail_acc'].append(no_mail_acc)
            history['no_mail_loss'].append(no_mail_loss)
            history['clip_acc'].append(clip_acc)

            # Quick eval on test set (without mail for both)
            def quick_eval(tower, inc_enc, clf):
                tower.eval()
                inc_enc.eval()
                clf.eval()
                correct = 0
                total_eval = 0
                with torch.no_grad():
                    for img, lbl, clip_f in self.test_loader:
                        img, lbl = img.to(self.device), lbl.to(self.device)
                        t5_f = self.get_t5_features_for_labels(lbl).to(self.device) if use_text else None
                        inc_f = inc_enc(img)
                        out = tower(inc_f, text_features=t5_f, mail=None)
                        pred = clf(out).argmax(1)
                        correct += (pred == lbl).sum().item()
                        total_eval += img.size(0)
                tower.train()
                inc_enc.train()
                clf.train()
                return correct / total_eval

            test_acc_mail = quick_eval(self.tower, self.inception_enc, self.classifier)
            test_acc_no_mail = quick_eval(tower_no_mail, inception_no_mail, classifier_no_mail)

            history['test_mail'].append(test_acc_mail)
            history['test_no_mail'].append(test_acc_no_mail)

            print(f"{epoch+1:>5} | {mail_loss:>8.4f} {mail_acc:>8.3f} | {no_mail_loss:>8.4f} {no_mail_acc:>8.3f} | {test_acc_mail:>8.3f} {test_acc_no_mail:>8.3f}")

        print("-" * 85)
        print()

        # ===== Final Evaluation =====
        print("Final Evaluation (test set)")
        print("-" * 40)

        def eval_tower(tower, inc_enc, clf, use_mail_eval):
            tower.eval()
            inc_enc.eval()
            clf.eval()
            correct = 0
            total = 0

            with torch.no_grad():
                for images, labels, clip_features in self.test_loader:
                    images = images.to(self.device)
                    labels = labels.to(self.device)
                    clip_features = clip_features.to(self.device)
                    t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

                    inc_feat = inc_enc(images)
                    mail = clip_features if use_mail_eval else None
                    tower_out = tower(inc_feat, text_features=t5_features, mail=mail)
                    pred = clf(tower_out).argmax(1)
                    correct += (pred == labels).sum().item()
                    total += images.size(0)

            return correct / total

        # Mail tower - with and without mail at eval
        mail_tower_with = eval_tower(self.tower, self.inception_enc, self.classifier, use_mail_eval=True)
        mail_tower_without = eval_tower(self.tower, self.inception_enc, self.classifier, use_mail_eval=False)

        # No-mail tower
        no_mail_tower = eval_tower(tower_no_mail, inception_no_mail, classifier_no_mail, use_mail_eval=False)

        print(f"\nMail Tower (trained with indoctrination):")
        print(f"  With mail at eval:    {mail_tower_with:.3f}")
        print(f"  Without mail at eval: {mail_tower_without:.3f}")

        print(f"\nNo-Mail Tower (trained blind):")
        print(f"  Accuracy:             {no_mail_tower:.3f}")

        print(f"\nBaseline:")
        print(f"  CLIP direct:          {history['clip_acc'][-1]:.3f}")

        print()
        print("=" * 60)
        diff = mail_tower_without - no_mail_tower
        if diff > 0.02:
            print(f"INDOCTRINATION WINS: +{diff:.1%} accuracy from mail training")
        elif diff > -0.02:
            print(f"COMPARABLE: Both towers within 2% ({mail_tower_without:.3f} vs {no_mail_tower:.3f})")
        else:
            print(f"NO-MAIL WINS: Blind tower better by {-diff:.1%}")
        print("=" * 60)

        return history


# =============================================================================
# MAIN
# =============================================================================

if __name__ == '__main__':
    experiment = InceptiveTowerExperimentV2(
        batch_size=64,
        lr=1e-4,
        device='cuda',
        cache_dir='./cache',
    )

    # Side-by-side comparison
    history = experiment.run_side_by_side(
        epochs=15,
        use_text=True,
    )

Loading cached features...
  T5 features: torch.Size([100, 12, 768])
  CLIP train: torch.Size([50000, 768])
  CLIP test: torch.Size([10000, 768])
  Inception encoder: 3,157,440 params
  Inceptive tower: 12,601,088 params

SIDE-BY-SIDE TRAINING: Mail vs No Mail
Device: cuda
Use text conditioning: True

Tower A (with mail): 12,601,088 params
Tower B (no mail):   12,601,088 params

Training...
-------------------------------------------------------------------------------------
Epoch |      Mail Tower      |    No Mail Tower     | Test Acc (no mail)
      |     Loss      Acc |     Loss      Acc |     Mail   NoMail
-------------------------------------------------------------------------------------




    1 |   1.0201    0.922 |   1.0827    0.904 |    1.000    1.000




    2 |   0.3474    1.000 |   0.3888    1.000 |    1.000    1.000




    3 |   0.2649    1.000 |   0.3193    1.000 |    1.000    1.000




    4 |   0.2215    1.000 |   0.2798    1.000 |    1.000    1.000




    5 |   0.1976    1.000 |   0.2543    1.000 |    1.000    1.000




    6 |   0.1835    1.000 |   0.2363    1.000 |    1.000    1.000




    7 |   0.1774    1.000 |   0.2228    1.000 |    1.000    1.000




    8 |   0.1732    1.000 |   0.2127    1.000 |    1.000    1.000




    9 |   0.1721    1.000 |   0.2044    1.000 |    1.000    1.000




   10 |   0.1728    1.000 |   0.1980    1.000 |    1.000    1.000




   11 |   0.1750    1.000 |   0.1921    1.000 |    1.000    1.000




   12 |   0.1756    1.000 |   0.1873    1.000 |    1.000    1.000




   13 |   0.1868    1.000 |   0.1822    1.000 |    1.000    1.000




   14 |   0.1773    1.000 |   0.1813    1.000 |    1.000    1.000




   15 |   0.1802    1.000 |   0.1751    1.000 |    1.000    1.000
-------------------------------------------------------------------------------------

Final Evaluation (test set)
----------------------------------------

Mail Tower (trained with indoctrination):
  With mail at eval:    0.999
  Without mail at eval: 1.000

No-Mail Tower (trained blind):
  Accuracy:             1.000

Baseline:
  CLIP direct:          0.762

COMPARABLE: Both towers within 2% (1.000 vs 1.000)


# non-corrupted

In [None]:
"""
INCEPTIVE TOWER EXPERIMENT V2
==============================

Scaling up with real pretrained experts:
- CLIP ViT-base-patch16: Vision expert (provides mail)
- T5-base: Text expert (provides sequence features)
- Inception encoder: Inceptive view (sees differently)

Tower learns to translate inceptive view → expert space via indoctrination.

Features are cached to .pt files for speed.

Author: AbstractPhil + Claude
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from typing import Dict, Optional, Tuple
from transformers import (
    CLIPVisionModel,
    CLIPProcessor,
    T5EncoderModel,
    T5Tokenizer,
)
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import math


# =============================================================================
# INCEPTION ENCODER (The Inceptive View)
# =============================================================================

class InceptionBlock(nn.Module):
    """Multi-branch convolution block."""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        branch_ch = out_channels // 4
        remainder = out_channels % 4

        self.branch_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_3x3 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_ch, branch_ch, 3, padding=1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_5x5 = nn.Sequential(
            nn.Conv2d(in_channels, branch_ch, 1),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_ch, branch_ch, 5, padding=2),
            nn.BatchNorm2d(branch_ch),
            nn.ReLU(inplace=True),
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, branch_ch + remainder, 1),
            nn.BatchNorm2d(branch_ch + remainder),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: Tensor) -> Tensor:
        return torch.cat([
            self.branch_1x1(x),
            self.branch_3x3(x),
            self.branch_5x5(x),
            self.branch_pool(x),
        ], dim=1)


class InceptionEncoder(nn.Module):
    """
    Inception-style encoder that outputs CLIP-compatible dimension.

    Processes 224x224 RGB images → [B, 768] features
    """

    def __init__(self, out_features: int = 768):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),  # 224 -> 112
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),  # 112 -> 56
        )

        self.inception1 = InceptionBlock(64, 128)
        self.pool1 = nn.MaxPool2d(2, 2)  # 56 -> 28

        self.inception2 = InceptionBlock(128, 256)
        self.pool2 = nn.MaxPool2d(2, 2)  # 28 -> 14

        self.inception3 = InceptionBlock(256, 512)
        self.pool3 = nn.MaxPool2d(2, 2)  # 14 -> 7

        self.inception4 = InceptionBlock(512, 768)
        self.pool4 = nn.AdaptiveAvgPool2d(1)  # 7 -> 1

        self.fc = nn.Linear(768, out_features)

    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.pool1(self.inception1(x))
        x = self.pool2(self.inception2(x))
        x = self.pool3(self.inception3(x))
        x = self.pool4(self.inception4(x))
        x = x.flatten(1)
        return self.fc(x)


# =============================================================================
# INCEPTIVE TOWER
# =============================================================================

class InceptiveTower(nn.Module):
    """
    Tower with inceptive view, indoctrinated to expert space.

    - Receives: Inception features (its inceptive view)
    - Optional: Text features from T5 for conditioning
    - Must output: predictions in CLIP expert space
    - Mail: CLIP's actual output (indoctrination signal)
    """

    def __init__(
        self,
        vision_dim: int = 768,
        text_dim: int = 768,
        hidden_dim: int = 1024,
        num_layers: int = 3,
        use_text: bool = True,
    ):
        super().__init__()

        self.vision_dim = vision_dim
        self.text_dim = text_dim
        self.use_text = use_text

        # Vision input projection
        self.vision_proj = nn.Linear(vision_dim, hidden_dim)

        # Text conditioning (optional)
        if use_text:
            self.text_proj = nn.Linear(text_dim, hidden_dim)
            self.cross_attn = nn.MultiheadAttention(
                hidden_dim, num_heads=8, batch_first=True
            )

        # Mail integration (indoctrination)
        self.mail_proj = nn.Linear(vision_dim, hidden_dim)
        self.mail_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid(),
        )

        # Processing layers
        layers = []
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.1),
            ])
        self.layers = nn.Sequential(*layers)

        # Output projection (to CLIP space)
        self.output_proj = nn.Linear(hidden_dim, vision_dim)

    def forward(
        self,
        inception_features: Tensor,
        text_features: Optional[Tensor] = None,
        mail: Optional[Tensor] = None,
        mail_dropout: float = 0.0,
    ) -> Tensor:
        """
        Args:
            inception_features: [B, vision_dim] from InceptionEncoder
            text_features: [B, seq_len, text_dim] from T5 (optional)
            mail: [B, vision_dim] from CLIP (ground truth)
            mail_dropout: probability of dropping mail
        """
        B = inception_features.shape[0]
        device = inception_features.device

        # Project vision
        h = self.vision_proj(inception_features)  # [B, hidden]

        # Condition on text if available
        if self.use_text and text_features is not None:
            text_h = self.text_proj(text_features)  # [B, seq_len, hidden]
            h_query = h.unsqueeze(1)  # [B, 1, hidden]
            h_attn, _ = self.cross_attn(h_query, text_h, text_h)
            h = h + h_attn.squeeze(1)  # [B, hidden]

        # Integrate mail (indoctrination)
        if mail is not None:
            if mail_dropout > 0 and self.training:
                mask = (torch.rand(B, 1, device=device) > mail_dropout).float()
                mail = mail * mask

            mail_h = self.mail_proj(mail)
            combined = torch.cat([h, mail_h], dim=-1)
            gate = self.mail_gate(combined)
            h = h * (1 - gate) + mail_h * gate

        # Process
        h = self.layers(h)

        # Output in CLIP space
        return self.output_proj(h)


# =============================================================================
# EXPERIMENT
# =============================================================================

class InceptiveTowerExperimentV2:
    """
    Experiment with real pretrained experts:
    - CLIP ViT-base-patch16 (vision expert)
    - T5-base (text expert)
    - Inception encoder (inceptive view)

    Features cached to .pt files for speed.
    """

    def __init__(
        self,
        batch_size: int = 32,
        lr: float = 1e-4,
        device: str = 'cuda',
        cache_dir: str = './cache',
    ):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        # CIFAR-100 class names
        self.class_names = [
            'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
            'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
            'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
            'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
            'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
            'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
            'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
            'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
            'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
            'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
            'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
            'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
            'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
            'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
        ]
        self.num_classes = 100

        # Dims
        self.clip_dim = 768
        self.t5_dim = 768

        # Cache paths (CIFAR-100)
        self.t5_cache_path = self.cache_dir / 't5_class_features_cifar100.pt'
        self.clip_train_cache_path = self.cache_dir / 'clip_train_features_cifar100.pt'
        self.clip_test_cache_path = self.cache_dir / 'clip_test_features_cifar100.pt'
        self.images_train_cache_path = self.cache_dir / 'images_train_cifar100.pt'
        self.images_test_cache_path = self.cache_dir / 'images_test_cifar100.pt'
        self.labels_train_cache_path = self.cache_dir / 'labels_train_cifar100.pt'
        self.labels_test_cache_path = self.cache_dir / 'labels_test_cifar100.pt'

        # Load or create caches
        self._setup_caches()

        # Create trainable modules
        self.inception_enc = InceptionEncoder(out_features=self.clip_dim).to(self.device)
        self.tower = InceptiveTower(
            vision_dim=self.clip_dim,
            text_dim=self.t5_dim,
            hidden_dim=1024,
            num_layers=3,
            use_text=True,
        ).to(self.device)

        # Classifier head
        self.classifier = nn.Linear(self.clip_dim, self.num_classes).to(self.device)

        print(f"  Inception encoder: {sum(p.numel() for p in self.inception_enc.parameters()):,} params")
        print(f"  Inceptive tower: {sum(p.numel() for p in self.tower.parameters()):,} params")

        # Optimizers
        self.opt_inception = torch.optim.AdamW(self.inception_enc.parameters(), lr=lr)
        self.opt_tower = torch.optim.AdamW(self.tower.parameters(), lr=lr)
        self.opt_clf = torch.optim.AdamW(self.classifier.parameters(), lr=lr)

        # Create dataloaders from cached tensors
        train_dataset = TensorDataset(
            self.images_train, self.labels_train, self.clip_train_features
        )
        test_dataset = TensorDataset(
            self.images_test, self.labels_test, self.clip_test_features
        )

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
        )
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False, num_workers=0
        )

    def _setup_caches(self):
        """Load or create all cached features."""

        # Data transforms
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711],
            ),
        ])

        # Check if all caches exist
        all_cached = all([
            self.t5_cache_path.exists(),
            self.clip_train_cache_path.exists(),
            self.clip_test_cache_path.exists(),
            self.images_train_cache_path.exists(),
            self.images_test_cache_path.exists(),
            self.labels_train_cache_path.exists(),
            self.labels_test_cache_path.exists(),
        ])

        if all_cached:
            print("Loading cached features...")
            self.t5_features = torch.load(self.t5_cache_path)
            self.clip_train_features = torch.load(self.clip_train_cache_path)
            self.clip_test_features = torch.load(self.clip_test_cache_path)
            self.images_train = torch.load(self.images_train_cache_path)
            self.images_test = torch.load(self.images_test_cache_path)
            self.labels_train = torch.load(self.labels_train_cache_path)
            self.labels_test = torch.load(self.labels_test_cache_path)
            print(f"  T5 features: {self.t5_features.shape}")
            print(f"  CLIP train: {self.clip_train_features.shape}")
            print(f"  CLIP test: {self.clip_test_features.shape}")
            return

        print("Creating feature caches (one-time)...")

        # Load raw datasets (CIFAR-100)
        train_data = datasets.CIFAR100('./data', train=True, download=True, transform=transform)
        test_data = datasets.CIFAR100('./data', train=False, download=True, transform=transform)

        # ===== Cache T5 features (10 classes) =====
        print("  Caching T5 features for 10 classes...")
        t5_model = T5EncoderModel.from_pretrained("t5-base").to(self.device)
        t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
        t5_model.eval()

        texts = [f"a photo of a {name}" for name in self.class_names]
        tokens = t5_tokenizer(
            texts, return_tensors='pt', padding=True, truncation=True, max_length=32
        ).to(self.device)

        with torch.no_grad():
            t5_out = t5_model(**tokens)
            self.t5_features = t5_out.last_hidden_state.cpu()  # [10, seq_len, 768]

        torch.save(self.t5_features, self.t5_cache_path)
        print(f"    Saved: {self.t5_features.shape}")

        del t5_model, t5_tokenizer
        torch.cuda.empty_cache()

        # ===== Cache CLIP features =====
        print("  Loading CLIP...")
        clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16").to(self.device)
        clip_model.eval()

        def extract_clip_and_images(dataset, desc):
            loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
            all_clip = []
            all_images = []
            all_labels = []

            with torch.no_grad():
                for images, labels in tqdm(loader, desc=desc):
                    images = images.to(self.device)
                    clip_out = clip_model(pixel_values=images).pooler_output
                    all_clip.append(clip_out.cpu())
                    all_images.append(images.cpu())
                    all_labels.append(labels)

            return (
                torch.cat(all_clip, dim=0),
                torch.cat(all_images, dim=0),
                torch.cat(all_labels, dim=0),
            )

        # Train set
        self.clip_train_features, self.images_train, self.labels_train = extract_clip_and_images(
            train_data, "  Caching CLIP train"
        )
        torch.save(self.clip_train_features, self.clip_train_cache_path)
        torch.save(self.images_train, self.images_train_cache_path)
        torch.save(self.labels_train, self.labels_train_cache_path)
        print(f"    Train: {self.clip_train_features.shape}")

        # Test set
        self.clip_test_features, self.images_test, self.labels_test = extract_clip_and_images(
            test_data, "  Caching CLIP test"
        )
        torch.save(self.clip_test_features, self.clip_test_cache_path)
        torch.save(self.images_test, self.images_test_cache_path)
        torch.save(self.labels_test, self.labels_test_cache_path)
        print(f"    Test: {self.clip_test_features.shape}")

        del clip_model
        torch.cuda.empty_cache()
        print("  Caching complete!")

    def get_t5_features_for_labels(self, labels: Tensor) -> Tensor:
        """Look up cached T5 features by label index."""
        # self.t5_features: [10, seq_len, 768]
        # labels: [B]
        # return: [B, seq_len, 768]
        return self.t5_features[labels.cpu()]

    def train_epoch(
        self,
        mail_dropout: float = 0.0,
        use_text: bool = True,
    ) -> Dict[str, float]:
        """Train one epoch."""
        self.inception_enc.train()
        self.tower.train()
        self.classifier.train()

        total_loss_trans = 0
        total_loss_clf = 0
        correct_tower = 0
        correct_clip = 0
        total = 0

        pbar = tqdm(self.train_loader, desc="Training", leave=False)
        for images, labels, clip_features in pbar:
            images = images.to(self.device)
            labels = labels.to(self.device)
            clip_features = clip_features.to(self.device)

            # Get T5 features from cache
            t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

            # Get inceptive view
            inception_features = self.inception_enc(images)

            # Tower predicts CLIP space from inception view
            tower_output = self.tower(
                inception_features,
                text_features=t5_features,
                mail=clip_features,
                mail_dropout=mail_dropout,
            )

            # Losses
            loss_trans = F.mse_loss(tower_output, clip_features)

            # Classification from tower output
            logits = self.classifier(tower_output)
            loss_clf = F.cross_entropy(logits, labels)

            loss = loss_trans + loss_clf

            # Optimize
            self.opt_inception.zero_grad()
            self.opt_tower.zero_grad()
            self.opt_clf.zero_grad()
            loss.backward()
            self.opt_inception.step()
            self.opt_tower.step()
            self.opt_clf.step()

            # Track metrics
            total_loss_trans += loss_trans.item() * images.size(0)
            total_loss_clf += loss_clf.item() * images.size(0)
            correct_tower += (logits.argmax(1) == labels).sum().item()

            # CLIP baseline accuracy
            with torch.no_grad():
                clip_logits = self.classifier(clip_features)
                correct_clip += (clip_logits.argmax(1) == labels).sum().item()

            total += images.size(0)

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.3f}',
                'acc': f'{correct_tower/total:.3f}',
            })

        return {
            'loss_translation': total_loss_trans / total,
            'loss_clf': total_loss_clf / total,
            'acc_tower': correct_tower / total,
            'acc_clip': correct_clip / total,
        }

    @torch.no_grad()
    def evaluate(self, use_mail: bool = True, use_text: bool = True) -> Dict[str, float]:
        """Evaluate all paths."""
        self.inception_enc.eval()
        self.tower.eval()
        self.classifier.eval()

        correct_tower = 0
        correct_clip = 0
        correct_inception = 0
        trans_error = 0
        total = 0

        desc = f"Eval (mail={use_mail}, text={use_text})"
        for images, labels, clip_features in tqdm(self.test_loader, desc=desc, leave=False):
            images = images.to(self.device)
            labels = labels.to(self.device)
            clip_features = clip_features.to(self.device)

            # T5 features
            t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

            # Inception features
            inception_features = self.inception_enc(images)

            # Tower output
            mail = clip_features if use_mail else None
            tower_output = self.tower(
                inception_features,
                text_features=t5_features,
                mail=mail,
                mail_dropout=0.0,
            )

            # Classifications
            pred_tower = self.classifier(tower_output).argmax(1)
            pred_clip = self.classifier(clip_features).argmax(1)
            pred_inception = self.classifier(inception_features).argmax(1)

            correct_tower += (pred_tower == labels).sum().item()
            correct_clip += (pred_clip == labels).sum().item()
            correct_inception += (pred_inception == labels).sum().item()
            trans_error += F.mse_loss(tower_output, clip_features, reduction='sum').item()
            total += images.size(0)

        return {
            'acc_tower': correct_tower / total,
            'acc_clip': correct_clip / total,
            'acc_inception': correct_inception / total,
            'translation_mse': trans_error / total / self.clip_dim,
        }

    def run(
        self,
        epochs: int = 10,
        mail_curriculum: bool = True,
        use_text: bool = True,
    ):
        """Run full experiment."""
        print("\n" + "=" * 60)
        print("INCEPTIVE TOWER EXPERIMENT V2")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Vision expert: CLIP ViT-base-patch16 (768-dim)")
        print(f"Text expert: T5-base (768-dim)")
        print(f"Use text conditioning: {use_text}")
        print()

        # Training with indoctrination
        print("Training with indoctrination...")
        print("-" * 40)

        for epoch in range(epochs):
            if mail_curriculum:
                mail_dropout = min(0.8, epoch / epochs)
            else:
                mail_dropout = 0.0

            metrics = self.train_epoch(mail_dropout=mail_dropout, use_text=use_text)
            print(f"Epoch {epoch+1}/{epochs} | "
                  f"Trans: {metrics['loss_translation']:.4f} | "
                  f"Tower: {metrics['acc_tower']:.3f} | "
                  f"CLIP: {metrics['acc_clip']:.3f} | "
                  f"Mail drop: {mail_dropout:.1%}")

        print()

        # Evaluation
        print("Evaluation")
        print("-" * 40)

        metrics_with_mail = self.evaluate(use_mail=True, use_text=use_text)
        metrics_without_mail = self.evaluate(use_mail=False, use_text=use_text)
        metrics_no_text = self.evaluate(use_mail=False, use_text=False)

        print(f"\nWith mail + text:")
        print(f"  Tower accuracy:    {metrics_with_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_with_mail['translation_mse']:.4f}")

        print(f"\nWithout mail (+ text):")
        print(f"  Tower accuracy:    {metrics_without_mail['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_without_mail['translation_mse']:.4f}")

        print(f"\nWithout mail, without text:")
        print(f"  Tower accuracy:    {metrics_no_text['acc_tower']:.3f}")
        print(f"  Translation MSE:   {metrics_no_text['translation_mse']:.4f}")

        print(f"\nBaselines:")
        print(f"  CLIP direct:       {metrics_with_mail['acc_clip']:.3f}")
        print(f"  Inception direct:  {metrics_with_mail['acc_inception']:.3f}")

        # Success criteria
        print()
        print("=" * 60)
        no_mail_acc = metrics_without_mail['acc_tower']
        clip_acc = metrics_with_mail['acc_clip']

        if no_mail_acc >= clip_acc * 0.95:
            print(f"SUCCESS: Tower matches CLIP without mail ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        elif no_mail_acc >= clip_acc * 0.85:
            print(f"PARTIAL: Tower approaches CLIP ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        else:
            print(f"NEEDS WORK: Tower below CLIP ({no_mail_acc:.3f} vs {clip_acc:.3f})")
        print("=" * 60)

        return {
            'with_mail': metrics_with_mail,
            'without_mail': metrics_without_mail,
            'no_text': metrics_no_text,
        }

    def run_side_by_side(
        self,
        epochs: int = 15,
        use_text: bool = True,
    ):
        """
        Train two towers side-by-side:
        - Tower A: WITH mail (indoctrinated)
        - Tower B: WITHOUT mail (blind)

        Compare learning curves directly.
        """
        print("\n" + "=" * 60)
        print("SIDE-BY-SIDE TRAINING: Mail vs No Mail")
        print("=" * 60)
        print(f"Device: {self.device}")
        print(f"Use text conditioning: {use_text}")
        print()

        # Create second tower (no mail)
        tower_no_mail = InceptiveTower(
            vision_dim=self.clip_dim,
            text_dim=self.t5_dim,
            hidden_dim=1024,
            num_layers=3,
            use_text=True,
        ).to(self.device)

        inception_no_mail = InceptionEncoder(out_features=self.clip_dim).to(self.device)
        classifier_no_mail = nn.Linear(self.clip_dim, self.num_classes).to(self.device)

        opt_tower_no_mail = torch.optim.AdamW(tower_no_mail.parameters(), lr=1e-4)
        opt_inc_no_mail = torch.optim.AdamW(inception_no_mail.parameters(), lr=1e-4)
        opt_clf_no_mail = torch.optim.AdamW(classifier_no_mail.parameters(), lr=1e-4)

        print(f"Tower A (with mail): {sum(p.numel() for p in self.tower.parameters()):,} params")
        print(f"Tower B (no mail):   {sum(p.numel() for p in tower_no_mail.parameters()):,} params")
        print()

        print("Training...")
        print("-" * 85)
        print(f"{'Epoch':>5} | {'Mail Tower':^20} | {'No Mail Tower':^20} | {'Test Acc (no mail)':^18}")
        print(f"{'':>5} | {'Loss':>8} {'Acc':>8} | {'Loss':>8} {'Acc':>8} | {'Mail':>8} {'NoMail':>8}")
        print("-" * 85)

        history = {
            'mail_acc': [], 'mail_loss': [],
            'no_mail_acc': [], 'no_mail_loss': [],
            'clip_acc': [],
            'test_mail': [], 'test_no_mail': [],
        }

        for epoch in range(epochs):
            # Curriculum for mail tower
            mail_dropout = min(0.8, epoch / epochs)

            # ===== Train both towers =====
            self.inception_enc.train()
            self.tower.train()
            self.classifier.train()
            inception_no_mail.train()
            tower_no_mail.train()
            classifier_no_mail.train()

            mail_loss_sum, mail_correct = 0, 0
            no_mail_loss_sum, no_mail_correct = 0, 0
            clip_correct = 0
            total = 0

            pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
            for images, labels, clip_features in pbar:
                images = images.to(self.device)
                labels = labels.to(self.device)
                clip_features = clip_features.to(self.device)
                t5_features = self.get_t5_features_for_labels(labels).to(self.device) if use_text else None

                # ===== Tower A: With Mail =====
                inc_feat = self.inception_enc(images)
                tower_out = self.tower(
                    inc_feat, text_features=t5_features,
                    mail=clip_features, mail_dropout=mail_dropout
                )
                loss_trans = F.mse_loss(tower_out, clip_features)
                logits = self.classifier(tower_out)
                loss_clf = F.cross_entropy(logits, labels)
                loss_mail = loss_trans + loss_clf

                self.opt_inception.zero_grad()
                self.opt_tower.zero_grad()
                self.opt_clf.zero_grad()
                loss_mail.backward()
                self.opt_inception.step()
                self.opt_tower.step()
                self.opt_clf.step()

                mail_loss_sum += loss_mail.item() * images.size(0)
                mail_correct += (logits.argmax(1) == labels).sum().item()

                # ===== Tower B: No Mail =====
                inc_feat_b = inception_no_mail(images)
                tower_out_b = tower_no_mail(
                    inc_feat_b, text_features=t5_features,
                    mail=None, mail_dropout=0.0  # Never gets mail
                )
                loss_trans_b = F.mse_loss(tower_out_b, clip_features)
                logits_b = classifier_no_mail(tower_out_b)
                loss_clf_b = F.cross_entropy(logits_b, labels)
                loss_no_mail = loss_trans_b + loss_clf_b

                opt_inc_no_mail.zero_grad()
                opt_tower_no_mail.zero_grad()
                opt_clf_no_mail.zero_grad()
                loss_no_mail.backward()
                opt_inc_no_mail.step()
                opt_tower_no_mail.step()
                opt_clf_no_mail.step()

                no_mail_loss_sum += loss_no_mail.item() * images.size(0)
                no_mail_correct += (logits_b.argmax(1) == labels).sum().item()

                # CLIP baseline
                with torch.no_grad():
                    clip_logits = self.classifier(clip_features)
                    clip_correct += (clip_logits.argmax(1) == labels).sum().item()

                total += images.size(0)

                pbar.set_postfix({
                    'mail': f'{mail_correct/total:.3f}',
                    'no_mail': f'{no_mail_correct/total:.3f}',
                })

            # Epoch stats
            mail_acc = mail_correct / total
            mail_loss = mail_loss_sum / total
            no_mail_acc = no_mail_correct / total
            no_mail_loss = no_mail_loss_sum / total
            clip_acc = clip_correct / total

            history['mail_acc'].append(mail_acc)
            history['mail_loss'].append(mail_loss)
            history['no_mail_acc'].append(no_mail_acc)
            history['no_mail_loss'].append(no_mail_loss)
            history['clip_acc'].append(clip_acc)

            # Quick eval on test set (without mail, without T5 to avoid label leak)
            def quick_eval(tower, inc_enc, clf):
                tower.eval()
                inc_enc.eval()
                clf.eval()
                correct = 0
                total_eval = 0
                with torch.no_grad():
                    for img, lbl, clip_f in self.test_loader:
                        img, lbl = img.to(self.device), lbl.to(self.device)
                        inc_f = inc_enc(img)
                        # No T5 at eval - would leak labels
                        out = tower(inc_f, text_features=None, mail=None)
                        pred = clf(out).argmax(1)
                        correct += (pred == lbl).sum().item()
                        total_eval += img.size(0)
                tower.train()
                inc_enc.train()
                clf.train()
                return correct / total_eval

            test_acc_mail = quick_eval(self.tower, self.inception_enc, self.classifier)
            test_acc_no_mail = quick_eval(tower_no_mail, inception_no_mail, classifier_no_mail)

            history['test_mail'].append(test_acc_mail)
            history['test_no_mail'].append(test_acc_no_mail)

            print(f"{epoch+1:>5} | {mail_loss:>8.4f} {mail_acc:>8.3f} | {no_mail_loss:>8.4f} {no_mail_acc:>8.3f} | {test_acc_mail:>8.3f} {test_acc_no_mail:>8.3f}")

        print("-" * 85)
        print()

        # ===== Final Evaluation =====
        print("Final Evaluation (test set)")
        print("-" * 40)

        def eval_tower(tower, inc_enc, clf, use_mail_eval):
            tower.eval()
            inc_enc.eval()
            clf.eval()
            correct = 0
            total = 0

            with torch.no_grad():
                for images, labels, clip_features in self.test_loader:
                    images = images.to(self.device)
                    labels = labels.to(self.device)
                    clip_features = clip_features.to(self.device)

                    inc_feat = inc_enc(images)
                    mail = clip_features if use_mail_eval else None
                    # No T5 at eval - would leak labels
                    tower_out = tower(inc_feat, text_features=None, mail=mail)
                    pred = clf(tower_out).argmax(1)
                    correct += (pred == labels).sum().item()
                    total += images.size(0)

            return correct / total

        # Mail tower - with and without mail at eval
        mail_tower_with = eval_tower(self.tower, self.inception_enc, self.classifier, use_mail_eval=True)
        mail_tower_without = eval_tower(self.tower, self.inception_enc, self.classifier, use_mail_eval=False)

        # No-mail tower
        no_mail_tower = eval_tower(tower_no_mail, inception_no_mail, classifier_no_mail, use_mail_eval=False)

        print(f"\nMail Tower (trained with indoctrination):")
        print(f"  With mail at eval:    {mail_tower_with:.3f}")
        print(f"  Without mail at eval: {mail_tower_without:.3f}")

        print(f"\nNo-Mail Tower (trained blind):")
        print(f"  Accuracy:             {no_mail_tower:.3f}")

        print(f"\nBaseline:")
        print(f"  CLIP direct:          {history['clip_acc'][-1]:.3f}")

        print()
        print("=" * 60)
        diff = mail_tower_without - no_mail_tower
        if diff > 0.02:
            print(f"INDOCTRINATION WINS: +{diff:.1%} accuracy from mail training")
        elif diff > -0.02:
            print(f"COMPARABLE: Both towers within 2% ({mail_tower_without:.3f} vs {no_mail_tower:.3f})")
        else:
            print(f"NO-MAIL WINS: Blind tower better by {-diff:.1%}")
        print("=" * 60)

        return history


# =============================================================================
# MAIN
# =============================================================================

if __name__ == '__main__':
    experiment = InceptiveTowerExperimentV2(
        batch_size=64,
        lr=1e-4,
        device='cuda',
        cache_dir='./cache',
    )

    # Side-by-side comparison
    history = experiment.run_side_by_side(
        epochs=15,
        use_text=True,
    )

Loading cached features...
  T5 features: torch.Size([100, 12, 768])
  CLIP train: torch.Size([50000, 768])
  CLIP test: torch.Size([10000, 768])
  Inception encoder: 3,157,440 params
  Inceptive tower: 12,601,088 params

SIDE-BY-SIDE TRAINING: Mail vs No Mail
Device: cuda
Use text conditioning: True

Tower A (with mail): 12,601,088 params
Tower B (no mail):   12,601,088 params

Training...
-------------------------------------------------------------------------------------
Epoch |      Mail Tower      |    No Mail Tower     | Test Acc (no mail)
      |     Loss      Acc |     Loss      Acc |     Mail   NoMail
-------------------------------------------------------------------------------------




    1 |   1.0114    0.924 |   1.0900    0.903 |    0.015    0.061




    2 |   0.3441    1.000 |   0.3893    1.000 |    0.045    0.082




    3 |   0.2627    1.000 |   0.3196    1.000 |    0.061    0.119




    4 |   0.2200    1.000 |   0.2804    1.000 |    0.084    0.121




    5 |   0.1967    1.000 |   0.2551    1.000 |    0.093    0.155




    6 |   0.1824    1.000 |   0.2371    1.000 |    0.114    0.170




    7 |   0.1814    1.000 |   0.2235    1.000 |    0.108    0.179




    8 |   0.1731    1.000 |   0.2129    1.000 |    0.134    0.174




    9 |   0.1745    1.000 |   0.2048    1.000 |    0.159    0.236




   10 |   0.1742    1.000 |   0.1982    1.000 |    0.171    0.212




   11 |   0.1798    1.000 |   0.1922    1.000 |    0.181    0.224




   12 |   0.1792    1.000 |   0.1870    1.000 |    0.201    0.233




   13 |   0.1799    1.000 |   0.1830    1.000 |    0.209    0.221




   14 |   0.1831    1.000 |   0.1789    1.000 |    0.195    0.252




   15 |   0.1797    1.000 |   0.1754    1.000 |    0.204    0.259
-------------------------------------------------------------------------------------

Final Evaluation (test set)
----------------------------------------

Mail Tower (trained with indoctrination):
  With mail at eval:    0.708
  Without mail at eval: 0.204

No-Mail Tower (trained blind):
  Accuracy:             0.259

Baseline:
  CLIP direct:          0.759

NO-MAIL WINS: Blind tower better by 5.5%


# GPT mnist classifier

In [None]:
import torch
import torch.nn as nn
from torch import Tensor
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.fusion_component import AdaptiveFusion

# ----------------------------
# Components
# ----------------------------

class FFNBlock(TorchComponent):
    def __init__(self, name: str, dim: int):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)


# ----------------------------
# Tower (expert)
# ----------------------------

class MNISTTower(BaseTower):
    def __init__(self, name: str, dim: int, depth: int = 2):
        super().__init__(name, strict=False)

        for i in range(depth):
            self.append(FFNBlock(f"{name}_ffn_{i}", dim))

        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x: Tensor) -> Tensor:
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)


# ----------------------------
# Collective classifier
# ----------------------------

class MNISTCollective(WideRouter):
    def __init__(self, name: str, dim: int, num_towers: int = 8):
        super().__init__(name, auto_discover=True)

        # Input projection
        self.attach("in_proj", nn.Linear(28 * 28, dim))

        # Towers
        for i in range(num_towers):
            self.attach(f"tower_{i}", MNISTTower(f"tower_{i}", dim))

        self.discover_towers()

        # Fusion + classifier head
        self.attach("fusion", AdaptiveFusion("fusion", num_towers, dim))
        self.attach("head", nn.Linear(dim, 10))

    def forward(self, x: Tensor) -> Tensor:
        x = x.view(x.size(0), -1)          # [B, 784]
        x = self["in_proj"](x)             # [B, D]

        opinions = self.wide_forward(x)    # Dict[name, [B, D]]
        fused = self["fusion"](*opinions.values())

        return self["head"](fused)


# ----------------------------
# Training script
# ----------------------------

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.set_float32_matmul_precision("high")

    # Data
    transform = transforms.ToTensor()
    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)
    test_loader  = DataLoader(test_ds, batch_size=256, shuffle=False)

    # Model
    model = MNISTCollective("mnist", dim=256, num_towers=8)
    model.network_to(device)
    model = model.prepare_and_compile()

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    loss_fn = nn.CrossEntropyLoss()

    # Train
    for epoch in range(5):
        model.train()
        total_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            opt.zero_grad()
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            opt.step()

            total_loss += loss.item()

        # Eval
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                preds = model(x).argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)

        acc = 100.0 * correct / total
        print(f"Epoch {epoch+1} | loss {total_loss:.3f} | test acc {acc:.2f}%")

        model.reset()  # paranoid-safe; no-op if no cache


if __name__ == "__main__":
    main()


  _C._set_float32_matmul_precision(precision)
100%|██████████| 9.91M/9.91M [00:00<00:00, 12.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 338kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.20MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.32MB/s]


Epoch 1 | loss 133.711 | test acc 95.41%
Epoch 2 | loss 59.568 | test acc 96.17%
Epoch 3 | loss 43.994 | test acc 96.69%
Epoch 4 | loss 35.057 | test acc 97.20%
Epoch 5 | loss 27.777 | test acc 96.78%


In [None]:
# =========================
# CIFAR-100 GeoFractal Classifier
# DINOv3 ConvNeXt-S or DINOv3 ViT-B/16 encoder
# Unfreeze LAST STAGE only
# Single-cell Colab script + progress bars
# =========================

import torch
import torch.nn as nn
from torch import Tensor
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import timm

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.fusion_component import AdaptiveFusion

# -------------------------
# Settings
# -------------------------
ENCODER_NAME = "convnext_small.dinov3_lvd1689m"  # or: "vit_base_patch16_dinov3.lvd1689m"
EPOCHS = 10
BATCH_TRAIN = 64
BATCH_EVAL  = 128

HEAD_DIM = 512
NUM_TOWERS = 8
TOWER_DEPTH = 2

LR_HEAD = 3e-4
LR_ENCODER = 1e-5
WEIGHT_DECAY = 1e-4

device = "cuda" if torch.cuda.is_available() else "cpu"

# Inductor / TF32 churn has been unstable across recent torch builds.
# This avoids mixed TF32 API pitfalls and keeps compile stable.
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"

torch.set_float32_matmul_precision("high")  # Inductor likes this (safe even w/ fp32_precision='ieee')

# -------------------------
# DINOv3 encoder (timm)
# -------------------------
encoder = timm.create_model(
    ENCODER_NAME,
    pretrained=True,
    num_classes=0,      # feature extractor
    global_pool="avg",  # [B, C]
)
encoder.to(device)

# Determine encoder output dim robustly
with torch.no_grad():
    dummy = torch.zeros(2, 3, 224, 224, device=device)
    out = encoder(dummy)
ENC_DIM = out.shape[-1]

# Freeze everything
for p in encoder.parameters():
    p.requires_grad = False

def unfreeze_last_stage_only(m: nn.Module, encoder_name: str):
    """
    ConvNeXt: unfreeze final stage (stages[-1]) + final norm if present.
    ViT: unfreeze last 2 blocks + final norm (closest equivalent to "last stage").
    """
    enc_lower = encoder_name.lower()

    if "convnext" in enc_lower:
        if hasattr(m, "stages") and isinstance(m.stages, nn.ModuleList) and len(m.stages) > 0:
            for p in m.stages[-1].parameters():
                p.requires_grad = True
        # final norm in timm ConvNeXt is often `norm` or `head.norm` depending on variant
        if hasattr(m, "norm") and isinstance(m.norm, nn.Module):
            for p in m.norm.parameters():
                p.requires_grad = True

    elif "vit" in enc_lower:
        # timm ViT typically has `blocks` (ModuleList) and `norm`
        if hasattr(m, "blocks") and isinstance(m.blocks, nn.ModuleList) and len(m.blocks) > 0:
            for blk in m.blocks[-2:]:
                for p in blk.parameters():
                    p.requires_grad = True
        if hasattr(m, "norm") and isinstance(m.norm, nn.Module):
            for p in m.norm.parameters():
                p.requires_grad = True

    else:
        raise ValueError(f"Unknown encoder family for last-stage unfreeze: {encoder_name}")

unfreeze_last_stage_only(encoder, ENCODER_NAME)

# -------------------------
# GeoFractal head
# -------------------------
class FFNBlock(TorchComponent):
    def __init__(self, name: str, dim: int):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

class CIFARTower(BaseTower):
    def __init__(self, name: str, dim: int, depth: int = 2):
        super().__init__(name, strict=False)
        for i in range(depth):
            self.append(FFNBlock(f"{name}_ffn_{i}", dim))
        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x: Tensor) -> Tensor:
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)

class CIFAR100Collective(WideRouter):
    def __init__(self, enc_dim: int, head_dim: int, num_towers: int, tower_depth: int):
        super().__init__("cifar100", auto_discover=True)

        self.attach("proj", nn.Linear(enc_dim, head_dim))

        for i in range(num_towers):
            self.attach(f"tower_{i}", CIFARTower(f"tower_{i}", head_dim, depth=tower_depth))

        self.discover_towers()

        self.attach("fusion", AdaptiveFusion("fusion", num_towers, head_dim))
        self.attach("head", nn.Linear(head_dim, 100))

    def forward(self, feats: Tensor) -> Tensor:
        x = self["proj"](feats)
        opinions = self.wide_forward(x)
        fused = self["fusion"](*opinions.values())
        return self["head"](fused)

model = CIFAR100Collective(ENC_DIM, HEAD_DIM, NUM_TOWERS, TOWER_DEPTH)
model.network_to(device)
model = model.prepare_and_compile()

# -------------------------
# Data (CIFAR-100)
# (Simple, reliable aug: CIFAR-native crop/flip, then resize to 224)
# -------------------------
train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

test_tf = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

train_ds = datasets.CIFAR100("./data", train=True, download=True, transform=train_tf)
test_ds  = datasets.CIFAR100("./data", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_ds, batch_size=BATCH_TRAIN, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_EVAL, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

# -------------------------
# Optimizer (two param groups)
# -------------------------
head_params = list(model.parameters())
enc_params  = [p for p in encoder.parameters() if p.requires_grad]

opt = torch.optim.AdamW(
    [
        {"params": head_params, "lr": LR_HEAD, "weight_decay": WEIGHT_DECAY},
        {"params": enc_params,  "lr": LR_ENCODER, "weight_decay": WEIGHT_DECAY},
    ]
)

loss_fn = nn.CrossEntropyLoss()

# -------------------------
# Train / Eval
# -------------------------
for epoch in range(EPOCHS):
    model.train()
    encoder.train()  # only last stage has grads; others are frozen anyway

    running_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [train]", leave=False)
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        feats = encoder(imgs)

        opt.zero_grad(set_to_none=True)
        logits = model(feats)
        loss = loss_fn(logits, labels)
        loss.backward()
        opt.step()

        running_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    # Eval
    model.eval()
    encoder.eval()
    correct = 0
    total = 0

    pbar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [eval]", leave=False)
    with torch.no_grad():
        for imgs, labels in pbar:
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            feats = encoder(imgs)
            preds = model(feats).argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)
            pbar.set_postfix(acc=f"{100.0 * correct / total:.2f}%")

    acc = 100.0 * correct / total
    print(f"Epoch {epoch+1:02d} | loss {running_loss:.2f} | test acc {acc:.2f}%")

    model.reset()


  _C._set_float32_matmul_precision(precision)


Epoch 1/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 01 | loss 619.13 | test acc 83.19%


Epoch 2/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 02 | loss 410.93 | test acc 84.01%


Epoch 3/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 3/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 03 | loss 361.12 | test acc 84.40%


Epoch 4/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 4/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 04 | loss 323.39 | test acc 84.24%


Epoch 5/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 5/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 05 | loss 294.00 | test acc 84.24%


Epoch 6/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 6/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 06 | loss 263.00 | test acc 84.81%


Epoch 7/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 7/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 07 | loss 237.17 | test acc 84.71%


Epoch 8/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 8/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 08 | loss 215.57 | test acc 84.99%


Epoch 9/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 9/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 09 | loss 190.68 | test acc 84.76%


Epoch 10/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 10/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 10 | loss 164.09 | test acc 85.08%


# more models

In [None]:
# =========================
# CIFAR-100 Dual-Collective GeoFractal Classifier
# DINOv3 ConvNeXt-S  +  DINOv3 ViT-B/16
# Single Colab cell
# =========================

!pip -q install timm tqdm

import torch
import torch.nn as nn
from torch import Tensor
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import timm

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.fusion_component import AdaptiveFusion

# -------------------------
# Stability first
# -------------------------
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.set_float32_matmul_precision("high")

device = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------
# Encoders (DINOv3)
# -------------------------
ENC_A = "convnext_small.dinov3_lvd1689m"
ENC_B = "vit_base_patch16_dinov3.lvd1689m"

def load_encoder(name):
    enc = timm.create_model(name, pretrained=True, num_classes=0, global_pool="avg")
    enc.to(device)
    for p in enc.parameters():
        p.requires_grad = False
    return enc

enc_a = load_encoder(ENC_A)
enc_b = load_encoder(ENC_B)

with torch.no_grad():
    dummy = torch.zeros(2, 3, 224, 224, device=device)
    DIM_A = enc_a(dummy).shape[-1]
    DIM_B = enc_b(dummy).shape[-1]

# -------------------------
# Components
# -------------------------
class FFNBlock(TorchComponent):
    def __init__(self, name, dim):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return self.net(x)

class Tower(BaseTower):
    def __init__(self, name, dim, depth=2):
        super().__init__(name, strict=False)
        for i in range(depth):
            self.append(FFNBlock(f"{name}_ffn_{i}", dim))
        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x):
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)

# -------------------------
# Sub-Collective
# -------------------------
class SubCollective(WideRouter):
    def __init__(self, name, in_dim, head_dim, num_towers):
        super().__init__(name, auto_discover=True)

        self.attach("proj", nn.Linear(in_dim, head_dim))
        for i in range(num_towers):
            self.attach(f"tower_{i}", Tower(f"{name}_tower_{i}", head_dim))

        self.discover_towers()
        self.attach("fusion", AdaptiveFusion(f"{name}_fusion", num_towers, head_dim))

    def forward(self, feats):
        x = self["proj"](feats)
        opinions = self.wide_forward(x)
        return self["fusion"](*opinions.values())

# -------------------------
# Meta-Collective
# -------------------------
class DualCollective(nn.Module):
    def __init__(self):
        super().__init__()
        self.sub_a = SubCollective("convnext", DIM_A, 512, 8)
        self.sub_b = SubCollective("vit",       DIM_B, 512, 8)

        self.meta_fusion = AdaptiveFusion("meta_fusion", 2, 512)
        self.head = nn.Linear(512, 100)

    def forward(self, fa, fb):
        oa = self.sub_a(fa)
        ob = self.sub_b(fb)
        fused = self.meta_fusion(oa, ob)
        return self.head(fused)

model = DualCollective().to(device)
model.sub_a.prepare_and_compile()
model.sub_b.prepare_and_compile()

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

# -------------------------
# Data
# -------------------------
tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])

train_ds = datasets.CIFAR100("./data", train=True, download=True, transform=tf)
test_ds  = datasets.CIFAR100("./data", train=False, download=True, transform=tf)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True)
test_loader  = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2, persistent_workers=True)

# -------------------------
# Train / Eval
# -------------------------
EPOCHS = 10

for epoch in range(EPOCHS):
    model.train()
    running = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [train]", leave=False)
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)

        with torch.no_grad():
            fa = enc_a(imgs)
            fb = enc_b(imgs)

        opt.zero_grad()
        logits = model(fa, fb)
        loss = loss_fn(logits, labels)
        loss.backward()
        opt.step()

        running += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    model.eval()
    correct = total = 0

    pbar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [eval]", leave=False)
    with torch.no_grad():
        for imgs, labels in pbar:
            imgs, labels = imgs.to(device), labels.to(device)
            fa = enc_a(imgs)
            fb = enc_b(imgs)
            preds = model(fa, fb).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            pbar.set_postfix(acc=f"{100*correct/total:.2f}%")

    print(f"Epoch {epoch+1:02d} | loss {running:.2f} | acc {100*correct/total:.2f}%")


  _C._set_float32_matmul_precision(precision)


model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

Epoch 1/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 01 | loss 610.63 | acc 83.41%


Epoch 2/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 02 | loss 354.78 | acc 86.56%


Epoch 3/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 3/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 03 | loss 266.78 | acc 87.40%


Epoch 4/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 4/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 04 | loss 222.49 | acc 87.71%


Epoch 5/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 5/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 05 | loss 194.73 | acc 87.41%


Epoch 6/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 6/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 06 | loss 171.57 | acc 87.66%


Epoch 7/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 7/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 07 | loss 153.34 | acc 87.73%


Epoch 8/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 8/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 08 | loss 135.76 | acc 87.01%


Epoch 9/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 9/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 09 | loss 124.58 | acc 87.37%


Epoch 10/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 10/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 10 | loss 110.23 | acc 87.68%


# aggregation inception begins

In [None]:
# =========================
# CIFAR-100 Dual-Collective GeoFractal Classifier
# DINOv3 ConvNeXt-S  +  DINOv3 ViT-B/16
# With FieldWalkerFusion ablation test
# Single Colab cell
# =========================

!pip -q install timm tqdm

import torch
import torch.nn as nn
from torch import Tensor
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import timm

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.fusion_component import AdaptiveFusion
from geofractal.router.components.aggregation_component import (
    FieldWalkerFusion, from_preset, WALKER_PRESETS
)

# -------------------------
# Config - Toggle for ablation
# -------------------------
USE_FIELD_WALKER = True          # False = baseline AdaptiveFusion
WALKER_PRESET = 'slerp'          # 'alucard', 'slerp', 'slip', 'min_p', 'learnable'
WALKER_STEPS = 4                 # Interpolation steps
LEARNABLE_WALKER = False         # Learn schedule + aggregation weights

# -------------------------
# Stability first
# -------------------------
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.set_float32_matmul_precision("high")

device = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------
# Encoders (DINOv3)
# -------------------------
ENC_A = "convnext_small.dinov3_lvd1689m"
ENC_B = "vit_base_patch16_dinov3.lvd1689m"

def load_encoder(name):
    enc = timm.create_model(name, pretrained=True, num_classes=0, global_pool="avg")
    enc.to(device)
    for p in enc.parameters():
        p.requires_grad = False
    return enc

enc_a = load_encoder(ENC_A)
enc_b = load_encoder(ENC_B)

with torch.no_grad():
    dummy = torch.zeros(2, 3, 224, 224, device=device)
    DIM_A = enc_a(dummy).shape[-1]
    DIM_B = enc_b(dummy).shape[-1]

print(f"Encoder dims: ConvNeXt={DIM_A}, ViT={DIM_B}")
print(f"Using FieldWalkerFusion: {USE_FIELD_WALKER}")
if USE_FIELD_WALKER:
    print(f"  Preset: {WALKER_PRESET}, Steps: {WALKER_STEPS}, Learnable: {LEARNABLE_WALKER}")

# -------------------------
# Components
# -------------------------
class FFNBlock(TorchComponent):
    def __init__(self, name, dim):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return self.net(x)

class Tower(BaseTower):
    def __init__(self, name, dim, depth=2):
        super().__init__(name, strict=False)
        for i in range(depth):
            self.append(FFNBlock(f"{name}_ffn_{i}", dim))
        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x):
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)

# -------------------------
# Sub-Collective
# -------------------------
class SubCollective(WideRouter):
    def __init__(self, name, in_dim, head_dim, num_towers):
        super().__init__(name, auto_discover=True)

        self.attach("proj", nn.Linear(in_dim, head_dim))
        for i in range(num_towers):
            self.attach(f"tower_{i}", Tower(f"{name}_tower_{i}", head_dim))

        self.discover_towers()
        self.attach("fusion", AdaptiveFusion(f"{name}_fusion", num_towers, head_dim))

    def forward(self, feats):
        x = self["proj"](feats)
        opinions = self.wide_forward(x)
        return self["fusion"](*opinions.values())

# -------------------------
# Meta-Collective with FieldWalker option
# -------------------------
class DualCollective(nn.Module):
    def __init__(self, use_walker=False, walker_preset='slerp',
                 walker_steps=4, learnable=False):
        super().__init__()
        self.sub_a = SubCollective("convnext", DIM_A, 512, 8)
        self.sub_b = SubCollective("vit",       DIM_B, 512, 8)

        self.use_walker = use_walker

        if use_walker:
            # FieldWalkerFusion: walks from ConvNeXt opinion to ViT opinion
            # This explores the interpolation manifold between modalities
            self.meta_fusion = FieldWalkerFusion(
                "meta_walker",
                in_features=512,
                num_steps=walker_steps,
                blend_mode=WALKER_PRESETS[walker_preset]['blend'],
                schedule=WALKER_PRESETS[walker_preset]['schedule'],
                aggregation=WALKER_PRESETS[walker_preset]['aggregation'],
                learnable_steps=learnable,
                learnable_agg=learnable,
            )
        else:
            # Baseline: simple adaptive fusion
            self.meta_fusion = AdaptiveFusion("meta_fusion", 2, 512)

        self.head = nn.Linear(512, 100)

    def forward(self, fa, fb):
        oa = self.sub_a(fa)
        ob = self.sub_b(fb)

        if self.use_walker:
            # FieldWalker takes (source, target) - walks from oa to ob
            fused = self.meta_fusion(oa, ob)
        else:
            fused = self.meta_fusion(oa, ob)

        return self.head(fused)

model = DualCollective(
    use_walker=USE_FIELD_WALKER,
    walker_preset=WALKER_PRESET,
    walker_steps=WALKER_STEPS,
    learnable=LEARNABLE_WALKER,
).to(device)

model.sub_a.prepare_and_compile()
model.sub_b.prepare_and_compile()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {trainable_params:,} trainable / {total_params:,} total")

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

# -------------------------
# Data
# -------------------------
tf_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])

tf_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])

train_ds = datasets.CIFAR100("./data", train=True, download=True, transform=tf_train)
test_ds  = datasets.CIFAR100("./data", train=False, download=True, transform=tf_test)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True)
test_loader  = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2, persistent_workers=True)

# -------------------------
# Train / Eval
# -------------------------
EPOCHS = 10

results = []

for epoch in range(EPOCHS):
    model.train()
    running = 0.0
    n_batches = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [train]", leave=False)
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)

        with torch.no_grad():
            fa = enc_a(imgs)
            fb = enc_b(imgs)

        opt.zero_grad()
        logits = model(fa, fb)
        loss = loss_fn(logits, labels)
        loss.backward()
        opt.step()

        running += loss.item()
        n_batches += 1
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = running / n_batches

    model.eval()
    correct = total = 0

    pbar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [eval]", leave=False)
    with torch.no_grad():
        for imgs, labels in pbar:
            imgs, labels = imgs.to(device), labels.to(device)
            fa = enc_a(imgs)
            fb = enc_b(imgs)
            preds = model(fa, fb).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            pbar.set_postfix(acc=f"{100*correct/total:.2f}%")

    acc = 100*correct/total
    results.append({'epoch': epoch+1, 'loss': avg_loss, 'acc': acc})
    print(f"Epoch {epoch+1:02d} | loss {avg_loss:.4f} | acc {acc:.2f}%")

# -------------------------
# Summary
# -------------------------
print("\n" + "="*50)
print("ABLATION SUMMARY")
print("="*50)
print(f"Config: {'FieldWalkerFusion' if USE_FIELD_WALKER else 'AdaptiveFusion'}")
if USE_FIELD_WALKER:
    print(f"  Preset: {WALKER_PRESET}")
    print(f"  Steps: {WALKER_STEPS}")
    print(f"  Learnable: {LEARNABLE_WALKER}")
print(f"Best accuracy: {max(r['acc'] for r in results):.2f}%")
print(f"Final accuracy: {results[-1]['acc']:.2f}%")

# If using learnable walker, show learned values
if USE_FIELD_WALKER and LEARNABLE_WALKER:
    walker = model.meta_fusion
    if hasattr(walker.schedule, 'get_alphas'):
        print(f"Learned alphas: {walker.schedule.get_alphas().tolist()}")
    if hasattr(walker.aggregation, 'get_weights'):
        print(f"Learned agg weights: {walker.aggregation.get_weights().tolist()}")

  _C._set_float32_matmul_precision(precision)


model.safetensors:   0%|          | 0.00/198M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

Encoder dims: ConvNeXt=768, ViT=768
Using FieldWalkerFusion: True
  Preset: slerp, Steps: 4, Learnable: False
Parameters: 17,796,710 trainable / 17,796,710 total


100%|██████████| 169M/169M [00:19<00:00, 8.80MB/s]


Epoch 1/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 01 | loss 0.6368 | acc 87.32%


Epoch 2/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 02 | loss 0.3670 | acc 88.00%


Epoch 3/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 3/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 03 | loss 0.3129 | acc 88.10%


Epoch 4/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 4/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 04 | loss 0.2750 | acc 88.11%


Epoch 5/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 5/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 05 | loss 0.2498 | acc 87.99%


Epoch 6/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 6/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 06 | loss 0.2236 | acc 87.84%


Epoch 7/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 7/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 07 | loss 0.2017 | acc 87.64%


Epoch 8/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 8/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 08 | loss 0.1802 | acc 88.76%


Epoch 9/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 9/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 09 | loss 0.1687 | acc 88.05%


Epoch 10/10 [train]:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 10/10 [eval]:   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 10 | loss 0.1585 | acc 88.43%

ABLATION SUMMARY
Config: FieldWalkerFusion
  Preset: slerp
  Steps: 4
  Learnable: False
Best accuracy: 88.76%
Final accuracy: 88.43%


# FieldWalker ablation

In [None]:
# =========================
# FieldWalkerFusion Forward Test Suite
# Tests all blend/schedule/aggregation combinations
# Run this first to catch any crashes before full ablation
# =========================

!pip -q install timm

import torch
import torch.nn as nn
import traceback
from typing import List, Tuple, Dict, Any
from tqdm.auto import tqdm

# Test if geofractal is available
try:
    from geofractal.router.components.aggregation_component import (
        FieldWalkerFusion, from_preset,
        BLEND_MODES, SCHEDULES, AGGREGATIONS, WALKER_PRESETS,
        get_blend_mode, get_aggregation,
        LerpBlend, SlerpBlend, SlipBlend, ZeusBlend, HeliosBlend,
        SurgeBlend, RippleBlend, GilgameshBlend, ShivaBlend, IfritBlend, MinPBlend,
        LinearSchedule, CosineSchedule, SigmoidSchedule, TauSchedule,
        WaveSchedule, LearnableSchedule, AdaptiveSchedule,
        MeanAggregation, SumAggregation, MaxAggregation, MinAggregation,
        TopKAggregation, BottomKAggregation, SoftmaxAggregation, SoftminAggregation,
        MinPAggregation, WeightedMeanAggregation, LastStepAggregation, FirstStepAggregation,
        TriangularAggregation, SimilarityAggregation, CrossSimilarityAggregation,
        SimilarityTreeAggregation, SlerpAggregation, AttentionAggregation, LearnableAggregation,
    )
    GEOFRACTAL_AVAILABLE = True
except ImportError as e:
    print(f"Import error: {e}")
    GEOFRACTAL_AVAILABLE = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print(f"GeoFractal available: {GEOFRACTAL_AVAILABLE}")

if not GEOFRACTAL_AVAILABLE:
    raise RuntimeError("Cannot run tests without geofractal package")

# -------------------------
# Test Configuration
# -------------------------
BATCH_SIZE = 2
SEQ_LEN = 16
FEATURES = 256
NUM_STEPS = 8

# All component names
BLEND_NAMES = list(BLEND_MODES.keys())
SCHEDULE_NAMES = ['linear', 'cosine', 'sigmoid', 'tau', 'wave', 'learnable']  # Skip adaptive for simple test
AGGREGATION_NAMES = list(AGGREGATIONS.keys())
PRESET_NAMES = list(WALKER_PRESETS.keys())

print(f"\nBlend modes ({len(BLEND_NAMES)}): {BLEND_NAMES}")
print(f"Schedules ({len(SCHEDULE_NAMES)}): {SCHEDULE_NAMES}")
print(f"Aggregations ({len(AGGREGATION_NAMES)}): {AGGREGATION_NAMES}")
print(f"Presets ({len(PRESET_NAMES)}): {PRESET_NAMES}")

# -------------------------
# Test Functions
# -------------------------
def test_blend_mode(name: str) -> Tuple[bool, str]:
    """Test a single blend mode."""
    try:
        blend = get_blend_mode(name)
        a = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device)
        b = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device)
        alpha = torch.tensor(0.5, device=device).expand(BATCH_SIZE, SEQ_LEN)

        # Test without mask
        out = blend(a, b, alpha)
        assert out.shape == a.shape, f"Shape mismatch: {out.shape} vs {a.shape}"

        # Test with mask
        mask = torch.zeros(BATCH_SIZE, SEQ_LEN, device=device)
        mask[:, SEQ_LEN//2:] = 1.0
        out_masked = blend(a, b, alpha, mask=mask)
        assert out_masked.shape == a.shape, f"Masked shape mismatch"

        # Verify fingerprint preservation (mask=0 should preserve a)
        diff = (out_masked[:, :SEQ_LEN//2, :] - a[:, :SEQ_LEN//2, :]).abs().max().item()
        assert diff < 1e-5, f"Fingerprint not preserved: diff={diff}"

        return True, "OK"
    except Exception as e:
        return False, str(e)


def test_schedule(name: str) -> Tuple[bool, str]:
    """Test a single schedule."""
    try:
        if name == 'learnable':
            sched = LearnableSchedule(NUM_STEPS)
        elif name == 'adaptive':
            sched = AdaptiveSchedule(FEATURES, NUM_STEPS)
        else:
            sched_cls = SCHEDULES[name]
            sched = sched_cls()

        sched = sched.to(device)
        alphas = sched(NUM_STEPS)

        assert alphas.shape == (NUM_STEPS,), f"Shape mismatch: {alphas.shape}"
        assert alphas.min() >= 0, f"Alpha below 0: {alphas.min()}"
        assert alphas.max() <= 1, f"Alpha above 1: {alphas.max()}"

        return True, f"range=[{alphas.min().item():.3f}, {alphas.max().item():.3f}]"
    except Exception as e:
        return False, str(e)


def test_aggregation(name: str) -> Tuple[bool, str]:
    """Test a single aggregation."""
    try:
        agg = get_aggregation(name, features=FEATURES, num_steps=NUM_STEPS, k=3)
        if hasattr(agg, 'to'):
            agg = agg.to(device)

        stepped = torch.randn(BATCH_SIZE, NUM_STEPS, SEQ_LEN, FEATURES, device=device)
        alphas = torch.linspace(0, 1, NUM_STEPS, device=device)
        original = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device)

        # Test without mask
        out = agg(stepped, alphas)
        assert out.shape == (BATCH_SIZE, SEQ_LEN, FEATURES), f"Shape: {out.shape}"

        # Test with mask
        mask = torch.zeros(BATCH_SIZE, SEQ_LEN, device=device)
        mask[:, SEQ_LEN//2:] = 1.0
        out_masked = agg(stepped, alphas, mask=mask, original=original)
        assert out_masked.shape == (BATCH_SIZE, SEQ_LEN, FEATURES), f"Masked shape: {out_masked.shape}"

        # Verify fingerprint preservation
        diff = (out_masked[:, :SEQ_LEN//2, :] - original[:, :SEQ_LEN//2, :]).abs().max().item()
        assert diff < 1e-5, f"Fingerprint not preserved: diff={diff}"

        return True, "OK"
    except Exception as e:
        return False, str(e)


def test_preset(name: str) -> Tuple[bool, str]:
    """Test a preset configuration."""
    try:
        walker = from_preset(f"{name}_test", name, in_features=FEATURES, num_steps=NUM_STEPS)
        walker = walker.to(device)

        a = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device)
        b = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device)

        out = walker(a, b)
        assert out.shape == a.shape, f"Shape mismatch: {out.shape}"

        # Check stepped outputs
        stepped = walker.get_stepped_outputs()
        assert stepped is not None, "No stepped outputs"
        assert stepped.shape == (BATCH_SIZE, NUM_STEPS, SEQ_LEN, FEATURES), f"Stepped shape: {stepped.shape}"

        return True, "OK"
    except Exception as e:
        return False, str(e)


def test_full_walker(blend: str, schedule: str, aggregation: str) -> Tuple[bool, str]:
    """Test a full walker configuration."""
    try:
        walker = FieldWalkerFusion(
            name=f"test_{blend}_{schedule}_{aggregation}",
            in_features=FEATURES,
            num_steps=NUM_STEPS,
            blend_mode=blend,
            schedule=schedule,
            aggregation=aggregation,
        )
        walker = walker.to(device)

        a = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device)
        b = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device)

        out = walker(a, b)
        assert out.shape == a.shape, f"Shape: {out.shape}"

        return True, "OK"
    except Exception as e:
        return False, str(e)


def test_gradient_flow(blend: str, schedule: str, aggregation: str) -> Tuple[bool, str]:
    """Test gradient flow through walker."""
    try:
        learnable_sched = schedule == 'learnable'
        learnable_agg = aggregation == 'learnable'

        walker = FieldWalkerFusion(
            name=f"grad_test",
            in_features=FEATURES,
            num_steps=NUM_STEPS,
            blend_mode=blend,
            schedule=schedule,
            aggregation=aggregation,
            learnable_steps=learnable_sched,
            learnable_agg=learnable_agg,
        )
        walker = walker.to(device)

        a = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device, requires_grad=True)
        b = torch.randn(BATCH_SIZE, SEQ_LEN, FEATURES, device=device, requires_grad=True)

        out = walker(a, b)
        loss = out.sum()
        loss.backward()

        assert a.grad is not None, "No grad on input a"
        assert b.grad is not None, "No grad on input b"

        return True, "OK"
    except Exception as e:
        return False, str(e)


def test_2d_input() -> Tuple[bool, str]:
    """Test 2D input (no sequence dimension)."""
    try:
        walker = from_preset("2d_test", "slerp", in_features=FEATURES, num_steps=NUM_STEPS)
        walker = walker.to(device)

        a = torch.randn(BATCH_SIZE, FEATURES, device=device)
        b = torch.randn(BATCH_SIZE, FEATURES, device=device)

        out = walker(a, b)
        assert out.shape == a.shape, f"Shape: {out.shape} vs {a.shape}"

        return True, "OK"
    except Exception as e:
        return False, str(e)


# -------------------------
# Run All Tests
# -------------------------
def run_all_tests():
    results = {
        'blend_modes': {},
        'schedules': {},
        'aggregations': {},
        'presets': {},
        'combinations': {},
        'gradients': {},
        'special': {},
    }

    failed = []

    # Test blend modes
    print("\n" + "="*60)
    print("BLEND MODES")
    print("="*60)
    for name in BLEND_NAMES:
        ok, msg = test_blend_mode(name)
        status = "✓" if ok else "✗"
        print(f"  {status} {name:12s} : {msg}")
        results['blend_modes'][name] = (ok, msg)
        if not ok:
            failed.append(f"blend:{name}")

    # Test schedules
    print("\n" + "="*60)
    print("SCHEDULES")
    print("="*60)
    for name in SCHEDULE_NAMES:
        ok, msg = test_schedule(name)
        status = "✓" if ok else "✗"
        print(f"  {status} {name:12s} : {msg}")
        results['schedules'][name] = (ok, msg)
        if not ok:
            failed.append(f"schedule:{name}")

    # Test aggregations
    print("\n" + "="*60)
    print("AGGREGATIONS")
    print("="*60)
    for name in AGGREGATION_NAMES:
        ok, msg = test_aggregation(name)
        status = "✓" if ok else "✗"
        print(f"  {status} {name:18s} : {msg}")
        results['aggregations'][name] = (ok, msg)
        if not ok:
            failed.append(f"aggregation:{name}")

    # Test presets
    print("\n" + "="*60)
    print("PRESETS")
    print("="*60)
    for name in PRESET_NAMES:
        ok, msg = test_preset(name)
        status = "✓" if ok else "✗"
        print(f"  {status} {name:12s} : {msg}")
        results['presets'][name] = (ok, msg)
        if not ok:
            failed.append(f"preset:{name}")

    # Test key combinations (subset - full matrix would be 11*6*19 = 1254 tests)
    print("\n" + "="*60)
    print("KEY COMBINATIONS")
    print("="*60)

    key_blends = ['lerp', 'slerp', 'slip', 'min_p']
    key_schedules = ['linear', 'cosine', 'learnable']
    key_aggs = ['mean', 'similarity', 'min_p', 'softmax', 'similarity_tree']

    combo_count = 0
    combo_pass = 0
    for blend in key_blends:
        for sched in key_schedules:
            for agg in key_aggs:
                combo_name = f"{blend}+{sched}+{agg}"
                ok, msg = test_full_walker(blend, sched, agg)
                status = "✓" if ok else "✗"
                if not ok:
                    print(f"  {status} {combo_name}: {msg}")
                    failed.append(f"combo:{combo_name}")
                else:
                    combo_pass += 1
                combo_count += 1
                results['combinations'][combo_name] = (ok, msg)

    print(f"  Tested {combo_count} combinations: {combo_pass}/{combo_count} passed")

    # Test gradient flow
    print("\n" + "="*60)
    print("GRADIENT FLOW")
    print("="*60)
    grad_tests = [
        ('lerp', 'linear', 'mean'),
        ('slerp', 'cosine', 'softmax'),
        ('slip', 'learnable', 'learnable'),
        ('min_p', 'linear', 'min_p'),
    ]
    for blend, sched, agg in grad_tests:
        name = f"{blend}+{sched}+{agg}"
        ok, msg = test_gradient_flow(blend, sched, agg)
        status = "✓" if ok else "✗"
        print(f"  {status} {name}: {msg}")
        results['gradients'][name] = (ok, msg)
        if not ok:
            failed.append(f"gradient:{name}")

    # Special tests
    print("\n" + "="*60)
    print("SPECIAL CASES")
    print("="*60)

    ok, msg = test_2d_input()
    status = "✓" if ok else "✗"
    print(f"  {status} 2D input (no seq dim): {msg}")
    results['special']['2d_input'] = (ok, msg)
    if not ok:
        failed.append("special:2d_input")

    # Summary
    print("\n" + "="*60)
    print("SUMMARY")
    print("="*60)

    total_tests = (
        len(results['blend_modes']) +
        len(results['schedules']) +
        len(results['aggregations']) +
        len(results['presets']) +
        len(results['combinations']) +
        len(results['gradients']) +
        len(results['special'])
    )

    total_passed = total_tests - len(failed)

    print(f"Total tests: {total_tests}")
    print(f"Passed: {total_passed}")
    print(f"Failed: {len(failed)}")

    if failed:
        print(f"\nFailed tests:")
        for f in failed:
            print(f"  - {f}")
    else:
        print("\n✓ ALL TESTS PASSED - Ready for ablation barrage")

    return results, failed


# -------------------------
# Main
# -------------------------
if __name__ == "__main__":
    results, failed = run_all_tests()

    if not failed:
        print("\n" + "="*60)
        print("ABLATION MATRIX PREVIEW")
        print("="*60)

        # Count total ablation runs
        n_blends = len(BLEND_NAMES)
        n_schedules = len(SCHEDULE_NAMES)
        n_aggs = len(AGGREGATION_NAMES)
        n_steps = [2, 4, 8, 16]
        n_learnable = 2  # True/False

        # Full matrix
        full_matrix = n_blends * n_schedules * n_aggs
        print(f"Full matrix (blend × schedule × agg): {n_blends} × {n_schedules} × {n_aggs} = {full_matrix}")

        # Practical subset
        practical_blends = ['lerp', 'slerp', 'slip', 'zeus', 'min_p']
        practical_scheds = ['linear', 'cosine', 'tau', 'learnable']
        practical_aggs = ['mean', 'weighted', 'softmax', 'min_p', 'similarity', 'similarity_tree', 'learnable']

        practical_matrix = len(practical_blends) * len(practical_scheds) * len(practical_aggs)
        print(f"Practical matrix: {len(practical_blends)} × {len(practical_scheds)} × {len(practical_aggs)} = {practical_matrix}")

        # With steps variation
        with_steps = practical_matrix * len(n_steps)
        print(f"With step variation (×{len(n_steps)}): {with_steps}")

        print(f"\nRecommended: Run {practical_matrix} base configs × 10 epochs each")

Device: cuda
GeoFractal available: True

Blend modes (11): ['lerp', 'slerp', 'slip', 'zeus', 'helios', 'surge', 'ripple', 'gilgamesh', 'shiva', 'ifrit', 'min_p']
Schedules (6): ['linear', 'cosine', 'sigmoid', 'tau', 'wave', 'learnable']
Aggregations (19): ['mean', 'sum', 'max', 'min', 'top_k', 'bottom_k', 'softmax', 'softmin', 'min_p', 'weighted', 'last', 'first', 'triangular', 'similarity', 'cross_similarity', 'similarity_tree', 'slerp', 'attention', 'learnable']
Presets (10): ['alucard', 'slerp', 'slip', 'zeus', 'gilgamesh', 'shiva', 'ifrit', 'learnable', 'fingerprint', 'min_p']

BLEND MODES
  ✓ lerp         : OK
  ✓ slerp        : OK
  ✓ slip         : OK
  ✓ zeus         : OK
  ✓ helios       : OK
  ✓ surge        : OK
  ✓ ripple       : OK
  ✓ gilgamesh    : OK
  ✓ shiva        : OK
  ✓ ifrit        : OK
  ✓ min_p        : OK

SCHEDULES
  ✓ linear       : range=[0.000, 1.000]
  ✓ cosine       : range=[0.000, 1.000]
  ✓ sigmoid      : range=[0.007, 0.993]
  ✓ tau          : range=[

In [None]:
# =========================
# CIFAR-10 FieldWalkerFusion Ablation Barrage
# Cached latents for fast iteration
# 100 configuration tests
# =========================

# !pip -q install timm tqdm pandas  # Uncomment if needed

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from tqdm.auto import tqdm
import timm
import pandas as pd
import os
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional
import gc

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.fusion_component import AdaptiveFusion
from geofractal.router.components.aggregation_component import (
    FieldWalkerFusion, from_preset, WALKER_PRESETS
)

# -------------------------
# Config
# -------------------------
CACHE_DIR = "./latent_cache"
RESULTS_FILE = f"ablation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
EPOCHS_PER_RUN = 10
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 128

# Stability
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# -------------------------
# Encoder Setup
# -------------------------
ENC_A = "convnext_small.dinov3_lvd1689m"
ENC_B = "vit_base_patch16_dinov3.lvd1689m"

def load_encoder(name):
    enc = timm.create_model(name, pretrained=True, num_classes=0, global_pool="avg")
    enc.to(device)
    enc.eval()
    for p in enc.parameters():
        p.requires_grad = False
    return enc

# -------------------------
# Latent Caching
# -------------------------
def cache_latents(encoder, dataloader, cache_path: str, desc: str):
    """Extract and cache latents from encoder."""
    if os.path.exists(cache_path):
        print(f"Loading cached latents from {cache_path}")
        data = torch.load(cache_path)
        return data['latents'], data['labels']

    print(f"Extracting latents: {desc}")
    all_latents = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in tqdm(dataloader, desc=desc):
            imgs = imgs.to(device)
            latents = encoder(imgs)
            all_latents.append(latents.cpu())
            all_labels.append(labels)

    latents = torch.cat(all_latents, dim=0)
    labels = torch.cat(all_labels, dim=0)

    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
    torch.save({'latents': latents, 'labels': labels}, cache_path)
    print(f"Cached to {cache_path}: {latents.shape}")

    return latents, labels


def prepare_cached_data():
    """Prepare all cached latents."""

    # Data transforms (no augmentation for caching - we'll augment latents or skip)
    # Actually for fair comparison, cache without augmentation
    tf = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=tf)
    test_ds = datasets.CIFAR10("./data", train=False, download=True, transform=tf)

    # Use larger batch for extraction (faster)
    train_loader = DataLoader(train_ds, batch_size=256, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4)

    # Load encoders
    print("Loading encoders...")
    enc_a = load_encoder(ENC_A)
    enc_b = load_encoder(ENC_B)

    # Get dimensions
    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224, device=device)
        DIM_A = enc_a(dummy).shape[-1]
        DIM_B = enc_b(dummy).shape[-1]
    print(f"Encoder dims: ConvNeXt={DIM_A}, ViT={DIM_B}")

    # Cache all latents
    train_a, train_labels = cache_latents(
        enc_a, train_loader, f"{CACHE_DIR}/train_convnext.pt", "Train ConvNeXt"
    )
    train_b, _ = cache_latents(
        enc_b, train_loader, f"{CACHE_DIR}/train_vit.pt", "Train ViT"
    )
    test_a, test_labels = cache_latents(
        enc_a, test_loader, f"{CACHE_DIR}/test_convnext.pt", "Test ConvNeXt"
    )
    test_b, _ = cache_latents(
        enc_b, test_loader, f"{CACHE_DIR}/test_vit.pt", "Test ViT"
    )

    # Free encoder memory
    del enc_a, enc_b
    torch.cuda.empty_cache()
    gc.collect()

    # Create tensor datasets
    train_dataset = TensorDataset(train_a, train_b, train_labels)
    test_dataset = TensorDataset(test_a, test_b, test_labels)

    return train_dataset, test_dataset, DIM_A, DIM_B


# -------------------------
# Model Components
# -------------------------
class FFNBlock(TorchComponent):
    def __init__(self, name, dim):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return self.net(x)


class Tower(BaseTower):
    def __init__(self, name, dim, depth=2):
        super().__init__(name, strict=False)
        for i in range(depth):
            self.append(FFNBlock(f"{name}_ffn_{i}", dim))
        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x):
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)


class SubCollective(WideRouter):
    def __init__(self, name, in_dim, head_dim, num_towers):
        super().__init__(name, auto_discover=True)

        self.attach("proj", nn.Linear(in_dim, head_dim))
        for i in range(num_towers):
            self.attach(f"tower_{i}", Tower(f"{name}_tower_{i}", head_dim))

        self.discover_towers()
        self.attach("fusion", AdaptiveFusion(f"{name}_fusion", num_towers, head_dim))

    def forward(self, feats):
        x = self["proj"](feats)
        opinions = self.wide_forward(x)
        return self["fusion"](*opinions.values())


class DualCollective(nn.Module):
    def __init__(self, dim_a, dim_b, config: Dict[str, Any]):
        super().__init__()
        self.sub_a = SubCollective("convnext", dim_a, 512, 8)
        self.sub_b = SubCollective("vit", dim_b, 512, 8)

        self.config = config
        self.use_walker = config.get('use_walker', False)

        if self.use_walker:
            self.meta_fusion = FieldWalkerFusion(
                "meta_walker",
                in_features=512,
                num_steps=config.get('num_steps', 4),
                blend_mode=config.get('blend', 'lerp'),
                schedule=config.get('schedule', 'linear'),
                aggregation=config.get('aggregation', 'mean'),
                learnable_steps=config.get('learnable_steps', False),
                learnable_agg=config.get('learnable_agg', False),
            )
        else:
            self.meta_fusion = AdaptiveFusion("meta_fusion", 2, 512)

        self.head = nn.Linear(512, 10)

    def forward(self, fa, fb):
        oa = self.sub_a(fa)
        ob = self.sub_b(fb)

        if self.use_walker:
            fused = self.meta_fusion(oa, ob)
        else:
            fused = self.meta_fusion(oa, ob)

        return self.head(fused)


# -------------------------
# Training
# -------------------------
def train_single_config(
    config: Dict[str, Any],
    train_dataset: TensorDataset,
    test_dataset: TensorDataset,
    dim_a: int,
    dim_b: int,
    epochs: int = EPOCHS_PER_RUN,
) -> Dict[str, Any]:
    """Train a single configuration and return results."""

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True,
        num_workers=0, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=False,
        num_workers=0, pin_memory=True
    )

    # Build model
    model = DualCollective(dim_a, dim_b, config).to(device)
    model.sub_a.prepare_and_compile()
    model.sub_b.prepare_and_compile()

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    # Track metrics
    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        # Train
        model.train()
        train_loss = 0.0
        n_batches = 0

        for fa, fb, labels in train_loader:
            fa, fb, labels = fa.to(device), fb.to(device), labels.to(device)

            opt.zero_grad()
            logits = model(fa, fb)
            loss = loss_fn(logits, labels)
            loss.backward()
            opt.step()

            train_loss += loss.item()
            n_batches += 1

        avg_loss = train_loss / n_batches

        # Eval
        model.eval()
        correct = total = 0

        with torch.no_grad():
            for fa, fb, labels in test_loader:
                fa, fb, labels = fa.to(device), fb.to(device), labels.to(device)
                preds = model(fa, fb).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100 * correct / total
        best_acc = max(best_acc, acc)

        history.append({'epoch': epoch + 1, 'loss': avg_loss, 'acc': acc})

    # Cleanup
    del model, opt
    torch.cuda.empty_cache()
    gc.collect()

    return {
        'config': config,
        'best_acc': best_acc,
        'final_acc': history[-1]['acc'],
        'final_loss': history[-1]['loss'],
        'history': history,
    }


# -------------------------
# Ablation Matrix
# -------------------------
def generate_ablation_configs() -> List[Dict[str, Any]]:
    """Generate all configurations to test."""
    configs = []

    # Baseline: AdaptiveFusion (no walker)
    configs.append({
        'name': 'baseline_adaptive',
        'use_walker': False,
    })

    # All presets
    for preset_name in WALKER_PRESETS.keys():
        preset = WALKER_PRESETS[preset_name]
        configs.append({
            'name': f'preset_{preset_name}',
            'use_walker': True,
            'blend': preset['blend'],
            'schedule': preset['schedule'],
            'aggregation': preset['aggregation'],
            'num_steps': 4,
        })

    # Blend mode sweep (with fixed schedule/agg)
    blends = ['lerp', 'slerp', 'slip', 'zeus', 'helios', 'surge', 'ripple', 'gilgamesh', 'shiva', 'ifrit', 'min_p']
    for blend in blends:
        configs.append({
            'name': f'blend_{blend}',
            'use_walker': True,
            'blend': blend,
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': 4,
        })

    # Schedule sweep (with fixed blend/agg)
    schedules = ['linear', 'cosine', 'sigmoid', 'tau', 'wave', 'learnable']
    for sched in schedules:
        configs.append({
            'name': f'schedule_{sched}',
            'use_walker': True,
            'blend': 'slerp',
            'schedule': sched,
            'aggregation': 'mean',
            'num_steps': 4,
            'learnable_steps': sched == 'learnable',
        })

    # Aggregation sweep (with fixed blend/schedule)
    aggregations = ['mean', 'sum', 'max', 'min', 'top_k', 'bottom_k', 'softmax', 'softmin',
                    'min_p', 'weighted', 'last', 'first', 'triangular', 'similarity',
                    'cross_similarity', 'similarity_tree', 'slerp', 'attention', 'learnable']
    for agg in aggregations:
        configs.append({
            'name': f'agg_{agg}',
            'use_walker': True,
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': agg,
            'num_steps': 4,
            'learnable_agg': agg == 'learnable',
        })

    # Step count sweep
    for steps in [2, 4, 8, 16]:
        configs.append({
            'name': f'steps_{steps}',
            'use_walker': True,
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': steps,
        })

    # Step count with different blends
    for steps in [2, 8, 16]:
        for blend in ['lerp', 'slip', 'min_p']:
            configs.append({
                'name': f'steps_{steps}_{blend}',
                'use_walker': True,
                'blend': blend,
                'schedule': 'cosine',
                'aggregation': 'mean',
                'num_steps': steps,
            })

    # Best combo candidates (informed guesses)
    combos = [
        ('slerp', 'cosine', 'similarity', 4),
        ('slerp', 'cosine', 'min_p', 4),
        ('slerp', 'tau', 'softmax', 4),
        ('slip', 'cosine', 'similarity', 4),
        ('slip', 'cosine', 'similarity_tree', 4),
        ('min_p', 'linear', 'min_p', 4),
        ('min_p', 'cosine', 'softmax', 4),
        ('lerp', 'learnable', 'learnable', 4),
        ('slerp', 'learnable', 'learnable', 4),
        ('slerp', 'cosine', 'similarity', 8),
        ('slip', 'cosine', 'similarity', 8),
        ('slerp', 'cosine', 'cross_similarity', 4),
        ('zeus', 'sigmoid', 'last', 4),
        ('gilgamesh', 'linear', 'triangular', 8),
        ('shiva', 'cosine', 'similarity_tree', 4),
        # Additional combos
        ('lerp', 'cosine', 'softmax', 4),
        ('lerp', 'cosine', 'min_p', 4),
        ('lerp', 'tau', 'similarity', 4),
        ('slerp', 'linear', 'attention', 4),
        ('slerp', 'sigmoid', 'softmax', 4),
        ('slerp', 'tau', 'similarity_tree', 4),
        ('slip', 'linear', 'mean', 4),
        ('slip', 'tau', 'softmax', 4),
        ('slip', 'cosine', 'min_p', 4),
        ('slip', 'cosine', 'attention', 4),
        ('min_p', 'cosine', 'similarity', 4),
        ('min_p', 'tau', 'similarity_tree', 4),
        ('zeus', 'linear', 'softmax', 4),
        ('helios', 'cosine', 'triangular', 4),
        ('surge', 'linear', 'weighted', 4),
        ('ripple', 'wave', 'mean', 4),
        ('ifrit', 'wave', 'softmax', 4),
        # Deeper step variants
        ('slerp', 'cosine', 'mean', 16),
        ('slip', 'cosine', 'similarity', 16),
        ('min_p', 'linear', 'min_p', 8),
        ('slerp', 'tau', 'similarity_tree', 8),
        ('lerp', 'cosine', 'attention', 8),
        # Learnable variants
        ('slip', 'learnable', 'learnable', 4),
        ('min_p', 'learnable', 'min_p', 4),
        ('slerp', 'learnable', 'similarity', 4),
    ]
    for blend, sched, agg, steps in combos:
        configs.append({
            'name': f'combo_{blend}_{sched}_{agg}_{steps}',
            'use_walker': True,
            'blend': blend,
            'schedule': sched,
            'aggregation': agg,
            'num_steps': steps,
            'learnable_steps': sched == 'learnable',
            'learnable_agg': agg == 'learnable',
        })

    return configs


# -------------------------
# Main
# -------------------------
def run_ablation_barrage():
    """Run the full ablation barrage."""

    # Prepare cached data
    print("="*60)
    print("PREPARING CACHED LATENTS")
    print("="*60)
    train_dataset, test_dataset, dim_a, dim_b = prepare_cached_data()
    print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

    # Generate configs
    configs = generate_ablation_configs()
    print(f"\n{'='*60}")
    print(f"ABLATION BARRAGE: {len(configs)} configurations")
    print("="*60)

    # Results storage
    results = []

    # Run all configs
    for i, config in enumerate(configs):
        print(f"\n[{i+1}/{len(configs)}] {config['name']}")
        print("-" * 40)

        try:
            result = train_single_config(
                config, train_dataset, test_dataset, dim_a, dim_b
            )

            # Log result
            row = {
                'name': config['name'],
                'use_walker': config.get('use_walker', False),
                'blend': config.get('blend', 'N/A'),
                'schedule': config.get('schedule', 'N/A'),
                'aggregation': config.get('aggregation', 'N/A'),
                'num_steps': config.get('num_steps', 'N/A'),
                'learnable_steps': config.get('learnable_steps', False),
                'learnable_agg': config.get('learnable_agg', False),
                'best_acc': result['best_acc'],
                'final_acc': result['final_acc'],
                'final_loss': result['final_loss'],
            }
            results.append(row)

            print(f"  Best: {result['best_acc']:.2f}% | Final: {result['final_acc']:.2f}%")

            # Save intermediate results
            df = pd.DataFrame(results)
            df.to_csv(RESULTS_FILE, index=False)

        except Exception as e:
            print(f"  ERROR: {e}")
            results.append({
                'name': config['name'],
                'use_walker': config.get('use_walker', False),
                'blend': config.get('blend', 'N/A'),
                'schedule': config.get('schedule', 'N/A'),
                'aggregation': config.get('aggregation', 'N/A'),
                'num_steps': config.get('num_steps', 'N/A'),
                'best_acc': -1,
                'final_acc': -1,
                'final_loss': -1,
                'error': str(e),
            })

    # Final summary
    print("\n" + "="*60)
    print("ABLATION COMPLETE")
    print("="*60)

    df = pd.DataFrame(results)
    df = df.sort_values('best_acc', ascending=False)
    df.to_csv(RESULTS_FILE, index=False)

    print(f"\nResults saved to: {RESULTS_FILE}")
    print(f"\nTop 10 configurations:")
    print(df[['name', 'best_acc', 'final_acc']].head(10).to_string(index=False))

    print(f"\nBaseline (AdaptiveFusion):")
    baseline = df[df['name'] == 'baseline_adaptive']
    if len(baseline) > 0:
        print(f"  Best: {baseline['best_acc'].values[0]:.2f}%")

    print(f"\nBest Walker config:")
    walker_df = df[df['use_walker'] == True]
    if len(walker_df) > 0:
        best = walker_df.iloc[0]
        print(f"  {best['name']}: {best['best_acc']:.2f}%")
        print(f"  Blend: {best['blend']}, Schedule: {best['schedule']}, Agg: {best['aggregation']}")

    return df


# -------------------------
# Run
# -------------------------
if __name__ == "__main__":
    results_df = run_ablation_barrage()

  self.setter(val)


Device: cuda
PREPARING CACHED LATENTS


100%|██████████| 170M/170M [00:13<00:00, 12.2MB/s]


Loading encoders...
Encoder dims: ConvNeXt=768, ViT=768
Extracting latents: Train ConvNeXt


Train ConvNeXt:   0%|          | 0/196 [00:00<?, ?it/s]

Cached to ./latent_cache/train_convnext.pt: torch.Size([50000, 768])
Extracting latents: Train ViT


Train ViT:   0%|          | 0/196 [00:00<?, ?it/s]

Cached to ./latent_cache/train_vit.pt: torch.Size([50000, 768])
Extracting latents: Test ConvNeXt


Test ConvNeXt:   0%|          | 0/40 [00:00<?, ?it/s]

Cached to ./latent_cache/test_convnext.pt: torch.Size([10000, 768])
Extracting latents: Test ViT


Test ViT:   0%|          | 0/40 [00:00<?, ?it/s]

Cached to ./latent_cache/test_vit.pt: torch.Size([10000, 768])
Train samples: 50000, Test samples: 10000

ABLATION BARRAGE: 100 configurations

[1/100] baseline_adaptive
----------------------------------------
  Best: 98.02% | Final: 97.75%

[2/100] preset_alucard
----------------------------------------
  Best: 97.59% | Final: 97.52%

[3/100] preset_slerp
----------------------------------------
  Best: 98.07% | Final: 97.59%

[4/100] preset_slip
----------------------------------------
  Best: 98.10% | Final: 98.04%

[5/100] preset_zeus
----------------------------------------
  Best: 97.61% | Final: 97.37%

[6/100] preset_gilgamesh
----------------------------------------
  Best: 98.15% | Final: 97.92%

[7/100] preset_shiva
----------------------------------------
  Best: 98.21% | Final: 97.95%

[8/100] preset_ifrit
----------------------------------------
  Best: 97.98% | Final: 97.96%

[9/100] preset_learnable
----------------------------------------
  Best: 97.85% | Final: 97.85

# Fusion With Fieldwalker Ablation

## test

In [None]:
# =========================
# Fusion Factory Forward Test Suite
# Tests all fusion strategies before ablation
# Includes stress test for performance profiling
# =========================

# !pip -q install timm  # Uncomment for Colab

import torch
import torch.nn as nn
import traceback
import time
from typing import List, Tuple, Dict, Any
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# -------------------------
# Test Configuration
# -------------------------
BATCH_SIZE = 2
HEAD_DIM = 512
NUM_INPUTS = 2

# Stress test config
STRESS_BATCH_SIZE = 64
STRESS_ITERATIONS = 100
STRESS_WARMUP = 10

# -------------------------
# Import all fusion components
# -------------------------
print("\nImporting fusion components...")

try:
    from geofractal.router.components.fusion_component import (
        # Base
        FusionComponent,
        # Basic fusions
        ConcatFusion,
        SumFusion,
        BilinearFusion,
        ResidualFusion,
        # Adaptive fusions
        AdaptiveFusion,
        GatedFusion,
        AttentionFusion,
        # Geometric fusions (from David)
        GeometricAttentionGate,
        CantorScaleFusion,
        HierarchicalTreeGating,
        # Binding fusions (from Lyra)
        AdaptiveBindingFusion,
    )
    FUSION_IMPORTS_OK = True
    print("  ✓ fusion_component imports OK")
except ImportError as e:
    print(f"  ✗ fusion_component import error: {e}")
    FUSION_IMPORTS_OK = False

try:
    from geofractal.router.components.aggregation_component import (
        FieldWalkerFusion, from_preset, WALKER_PRESETS
    )
    WALKER_IMPORTS_OK = True
    print("  ✓ aggregation_component imports OK")
except ImportError as e:
    print(f"  ✗ aggregation_component import error: {e}")
    WALKER_IMPORTS_OK = False

if not (FUSION_IMPORTS_OK and WALKER_IMPORTS_OK):
    raise RuntimeError("Cannot run tests without all imports")


# -------------------------
# Wrapper for AttentionFusion (outputs [B, 1, D] -> [B, D])
# -------------------------
class _AttentionFusionWrapper(nn.Module):
    def __init__(self, fusion):
        super().__init__()
        self.fusion = fusion

    def forward(self, *args):
        out = self.fusion(*args)
        if out.dim() == 3 and out.shape[1] == 1:
            out = out.squeeze(1)
        return out

# -------------------------
# Test Functions
# -------------------------
def test_fusion(name: str, fusion: nn.Module, inputs: List[torch.Tensor],
                expected_shape: Tuple[int, ...]) -> Tuple[bool, str]:
    """Test a single fusion module."""
    try:
        fusion = fusion.to(device)
        inputs = [x.to(device) for x in inputs]

        # Forward pass
        out = fusion(*inputs)

        # Check output shape
        if out.shape != expected_shape:
            return False, f"Shape mismatch: {out.shape} vs expected {expected_shape}"

        # Check for NaN/Inf
        if torch.isnan(out).any():
            return False, "Output contains NaN"
        if torch.isinf(out).any():
            return False, "Output contains Inf"

        return True, f"shape={list(out.shape)}"
    except Exception as e:
        return False, str(e)


def test_fusion_gradient(name: str, fusion: nn.Module,
                         inputs: List[torch.Tensor]) -> Tuple[bool, str]:
    """Test gradient flow through fusion."""
    try:
        fusion = fusion.to(device)
        inputs = [x.clone().to(device).requires_grad_(True) for x in inputs]

        out = fusion(*inputs)
        loss = out.sum()
        loss.backward()

        # Check all inputs got gradients
        for i, inp in enumerate(inputs):
            if inp.grad is None:
                return False, f"No gradient on input {i}"
            if torch.isnan(inp.grad).any():
                return False, f"NaN gradient on input {i}"

        return True, "OK"
    except Exception as e:
        return False, str(e)


def test_walker(name: str, walker: nn.Module,
                a: torch.Tensor, b: torch.Tensor) -> Tuple[bool, str]:
    """Test a FieldWalker module."""
    try:
        walker = walker.to(device)
        a = a.to(device)
        b = b.to(device)

        out = walker(a, b)

        if out.shape != a.shape:
            return False, f"Shape mismatch: {out.shape} vs {a.shape}"

        if torch.isnan(out).any():
            return False, "Output contains NaN"

        return True, f"shape={list(out.shape)}"
    except Exception as e:
        return False, str(e)


def test_walker_gradient(name: str, walker: nn.Module,
                         a: torch.Tensor, b: torch.Tensor) -> Tuple[bool, str]:
    """Test gradient flow through walker."""
    try:
        walker = walker.to(device)
        a = a.clone().to(device).requires_grad_(True)
        b = b.clone().to(device).requires_grad_(True)

        out = walker(a, b)
        loss = out.sum()
        loss.backward()

        if a.grad is None:
            return False, "No gradient on input a"
        if b.grad is None:
            return False, "No gradient on input b"

        return True, "OK"
    except Exception as e:
        return False, str(e)


def get_gpu_memory_mb() -> float:
    """Get current GPU memory usage in MB."""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024 / 1024
    return 0.0


def stress_test_fusion(name: str, factory, mode: str = 'fusion') -> Dict[str, Any]:
    """
    Stress test a fusion module for performance.

    Returns timing and memory stats.
    """
    try:
        # Create module
        module = factory().to(device)

        # Create stress inputs
        if mode == 'fusion':
            inputs = [
                torch.randn(STRESS_BATCH_SIZE, HEAD_DIM, device=device)
                for _ in range(NUM_INPUTS)
            ]
        else:  # walker
            a = torch.randn(STRESS_BATCH_SIZE, HEAD_DIM, device=device)
            b = torch.randn(STRESS_BATCH_SIZE, HEAD_DIM, device=device)

        # Warmup
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        for _ in range(STRESS_WARMUP):
            if mode == 'fusion':
                _ = module(*inputs)
            else:
                _ = module(a, b)

        # Clear memory tracking
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.synchronize()

        mem_before = get_gpu_memory_mb()

        # Timed forward passes
        start = time.perf_counter()
        for _ in range(STRESS_ITERATIONS):
            if mode == 'fusion':
                out = module(*inputs)
            else:
                out = module(a, b)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        elapsed = time.perf_counter() - start

        mem_after = get_gpu_memory_mb()
        peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0

        # Timed backward passes
        if mode == 'fusion':
            inputs = [x.requires_grad_(True) for x in inputs]
        else:
            a = a.requires_grad_(True)
            b = b.requires_grad_(True)

        start_bwd = time.perf_counter()
        for _ in range(STRESS_ITERATIONS):
            if mode == 'fusion':
                out = module(*inputs)
            else:
                out = module(a, b)
            loss = out.sum()
            loss.backward()

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        elapsed_bwd = time.perf_counter() - start_bwd

        # Compute stats
        ms_per_fwd = (elapsed / STRESS_ITERATIONS) * 1000
        ms_per_bwd = (elapsed_bwd / STRESS_ITERATIONS) * 1000
        throughput = (STRESS_ITERATIONS * STRESS_BATCH_SIZE) / elapsed

        # Count params
        n_params = sum(p.numel() for p in module.parameters())

        # Cleanup
        del module, out
        if mode == 'fusion':
            del inputs
        else:
            del a, b
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

        return {
            'success': True,
            'ms_per_fwd': ms_per_fwd,
            'ms_per_bwd': ms_per_bwd,
            'throughput': throughput,
            'peak_mem_mb': peak_mem,
            'params': n_params,
        }

    except Exception as e:
        return {
            'success': False,
            'error': str(e),
        }


# -------------------------
# Run All Tests
# -------------------------
def run_all_tests():
    results = {
        'basic': {},
        'adaptive': {},
        'geometric': {},
        'binding': {},
        'cantor': {},
        'walker_preset': {},
        'walker_custom': {},
        'gradients': {},
    }

    failed = []

    # Test inputs
    x1 = torch.randn(BATCH_SIZE, HEAD_DIM)
    x2 = torch.randn(BATCH_SIZE, HEAD_DIM)
    inputs = [x1, x2]
    expected = (BATCH_SIZE, HEAD_DIM)

    # =========================================================================
    # BASIC FUSIONS
    # =========================================================================
    print("\n" + "="*60)
    print("BASIC FUSIONS")
    print("="*60)

    basic_fusions = [
        ('concat', lambda: ConcatFusion("test_concat", NUM_INPUTS, HEAD_DIM, HEAD_DIM)),
        ('sum', lambda: SumFusion("test_sum", NUM_INPUTS, HEAD_DIM)),
        ('bilinear', lambda: BilinearFusion("test_bilinear", HEAD_DIM)),
        ('residual', lambda: ResidualFusion("test_residual", NUM_INPUTS, HEAD_DIM)),
    ]

    for name, factory in basic_fusions:
        try:
            fusion = factory()
            ok, msg = test_fusion(name, fusion, inputs, expected)
            status = "✓" if ok else "✗"
            print(f"  {status} {name:20s} : {msg}")
            results['basic'][name] = (ok, msg)
            if not ok:
                failed.append(f"basic:{name}")
        except Exception as e:
            print(f"  ✗ {name:20s} : FACTORY ERROR: {e}")
            failed.append(f"basic:{name}")

    # =========================================================================
    # ADAPTIVE FUSIONS
    # =========================================================================
    print("\n" + "="*60)
    print("ADAPTIVE FUSIONS")
    print("="*60)

    adaptive_fusions = [
        ('adaptive', lambda: AdaptiveFusion("test_adaptive", NUM_INPUTS, HEAD_DIM)),
        ('gated', lambda: GatedFusion("test_gated", NUM_INPUTS, HEAD_DIM)),
        ('attention', lambda: _AttentionFusionWrapper(AttentionFusion("test_attention", NUM_INPUTS, HEAD_DIM))),
    ]

    for name, factory in adaptive_fusions:
        try:
            fusion = factory()
            ok, msg = test_fusion(name, fusion, inputs, expected)
            status = "✓" if ok else "✗"
            print(f"  {status} {name:20s} : {msg}")
            results['adaptive'][name] = (ok, msg)
            if not ok:
                failed.append(f"adaptive:{name}")
        except Exception as e:
            print(f"  ✗ {name:20s} : FACTORY ERROR: {e}")
            failed.append(f"adaptive:{name}")

    # =========================================================================
    # GEOMETRIC FUSIONS
    # =========================================================================
    print("\n" + "="*60)
    print("GEOMETRIC FUSIONS (from David)")
    print("="*60)

    geometric_fusions = [
        ('geometric_attention', lambda: GeometricAttentionGate(
            "test_geo", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        )),
        ('cantor_scale', lambda: CantorScaleFusion(
            "test_cantor", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        )),
        ('hierarchical_tree', lambda: HierarchicalTreeGating(
            "test_tree", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        )),
    ]

    for name, factory in geometric_fusions:
        try:
            fusion = factory()
            ok, msg = test_fusion(name, fusion, inputs, expected)
            status = "✓" if ok else "✗"
            print(f"  {status} {name:20s} : {msg}")
            results['geometric'][name] = (ok, msg)
            if not ok:
                failed.append(f"geometric:{name}")
        except Exception as e:
            print(f"  ✗ {name:20s} : FACTORY ERROR: {e}")
            traceback.print_exc()
            failed.append(f"geometric:{name}")

    # =========================================================================
    # BINDING FUSIONS
    # =========================================================================
    print("\n" + "="*60)
    print("BINDING FUSIONS (from Lyra)")
    print("="*60)

    binding_fusions = [
        ('adaptive_binding', lambda: AdaptiveBindingFusion(
            "test_binding", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        )),
    ]

    for name, factory in binding_fusions:
        try:
            fusion = factory()
            ok, msg = test_fusion(name, fusion, inputs, expected)
            status = "✓" if ok else "✗"
            print(f"  {status} {name:20s} : {msg}")
            results['binding'][name] = (ok, msg)
            if not ok:
                failed.append(f"binding:{name}")
        except Exception as e:
            print(f"  ✗ {name:20s} : FACTORY ERROR: {e}")
            traceback.print_exc()
            failed.append(f"binding:{name}")

    # =========================================================================
    # CANTOR-DERIVED FUSIONS
    # =========================================================================
    print("\n" + "="*60)
    print("CANTOR-DERIVED FUSIONS")
    print("="*60)

    cantor_fusions = [
        # InceptiveFusion requires aux_features input - different API
        # Skip for standard 2-input fusion test
        # ('inceptive', lambda: InceptiveFusion(...))
    ]

    if not cantor_fusions:
        print("  (InceptiveFusion skipped - requires aux_features input)")

    for name, factory in cantor_fusions:
        try:
            fusion = factory()
            ok, msg = test_fusion(name, fusion, inputs, expected)
            status = "✓" if ok else "✗"
            print(f"  {status} {name:20s} : {msg}")
            results['cantor'][name] = (ok, msg)
            if not ok:
                failed.append(f"cantor:{name}")
        except Exception as e:
            print(f"  ✗ {name:20s} : FACTORY ERROR: {e}")
            traceback.print_exc()
            failed.append(f"cantor:{name}")

    # =========================================================================
    # WALKER PRESETS
    # =========================================================================
    print("\n" + "="*60)
    print("FIELDWALKER PRESETS")
    print("="*60)

    a = torch.randn(BATCH_SIZE, HEAD_DIM)
    b = torch.randn(BATCH_SIZE, HEAD_DIM)

    for preset_name in WALKER_PRESETS.keys():
        try:
            walker = from_preset(f"test_{preset_name}", preset_name,
                                in_features=HEAD_DIM, num_steps=4)
            ok, msg = test_walker(preset_name, walker, a, b)
            status = "✓" if ok else "✗"
            print(f"  {status} walker_{preset_name:15s} : {msg}")
            results['walker_preset'][preset_name] = (ok, msg)
            if not ok:
                failed.append(f"walker_preset:{preset_name}")
        except Exception as e:
            print(f"  ✗ walker_{preset_name:15s} : {e}")
            failed.append(f"walker_preset:{preset_name}")

    # =========================================================================
    # WALKER CUSTOM CONFIGS (from CIFAR-10 winners)
    # =========================================================================
    print("\n" + "="*60)
    print("FIELDWALKER CUSTOM CONFIGS")
    print("="*60)

    custom_walkers = [
        ('shiva_best', {'blend': 'shiva', 'schedule': 'cosine', 'aggregation': 'similarity_tree', 'num_steps': 4}),
        ('slerp_sigmoid', {'blend': 'slerp', 'schedule': 'sigmoid', 'aggregation': 'softmax', 'num_steps': 4}),
        ('gilgamesh_tri', {'blend': 'gilgamesh', 'schedule': 'linear', 'aggregation': 'triangular', 'num_steps': 8}),
        ('min_p_double', {'blend': 'min_p', 'schedule': 'linear', 'aggregation': 'min_p', 'num_steps': 8}),
        ('zeus_sharp', {'blend': 'zeus', 'schedule': 'sigmoid', 'aggregation': 'last', 'num_steps': 4}),
    ]

    for name, config in custom_walkers:
        try:
            walker = FieldWalkerFusion(
                f"test_{name}",
                in_features=HEAD_DIM,
                num_steps=config['num_steps'],
                blend_mode=config['blend'],
                schedule=config['schedule'],
                aggregation=config['aggregation'],
            )
            ok, msg = test_walker(name, walker, a, b)
            status = "✓" if ok else "✗"
            print(f"  {status} {name:20s} : {msg}")
            results['walker_custom'][name] = (ok, msg)
            if not ok:
                failed.append(f"walker_custom:{name}")
        except Exception as e:
            print(f"  ✗ {name:20s} : {e}")
            failed.append(f"walker_custom:{name}")

    # =========================================================================
    # GRADIENT FLOW TESTS
    # =========================================================================
    print("\n" + "="*60)
    print("GRADIENT FLOW")
    print("="*60)

    gradient_tests = [
        ('adaptive', lambda: AdaptiveFusion("grad_adaptive", NUM_INPUTS, HEAD_DIM), 'fusion'),
        ('geometric_attention', lambda: GeometricAttentionGate(
            "grad_geo", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        ), 'fusion'),
        ('adaptive_binding', lambda: AdaptiveBindingFusion(
            "grad_binding", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        ), 'fusion'),
        # InceptiveFusion skipped - different input API
        ('walker_slerp', lambda: from_preset("grad_slerp", "slerp", in_features=HEAD_DIM), 'walker'),
        ('walker_shiva', lambda: from_preset("grad_shiva", "shiva", in_features=HEAD_DIM), 'walker'),
        ('walker_learnable', lambda: from_preset("grad_learnable", "learnable", in_features=HEAD_DIM), 'walker'),
    ]

    for name, factory, mode in gradient_tests:
        try:
            module = factory()
            if mode == 'fusion':
                ok, msg = test_fusion_gradient(name, module, [x1.clone(), x2.clone()])
            else:
                ok, msg = test_walker_gradient(name, module, a.clone(), b.clone())
            status = "✓" if ok else "✗"
            print(f"  {status} {name:20s} : {msg}")
            results['gradients'][name] = (ok, msg)
            if not ok:
                failed.append(f"gradient:{name}")
        except Exception as e:
            print(f"  ✗ {name:20s} : {e}")
            failed.append(f"gradient:{name}")

    # =========================================================================
    # STRESS TESTS
    # =========================================================================
    print("\n" + "="*60)
    print(f"STRESS TESTS (batch={STRESS_BATCH_SIZE}, iters={STRESS_ITERATIONS})")
    print("="*60)

    stress_results = []

    # Key fusions to stress test
    stress_configs = [
        # Basic
        ('concat+proj', lambda: ConcatFusion("stress_concat", NUM_INPUTS, HEAD_DIM, HEAD_DIM), 'fusion'),
        ('sum', lambda: SumFusion("stress_sum", NUM_INPUTS, HEAD_DIM), 'fusion'),

        # Adaptive
        ('adaptive', lambda: AdaptiveFusion("stress_adaptive", NUM_INPUTS, HEAD_DIM), 'fusion'),
        ('gated', lambda: GatedFusion("stress_gated", NUM_INPUTS, HEAD_DIM), 'fusion'),
        ('attention', lambda: _AttentionFusionWrapper(
            AttentionFusion("stress_attention", NUM_INPUTS, HEAD_DIM)
        ), 'fusion'),

        # Geometric
        ('geometric_attn', lambda: GeometricAttentionGate(
            "stress_geo", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        ), 'fusion'),
        ('cantor_scale', lambda: CantorScaleFusion(
            "stress_cantor", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        ), 'fusion'),
        ('hier_tree', lambda: HierarchicalTreeGating(
            "stress_tree", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        ), 'fusion'),

        # Binding
        ('adaptive_bind', lambda: AdaptiveBindingFusion(
            "stress_binding", in_features=HEAD_DIM, out_features=HEAD_DIM, num_inputs=NUM_INPUTS
        ), 'fusion'),

        # Walkers - key presets
        ('walker_lerp', lambda: from_preset("stress_lerp", "alucard", in_features=HEAD_DIM, num_steps=4), 'walker'),
        ('walker_slerp', lambda: from_preset("stress_slerp", "slerp", in_features=HEAD_DIM, num_steps=4), 'walker'),
        ('walker_slip', lambda: from_preset("stress_slip", "slip", in_features=HEAD_DIM, num_steps=4), 'walker'),
        ('walker_shiva', lambda: from_preset("stress_shiva", "shiva", in_features=HEAD_DIM, num_steps=4), 'walker'),
        ('walker_min_p', lambda: from_preset("stress_min_p", "min_p", in_features=HEAD_DIM, num_steps=4), 'walker'),

        # Walker with more steps
        ('walker_8step', lambda: FieldWalkerFusion(
            "stress_8step", in_features=HEAD_DIM, num_steps=8,
            blend_mode='slerp', schedule='cosine', aggregation='mean'
        ), 'walker'),
        ('walker_16step', lambda: FieldWalkerFusion(
            "stress_16step", in_features=HEAD_DIM, num_steps=16,
            blend_mode='slerp', schedule='cosine', aggregation='mean'
        ), 'walker'),
    ]

    print(f"\n  {'Name':<18s} {'Fwd(ms)':<10s} {'Bwd(ms)':<10s} {'Throughput':<12s} {'Mem(MB)':<10s} {'Params':<10s}")
    print("  " + "-"*70)

    for name, factory, mode in stress_configs:
        stats = stress_test_fusion(name, factory, mode)

        if stats['success']:
            print(f"  {name:<18s} {stats['ms_per_fwd']:<10.3f} {stats['ms_per_bwd']:<10.3f} "
                  f"{stats['throughput']:<12.0f} {stats['peak_mem_mb']:<10.1f} {stats['params']:<10d}")
            stress_results.append({
                'name': name,
                'ms_per_fwd': stats['ms_per_fwd'],
                'ms_per_bwd': stats['ms_per_bwd'],
                'throughput': stats['throughput'],
                'peak_mem_mb': stats['peak_mem_mb'],
                'params': stats['params'],
            })
        else:
            print(f"  {name:<18s} FAILED: {stats.get('error', 'unknown')}")
            failed.append(f"stress:{name}")

    # Flag any suspiciously slow fusions
    if stress_results:
        avg_fwd = sum(r['ms_per_fwd'] for r in stress_results) / len(stress_results)
        avg_bwd = sum(r['ms_per_bwd'] for r in stress_results) / len(stress_results)

        print(f"\n  Averages: fwd={avg_fwd:.3f}ms, bwd={avg_bwd:.3f}ms")

        slow_threshold = avg_fwd * 3  # Flag if 3x slower than average
        slow_fusions = [r for r in stress_results if r['ms_per_fwd'] > slow_threshold]

        if slow_fusions:
            print(f"\n  ⚠ SLOW FUSIONS (>{slow_threshold:.2f}ms):")
            for r in slow_fusions:
                print(f"    - {r['name']}: {r['ms_per_fwd']:.3f}ms ({r['ms_per_fwd']/avg_fwd:.1f}x avg)")
        else:
            print(f"\n  ✓ No unusually slow fusions detected")

    # =========================================================================
    # SUMMARY
    # =========================================================================
    print("\n" + "="*60)
    print("SUMMARY")
    print("="*60)

    total_tests = sum(len(v) for v in results.values()) + len(stress_results)
    total_passed = total_tests - len(failed)

    print(f"Total tests: {total_tests}")
    print(f"Passed: {total_passed}")
    print(f"Failed: {len(failed)}")

    if failed:
        print(f"\nFailed tests:")
        for f in failed:
            print(f"  - {f}")
        return False, stress_results
    else:
        print("\n✓ ALL TESTS PASSED - Ready for CIFAR-100 ablation")
        return True, stress_results


# -------------------------
# Main
# -------------------------
if __name__ == "__main__":
    success, stress_results = run_all_tests()

    if success:
        print("\n" + "="*60)
        print("ABLATION PREVIEW")
        print("="*60)
        print(f"  Basic fusions:      4")
        print(f"  Adaptive fusions:   3")
        print(f"  Geometric fusions:  3")
        print(f"  Binding fusions:    1")
        print(f"  Walker presets:     {len(WALKER_PRESETS)}")
        print(f"  Walker custom:      5")
        print(f"  -------------------------")
        print(f"  Total configs:      {4+3+3+1+len(WALKER_PRESETS)+5}")

        if stress_results:
            print("\n" + "="*60)
            print("PERFORMANCE RANKING (by throughput)")
            print("="*60)
            sorted_results = sorted(stress_results, key=lambda x: x['throughput'], reverse=True)
            for i, r in enumerate(sorted_results[:10], 1):
                print(f"  {i:2d}. {r['name']:<18s} {r['throughput']:>8.0f} samples/sec")

Device: cuda

Importing fusion components...
  ✓ fusion_component imports OK
  ✓ aggregation_component imports OK

BASIC FUSIONS
  ✓ concat               : shape=[2, 512]
  ✓ sum                  : shape=[2, 512]
  ✓ bilinear             : shape=[2, 512]
  ✓ residual             : shape=[2, 512]

ADAPTIVE FUSIONS
  ✓ adaptive             : shape=[2, 512]
  ✓ gated                : shape=[2, 512]
  ✓ attention            : shape=[2, 512]

GEOMETRIC FUSIONS (from David)
  ✓ geometric_attention  : shape=[2, 512]
  ✓ cantor_scale         : shape=[2, 512]
  ✓ hierarchical_tree    : shape=[2, 512]

BINDING FUSIONS (from Lyra)
  ✓ adaptive_binding     : shape=[2, 512]

CANTOR-DERIVED FUSIONS
  (InceptiveFusion skipped - requires aux_features input)

FIELDWALKER PRESETS
  ✓ walker_alucard         : shape=[2, 512]
  ✓ walker_slerp           : shape=[2, 512]
  ✓ walker_slip            : shape=[2, 512]
  ✓ walker_zeus            : shape=[2, 512]
  ✓ walker_gilgamesh       : shape=[2, 512]
  ✓ wal

## run

In [None]:
# =========================
# CIFAR-100 Fusion Factory Ablation
# Tests all FusionComponent strategies + top FieldWalker configs
# Cached latents for fast iteration
# =========================

# !pip -q install timm tqdm pandas  # Uncomment for Colab

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from tqdm.auto import tqdm
import timm
import pandas as pd
import os
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional
from enum import Enum
import gc

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent

# Fusion components
from geofractal.router.components.fusion_component import (
    # Base
    FusionComponent,
    # Basic fusions
    ConcatFusion,
    SumFusion,
    BilinearFusion,
    ResidualFusion,
    # Adaptive fusions
    AdaptiveFusion,
    GatedFusion,
    AttentionFusion,
    # Geometric fusions (from David)
    GeometricAttentionGate,
    CantorScaleFusion,
    HierarchicalTreeGating,
    # Binding fusions (from Lyra)
    AdaptiveBindingFusion,
)

# FieldWalker (our new system)
from geofractal.router.components.aggregation_component import (
    FieldWalkerFusion, from_preset, WALKER_PRESETS
)

# -------------------------
# Config
# -------------------------
CACHE_DIR = "./latent_cache_cifar100"
RESULTS_FILE = f"cifar100_fusion_ablation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
EPOCHS_PER_RUN = 10
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 128
HEAD_DIM = 512
NUM_TOWERS = 8

# Stability
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# -------------------------
# Encoder Setup
# -------------------------
ENC_A = "convnext_small.dinov3_lvd1689m"
ENC_B = "vit_base_patch16_dinov3.lvd1689m"

def load_encoder(name):
    enc = timm.create_model(name, pretrained=True, num_classes=0, global_pool="avg")
    enc.to(device)
    enc.eval()
    for p in enc.parameters():
        p.requires_grad = False
    return enc

# -------------------------
# Latent Caching
# -------------------------
def cache_latents(encoder, dataloader, cache_path: str, desc: str):
    """Extract and cache latents from encoder."""
    if os.path.exists(cache_path):
        print(f"Loading cached latents from {cache_path}")
        data = torch.load(cache_path)
        return data['latents'], data['labels']

    print(f"Extracting latents: {desc}")
    all_latents = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in tqdm(dataloader, desc=desc):
            imgs = imgs.to(device)
            latents = encoder(imgs)
            all_latents.append(latents.cpu())
            all_labels.append(labels)

    latents = torch.cat(all_latents, dim=0)
    labels = torch.cat(all_labels, dim=0)

    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
    torch.save({'latents': latents, 'labels': labels}, cache_path)
    print(f"Cached to {cache_path}: {latents.shape}")

    return latents, labels


def prepare_cached_data():
    """Prepare all cached latents for CIFAR-100."""

    tf = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    train_ds = datasets.CIFAR100("./data", train=True, download=True, transform=tf)
    test_ds = datasets.CIFAR100("./data", train=False, download=True, transform=tf)

    train_loader = DataLoader(train_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

    print("Loading encoders...")
    enc_a = load_encoder(ENC_A)
    enc_b = load_encoder(ENC_B)

    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224, device=device)
        DIM_A = enc_a(dummy).shape[-1]
        DIM_B = enc_b(dummy).shape[-1]
    print(f"Encoder dims: ConvNeXt={DIM_A}, ViT={DIM_B}")

    train_a, train_labels = cache_latents(
        enc_a, train_loader, f"{CACHE_DIR}/train_convnext.pt", "Train ConvNeXt"
    )
    train_b, _ = cache_latents(
        enc_b, train_loader, f"{CACHE_DIR}/train_vit.pt", "Train ViT"
    )
    test_a, test_labels = cache_latents(
        enc_a, test_loader, f"{CACHE_DIR}/test_convnext.pt", "Test ConvNeXt"
    )
    test_b, _ = cache_latents(
        enc_b, test_loader, f"{CACHE_DIR}/test_vit.pt", "Test ViT"
    )

    del enc_a, enc_b
    torch.cuda.empty_cache()
    gc.collect()

    train_dataset = TensorDataset(train_a, train_b, train_labels)
    test_dataset = TensorDataset(test_a, test_b, test_labels)

    return train_dataset, test_dataset, DIM_A, DIM_B


# -------------------------
# Model Components
# -------------------------
class FFNBlock(TorchComponent):
    def __init__(self, name, dim):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return self.net(x)


class Tower(BaseTower):
    def __init__(self, name, dim, depth=2):
        super().__init__(name, strict=False)
        for i in range(depth):
            self.append(FFNBlock(f"{name}_ffn_{i}", dim))
        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x):
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)


class SubCollective(WideRouter):
    def __init__(self, name, in_dim, head_dim, num_towers):
        super().__init__(name, auto_discover=True)

        self.attach("proj", nn.Linear(in_dim, head_dim))
        for i in range(num_towers):
            self.attach(f"tower_{i}", Tower(f"{name}_tower_{i}", head_dim))

        self.discover_towers()
        self.attach("fusion", AdaptiveFusion(f"{name}_fusion", num_towers, head_dim))

    def forward(self, feats):
        x = self["proj"](feats)
        opinions = self.wide_forward(x)
        return self["fusion"](*opinions.values())


# -------------------------
# Fusion Factory
# -------------------------
def create_meta_fusion(fusion_type: str, config: Dict[str, Any]) -> nn.Module:
    """
    Factory for creating meta-fusion modules.

    All fusions take 2 inputs of shape [B, HEAD_DIM] and output [B, HEAD_DIM].
    """
    dim = HEAD_DIM
    n_inputs = 2

    # === Basic Fusions ===
    if fusion_type == 'concat':
        # ConcatFusion(name, num_inputs, in_features, out_features) - has built-in projection
        return ConcatFusion("meta_concat", n_inputs, dim, dim)

    elif fusion_type == 'sum':
        return SumFusion("meta_sum", n_inputs, dim)

    elif fusion_type == 'bilinear':
        return BilinearFusion("meta_bilinear", dim)

    elif fusion_type == 'residual':
        return ResidualFusion("meta_residual", n_inputs, dim)

    # SlotFusion has different API - skip
    # elif fusion_type == 'slot':
    #     return SlotFusion("meta_slot", n_inputs, dim)

    # === Adaptive Fusions ===
    elif fusion_type == 'adaptive':
        return AdaptiveFusion("meta_adaptive", n_inputs, dim)

    elif fusion_type == 'gated':
        return GatedFusion("meta_gated", n_inputs, dim)

    elif fusion_type == 'attention':
        # AttentionFusion outputs [B, 1, D], need squeeze
        class AttentionWrapper(nn.Module):
            def __init__(self):
                super().__init__()
                self.fusion = AttentionFusion("meta_attention", n_inputs, dim)
            def forward(self, *args):
                out = self.fusion(*args)
                return out.squeeze(1) if out.dim() == 3 else out
        return AttentionWrapper()

    # === Geometric Fusions (from David) ===
    elif fusion_type == 'geometric_attention':
        return GeometricAttentionGate(
            "meta_geometric",
            in_features=dim,
            out_features=dim,
            num_inputs=n_inputs,
        )

    elif fusion_type == 'cantor_scale':
        return CantorScaleFusion(
            "meta_cantor",
            in_features=dim,
            out_features=dim,
            num_inputs=n_inputs,
        )

    elif fusion_type == 'hierarchical_tree':
        return HierarchicalTreeGating(
            "meta_tree",
            in_features=dim,
            out_features=dim,
            num_inputs=n_inputs,
        )

    # === Binding Fusions (from Lyra) ===
    elif fusion_type == 'adaptive_binding':
        return AdaptiveBindingFusion(
            "meta_binding",
            in_features=dim,
            out_features=dim,
            num_inputs=n_inputs,
        )

    # === Cantor-derived ===
    # InceptiveFusion requires aux_features - different API, skip
    # elif fusion_type == 'inceptive':
    #     ...

    # === FieldWalker Fusions ===
    elif fusion_type.startswith('walker_'):
        preset_name = fusion_type.replace('walker_', '').replace('custom_', '')
        if preset_name in WALKER_PRESETS:
            return from_preset(
                f"meta_{preset_name}",
                preset_name,
                in_features=dim,
                num_steps=config.get('num_steps', 4),
            )
        else:
            # Custom walker config
            return FieldWalkerFusion(
                "meta_walker",
                in_features=dim,
                num_steps=config.get('num_steps', 4),
                blend_mode=config.get('blend', 'slerp'),
                schedule=config.get('schedule', 'cosine'),
                aggregation=config.get('aggregation', 'mean'),
            )

    else:
        raise ValueError(f"Unknown fusion type: {fusion_type}")


class DualCollective(nn.Module):
    def __init__(self, dim_a, dim_b, fusion_type: str, config: Dict[str, Any]):
        super().__init__()
        self.sub_a = SubCollective("convnext", dim_a, HEAD_DIM, NUM_TOWERS)
        self.sub_b = SubCollective("vit", dim_b, HEAD_DIM, NUM_TOWERS)

        self.fusion_type = fusion_type
        self.meta_fusion = create_meta_fusion(fusion_type, config)

        # Check if fusion needs special handling
        self.is_walker = fusion_type.startswith('walker_')

        self.head = nn.Linear(HEAD_DIM, 100)  # CIFAR-100

    def forward(self, fa, fb):
        oa = self.sub_a(fa)
        ob = self.sub_b(fb)

        if self.is_walker:
            # FieldWalker takes (source, target)
            fused = self.meta_fusion(oa, ob)
        else:
            # Standard fusions take *args
            fused = self.meta_fusion(oa, ob)

        return self.head(fused)


# -------------------------
# Training
# -------------------------
def train_single_config(
    fusion_type: str,
    config: Dict[str, Any],
    train_dataset: TensorDataset,
    test_dataset: TensorDataset,
    dim_a: int,
    dim_b: int,
    epochs: int = EPOCHS_PER_RUN,
) -> Dict[str, Any]:
    """Train a single fusion configuration."""

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True,
        num_workers=0, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=False,
        num_workers=0, pin_memory=True
    )

    model = DualCollective(dim_a, dim_b, fusion_type, config).to(device)
    model.sub_a.prepare_and_compile()
    model.sub_b.prepare_and_compile()

    # Count params
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        n_batches = 0

        for fa, fb, labels in train_loader:
            fa, fb, labels = fa.to(device), fb.to(device), labels.to(device)

            opt.zero_grad()
            logits = model(fa, fb)
            loss = loss_fn(logits, labels)
            loss.backward()
            opt.step()

            train_loss += loss.item()
            n_batches += 1

        avg_loss = train_loss / n_batches

        model.eval()
        correct = total = 0

        with torch.no_grad():
            for fa, fb, labels in test_loader:
                fa, fb, labels = fa.to(device), fb.to(device), labels.to(device)
                preds = model(fa, fb).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100 * correct / total
        best_acc = max(best_acc, acc)

        history.append({'epoch': epoch + 1, 'loss': avg_loss, 'acc': acc})
        print(f"    Epoch {epoch+1:2d}: loss={avg_loss:.4f}, acc={acc:.2f}%")

    del model, opt
    torch.cuda.empty_cache()
    gc.collect()

    return {
        'fusion_type': fusion_type,
        'config': config,
        'best_acc': best_acc,
        'final_acc': history[-1]['acc'],
        'final_loss': history[-1]['loss'],
        'trainable_params': trainable_params,
        'history': history,
    }


# -------------------------
# Ablation Matrix
# -------------------------
def generate_fusion_configs() -> List[Tuple[str, Dict[str, Any]]]:
    """Generate all fusion configurations to test."""
    configs = []

    # === Basic Fusions ===
    configs.append(('concat', {}))
    configs.append(('sum', {}))
    configs.append(('bilinear', {}))
    configs.append(('residual', {}))
    # SlotFusion skipped - different API

    # === Adaptive Fusions ===
    configs.append(('adaptive', {}))
    configs.append(('gated', {}))
    configs.append(('attention', {}))

    # === Geometric Fusions ===
    configs.append(('geometric_attention', {}))
    configs.append(('cantor_scale', {}))
    configs.append(('hierarchical_tree', {}))

    # === Binding Fusions ===
    configs.append(('adaptive_binding', {}))

    # === Cantor-derived ===
    # InceptiveFusion skipped - requires aux_features

    # === FieldWalker Presets (all 10) ===
    for preset in WALKER_PRESETS.keys():
        configs.append((f'walker_{preset}', {'num_steps': 4}))

    # === Top Walker configs from CIFAR-10 ablation ===
    # Winner: shiva + cosine + similarity_tree
    configs.append(('walker_custom_shiva_best', {
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
    }))

    # Runner up: slerp + sigmoid + softmax
    configs.append(('walker_custom_slerp_sigmoid', {
        'blend': 'slerp',
        'schedule': 'sigmoid',
        'aggregation': 'softmax',
        'num_steps': 4,
    }))

    # Gilgamesh variant
    configs.append(('walker_custom_gilgamesh', {
        'blend': 'gilgamesh',
        'schedule': 'linear',
        'aggregation': 'triangular',
        'num_steps': 8,
    }))

    # Min-P double
    configs.append(('walker_custom_min_p_double', {
        'blend': 'min_p',
        'schedule': 'linear',
        'aggregation': 'min_p',
        'num_steps': 8,
    }))

    # Zeus sharp
    configs.append(('walker_custom_zeus_sharp', {
        'blend': 'zeus',
        'schedule': 'sigmoid',
        'aggregation': 'last',
        'num_steps': 4,
    }))

    return configs


# -------------------------
# Main
# -------------------------
def run_fusion_ablation():
    """Run the fusion factory ablation."""

    print("="*60)
    print("PREPARING CACHED LATENTS (CIFAR-100)")
    print("="*60)
    train_dataset, test_dataset, dim_a, dim_b = prepare_cached_data()
    print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

    configs = generate_fusion_configs()
    print(f"\n{'='*60}")
    print(f"FUSION FACTORY ABLATION: {len(configs)} configurations")
    print("="*60)

    results = []

    for i, (fusion_type, config) in enumerate(configs):
        print(f"\n[{i+1}/{len(configs)}] {fusion_type}")
        print("-" * 40)

        try:
            result = train_single_config(
                fusion_type, config, train_dataset, test_dataset, dim_a, dim_b
            )

            row = {
                'name': fusion_type,
                'category': get_fusion_category(fusion_type),
                'best_acc': result['best_acc'],
                'final_acc': result['final_acc'],
                'final_loss': result['final_loss'],
                'params': result['trainable_params'],
            }
            results.append(row)

            print(f"  Best: {result['best_acc']:.2f}% | Final: {result['final_acc']:.2f}%")

            df = pd.DataFrame(results)
            df.to_csv(RESULTS_FILE, index=False)

        except Exception as e:
            print(f"  ERROR: {e}")
            import traceback
            traceback.print_exc()
            results.append({
                'name': fusion_type,
                'category': get_fusion_category(fusion_type),
                'best_acc': -1,
                'final_acc': -1,
                'final_loss': -1,
                'error': str(e),
            })

    # Summary
    print("\n" + "="*60)
    print("ABLATION COMPLETE")
    print("="*60)

    df = pd.DataFrame(results)
    df = df.sort_values('best_acc', ascending=False)
    df.to_csv(RESULTS_FILE, index=False)

    print(f"\nResults saved to: {RESULTS_FILE}")

    print(f"\n{'='*60}")
    print("RESULTS BY CATEGORY")
    print("="*60)

    for cat in ['basic', 'adaptive', 'geometric', 'binding', 'walker']:
        cat_df = df[df['category'] == cat].sort_values('best_acc', ascending=False)
        if len(cat_df) > 0:
            print(f"\n{cat.upper()}:")
            for _, row in cat_df.head(5).iterrows():
                print(f"  {row['name']:35s} | Best: {row['best_acc']:5.2f}%")

    print(f"\n{'='*60}")
    print("TOP 10 OVERALL")
    print("="*60)
    print(df[['name', 'category', 'best_acc', 'final_acc']].head(10).to_string(index=False))

    return df


def get_fusion_category(fusion_type: str) -> str:
    """Categorize fusion type."""
    if fusion_type in ['concat', 'sum', 'bilinear', 'residual']:
        return 'basic'
    elif fusion_type in ['adaptive', 'gated', 'attention']:
        return 'adaptive'
    elif fusion_type in ['geometric_attention', 'cantor_scale', 'hierarchical_tree']:
        return 'geometric'
    elif fusion_type in ['adaptive_binding']:
        return 'binding'
    elif fusion_type.startswith('walker_'):
        return 'walker'
    else:
        return 'unknown'


# -------------------------
# Run
# -------------------------
if __name__ == "__main__":
    results_df = run_fusion_ablation()

Device: cuda
PREPARING CACHED LATENTS (CIFAR-100)
Loading encoders...
Encoder dims: ConvNeXt=768, ViT=768
Loading cached latents from ./latent_cache_cifar100/train_convnext.pt
Loading cached latents from ./latent_cache_cifar100/train_vit.pt
Loading cached latents from ./latent_cache_cifar100/test_convnext.pt
Loading cached latents from ./latent_cache_cifar100/test_vit.pt
Train samples: 50000, Test samples: 10000

FUSION FACTORY ABLATION: 26 configurations

[1/26] concat
----------------------------------------
    Epoch  1: loss=0.6086, acc=86.03%
    Epoch  2: loss=0.3446, acc=86.97%
    Epoch  3: loss=0.2623, acc=86.01%
    Epoch  4: loss=0.2094, acc=86.78%
    Epoch  5: loss=0.1693, acc=87.05%
    Epoch  6: loss=0.1391, acc=86.96%
    Epoch  7: loss=0.1145, acc=86.60%
    Epoch  8: loss=0.0927, acc=86.46%
    Epoch  9: loss=0.0841, acc=86.71%
    Epoch 10: loss=0.0858, acc=86.26%
  Best: 87.05% | Final: 86.26%

[2/26] sum
----------------------------------------
    Epoch  1: loss=0

# triple vision encoder walker sweep

In [None]:
# =========================
# CIFAR-100 Triple Encoder Walker Ablation
# 3 encoders: ConvNeXt + DINOv3 ViT + CLIP ViT
# Massive parameter sweep with crash protection
# =========================

# !pip -q install timm tqdm pandas  # Uncomment for Colab

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from tqdm.auto import tqdm
import timm
import pandas as pd
import os
import traceback
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional
from itertools import product
import gc
import json

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.fusion_component import (
    AdaptiveFusion, SumFusion, ConcatFusion, GatedFusion,
)
from geofractal.router.components.aggregation_component import (
    FieldWalkerFusion, from_preset, WALKER_PRESETS,
    BLEND_MODES, SCHEDULES, AGGREGATIONS,
)

# -------------------------
# Config
# -------------------------
CACHE_DIR = "./latent_cache_triple"
RESULTS_FILE = f"triple_walker_ablation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
CHECKPOINT_FILE = "ablation_checkpoint.json"
EPOCHS_PER_RUN = 10
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 128
HEAD_DIM = 512
NUM_TOWERS = 8

# Stability
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# -------------------------
# Encoders
# -------------------------
ENCODERS = {
    'convnext': "convnext_small.dinov3_lvd1689m",
    'dino_vit': "vit_base_patch16_dinov3.lvd1689m",
    'clip_vit': "vit_base_patch16_clip_224.laion2b_ft_in12k_in1k",
}

def load_encoder(name):
    enc = timm.create_model(name, pretrained=True, num_classes=0, global_pool="avg")
    enc.to(device)
    enc.eval()
    for p in enc.parameters():
        p.requires_grad = False
    return enc

# -------------------------
# Latent Caching
# -------------------------
def cache_latents(encoder, dataloader, cache_path: str, desc: str):
    """Extract and cache latents from encoder."""
    if os.path.exists(cache_path):
        print(f"Loading cached: {cache_path}")
        data = torch.load(cache_path)
        return data['latents'], data['labels']

    print(f"Extracting: {desc}")
    all_latents = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in tqdm(dataloader, desc=desc):
            imgs = imgs.to(device)
            latents = encoder(imgs)
            all_latents.append(latents.cpu())
            all_labels.append(labels)

    latents = torch.cat(all_latents, dim=0)
    labels = torch.cat(all_labels, dim=0)

    os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else '.', exist_ok=True)
    torch.save({'latents': latents, 'labels': labels}, cache_path)
    print(f"Cached: {cache_path} {latents.shape}")

    return latents, labels


def prepare_cached_data():
    """Prepare cached latents for all 3 encoders."""

    tf = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    # CLIP needs 384 but we'll use 224 for consistency - it still works
    train_ds = datasets.CIFAR100("./data", train=True, download=True, transform=tf)
    test_ds = datasets.CIFAR100("./data", train=False, download=True, transform=tf)

    train_loader = DataLoader(train_ds, batch_size=256, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4)

    dims = {}
    train_latents = {}
    test_latents = {}
    train_labels = None
    test_labels = None

    for enc_name, enc_model in ENCODERS.items():
        print(f"\nLoading {enc_name}...")
        enc = load_encoder(enc_model)

        # Get dim
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 224, 224, device=device)
            dims[enc_name] = enc(dummy).shape[-1]
        print(f"  {enc_name} dim: {dims[enc_name]}")

        # Cache
        train_latents[enc_name], train_labels = cache_latents(
            enc, train_loader, f"{CACHE_DIR}/train_{enc_name}.pt", f"Train {enc_name}"
        )
        test_latents[enc_name], test_labels = cache_latents(
            enc, test_loader, f"{CACHE_DIR}/test_{enc_name}.pt", f"Test {enc_name}"
        )

        # Free encoder
        del enc
        torch.cuda.empty_cache()

    gc.collect()

    # Create datasets
    train_dataset = TensorDataset(
        train_latents['convnext'],
        train_latents['dino_vit'],
        train_latents['clip_vit'],
        train_labels
    )
    test_dataset = TensorDataset(
        test_latents['convnext'],
        test_latents['dino_vit'],
        test_latents['clip_vit'],
        test_labels
    )

    return train_dataset, test_dataset, dims


# -------------------------
# Model Components
# -------------------------
class FFNBlock(TorchComponent):
    def __init__(self, name, dim):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return self.net(x)


class Tower(BaseTower):
    def __init__(self, name, dim, depth=2):
        super().__init__(name, strict=False)
        for i in range(depth):
            self.append(FFNBlock(f"{name}_ffn_{i}", dim))
        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x):
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)


class SubCollective(WideRouter):
    def __init__(self, name, in_dim, head_dim, num_towers):
        super().__init__(name, auto_discover=True)

        self.attach("proj", nn.Linear(in_dim, head_dim))
        for i in range(num_towers):
            self.attach(f"tower_{i}", Tower(f"{name}_tower_{i}", head_dim))

        self.discover_towers()
        self.attach("fusion", AdaptiveFusion(f"{name}_fusion", num_towers, head_dim))

    def forward(self, feats):
        x = self["proj"](feats)
        opinions = self.wide_forward(x)
        return self["fusion"](*opinions.values())


# -------------------------
# 3-Way Fusion Strategies
# -------------------------
class HierarchicalWalkerFusion(nn.Module):
    """
    Hierarchical 3-way fusion: (A, B) -> AB, then (AB, C) -> final

    Uses two walkers in sequence.
    """
    def __init__(self, dim: int, config: Dict[str, Any]):
        super().__init__()

        # First walker: A + B
        self.walker1 = FieldWalkerFusion(
            "walker_ab",
            in_features=dim,
            num_steps=config.get('num_steps', 4),
            blend_mode=config.get('blend', 'slerp'),
            schedule=config.get('schedule', 'cosine'),
            aggregation=config.get('aggregation', 'mean'),
        )

        # Second walker: AB + C (can use different config)
        self.walker2 = FieldWalkerFusion(
            "walker_abc",
            in_features=dim,
            num_steps=config.get('num_steps_2', config.get('num_steps', 4)),
            blend_mode=config.get('blend_2', config.get('blend', 'slerp')),
            schedule=config.get('schedule_2', config.get('schedule', 'cosine')),
            aggregation=config.get('aggregation_2', config.get('aggregation', 'mean')),
        )

    def forward(self, a, b, c):
        ab = self.walker1(a, b)
        return self.walker2(ab, c)


class PairwiseWalkerFusion(nn.Module):
    """
    Pairwise 3-way fusion: compute (A,B), (B,C), (A,C) then combine.

    More expensive but captures all relationships.
    """
    def __init__(self, dim: int, config: Dict[str, Any]):
        super().__init__()

        walker_kwargs = {
            'in_features': dim,
            'num_steps': config.get('num_steps', 4),
            'blend_mode': config.get('blend', 'slerp'),
            'schedule': config.get('schedule', 'cosine'),
            'aggregation': config.get('aggregation', 'mean'),
        }

        self.walker_ab = FieldWalkerFusion("walker_ab", **walker_kwargs)
        self.walker_bc = FieldWalkerFusion("walker_bc", **walker_kwargs)
        self.walker_ac = FieldWalkerFusion("walker_ac", **walker_kwargs)

        # Combine the 3 pairwise outputs
        combine_mode = config.get('combine_mode', 'sum')
        if combine_mode == 'sum':
            self.combine = SumFusion("combine", 3, dim)
        elif combine_mode == 'adaptive':
            self.combine = AdaptiveFusion("combine", 3, dim)
        elif combine_mode == 'concat':
            self.combine = nn.Sequential(
                ConcatFusion("combine", 3, dim, dim),
            )
        elif combine_mode == 'gated':
            self.combine = GatedFusion("combine", 3, dim)
        else:
            self.combine = SumFusion("combine", 3, dim)

    def forward(self, a, b, c):
        ab = self.walker_ab(a, b)
        bc = self.walker_bc(b, c)
        ac = self.walker_ac(a, c)
        return self.combine(ab, bc, ac)


class HubWalkerFusion(nn.Module):
    """
    Hub-and-spoke fusion: A is hub, walk to B and C, combine.

    Useful when one encoder is "primary".
    """
    def __init__(self, dim: int, config: Dict[str, Any]):
        super().__init__()

        walker_kwargs = {
            'in_features': dim,
            'num_steps': config.get('num_steps', 4),
            'blend_mode': config.get('blend', 'slerp'),
            'schedule': config.get('schedule', 'cosine'),
            'aggregation': config.get('aggregation', 'mean'),
        }

        self.walker_ab = FieldWalkerFusion("walker_ab", **walker_kwargs)
        self.walker_ac = FieldWalkerFusion("walker_ac", **walker_kwargs)

        # Combine
        combine_mode = config.get('combine_mode', 'sum')
        if combine_mode == 'sum':
            self.combine = SumFusion("combine", 2, dim)
        elif combine_mode == 'adaptive':
            self.combine = AdaptiveFusion("combine", 2, dim)
        else:
            self.combine = SumFusion("combine", 2, dim)

    def forward(self, a, b, c):
        ab = self.walker_ab(a, b)
        ac = self.walker_ac(a, c)
        return self.combine(ab, ac)


class ChainWalkerFusion(nn.Module):
    """
    Chain fusion: A -> B -> C (sequential walk through all)
    """
    def __init__(self, dim: int, config: Dict[str, Any]):
        super().__init__()

        walker_kwargs = {
            'in_features': dim,
            'num_steps': config.get('num_steps', 4),
            'blend_mode': config.get('blend', 'slerp'),
            'schedule': config.get('schedule', 'cosine'),
            'aggregation': config.get('aggregation', 'mean'),
        }

        self.walker_ab = FieldWalkerFusion("walker_ab", **walker_kwargs)
        self.walker_bc = FieldWalkerFusion("walker_bc", **walker_kwargs)

    def forward(self, a, b, c):
        ab = self.walker_ab(a, b)
        return self.walker_bc(ab, c)


def create_triple_fusion(strategy: str, dim: int, config: Dict[str, Any]) -> nn.Module:
    """Factory for 3-way fusion strategies."""
    if strategy == 'hierarchical':
        return HierarchicalWalkerFusion(dim, config)
    elif strategy == 'pairwise':
        return PairwiseWalkerFusion(dim, config)
    elif strategy == 'hub':
        return HubWalkerFusion(dim, config)
    elif strategy == 'chain':
        return ChainWalkerFusion(dim, config)
    elif strategy == 'sum':
        return SumFusion("meta_sum", 3, dim)
    elif strategy == 'adaptive':
        return AdaptiveFusion("meta_adaptive", 3, dim)
    elif strategy == 'concat':
        return ConcatFusion("meta_concat", 3, dim, dim)
    else:
        raise ValueError(f"Unknown strategy: {strategy}")


# -------------------------
# Triple Collective Model
# -------------------------
class TripleCollective(nn.Module):
    def __init__(self, dims: Dict[str, int], strategy: str, config: Dict[str, Any]):
        super().__init__()

        self.sub_convnext = SubCollective("convnext", dims['convnext'], HEAD_DIM, NUM_TOWERS)
        self.sub_dino = SubCollective("dino_vit", dims['dino_vit'], HEAD_DIM, NUM_TOWERS)
        self.sub_clip = SubCollective("clip_vit", dims['clip_vit'], HEAD_DIM, NUM_TOWERS)

        self.strategy = strategy
        self.meta_fusion = create_triple_fusion(strategy, HEAD_DIM, config)

        self.head = nn.Linear(HEAD_DIM, 100)  # CIFAR-100

    def forward(self, f_convnext, f_dino, f_clip):
        o_convnext = self.sub_convnext(f_convnext)
        o_dino = self.sub_dino(f_dino)
        o_clip = self.sub_clip(f_clip)

        fused = self.meta_fusion(o_convnext, o_dino, o_clip)
        return self.head(fused)


# -------------------------
# Training
# -------------------------
def train_single_config(
    strategy: str,
    config: Dict[str, Any],
    train_dataset: TensorDataset,
    test_dataset: TensorDataset,
    dims: Dict[str, int],
    epochs: int = EPOCHS_PER_RUN,
) -> Dict[str, Any]:
    """Train a single configuration."""

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True,
        num_workers=0, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=False,
        num_workers=0, pin_memory=True
    )

    model = TripleCollective(dims, strategy, config).to(device)
    model.sub_convnext.prepare_and_compile()
    model.sub_dino.prepare_and_compile()
    model.sub_clip.prepare_and_compile()

    # Count params
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        n_batches = 0

        for f_conv, f_dino, f_clip, labels in train_loader:
            f_conv = f_conv.to(device)
            f_dino = f_dino.to(device)
            f_clip = f_clip.to(device)
            labels = labels.to(device)

            opt.zero_grad()
            logits = model(f_conv, f_dino, f_clip)
            loss = loss_fn(logits, labels)
            loss.backward()
            opt.step()

            train_loss += loss.item()
            n_batches += 1

        avg_loss = train_loss / n_batches

        model.eval()
        correct = total = 0

        with torch.no_grad():
            for f_conv, f_dino, f_clip, labels in test_loader:
                f_conv = f_conv.to(device)
                f_dino = f_dino.to(device)
                f_clip = f_clip.to(device)
                labels = labels.to(device)

                preds = model(f_conv, f_dino, f_clip).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100 * correct / total
        best_acc = max(best_acc, acc)

        history.append({'epoch': epoch + 1, 'loss': avg_loss, 'acc': acc})
        print(f"    Epoch {epoch+1:2d}: loss={avg_loss:.4f}, acc={acc:.2f}%")

    del model, opt
    torch.cuda.empty_cache()
    gc.collect()

    return {
        'strategy': strategy,
        'config': config,
        'best_acc': best_acc,
        'final_acc': history[-1]['acc'],
        'final_loss': history[-1]['loss'],
        'trainable_params': trainable_params,
        'history': history,
    }


# -------------------------
# Configuration Generator
# -------------------------
def generate_sweep_configs() -> List[Tuple[str, Dict[str, Any]]]:
    """Generate all configurations to sweep."""
    configs = []

    # === Baselines ===
    configs.append(('sum', {'name': 'baseline_sum'}))
    configs.append(('adaptive', {'name': 'baseline_adaptive'}))
    configs.append(('concat', {'name': 'baseline_concat'}))

    # === Strategy sweep with default walker settings ===
    for strategy in ['hierarchical', 'pairwise', 'hub', 'chain']:
        configs.append((strategy, {
            'name': f'{strategy}_default',
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': 4,
        }))

    # === Blend mode sweep (hierarchical) ===
    blends = ['lerp', 'slerp', 'slip', 'shiva', 'min_p', 'zeus', 'gilgamesh']
    for blend in blends:
        configs.append(('hierarchical', {
            'name': f'hier_blend_{blend}',
            'blend': blend,
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': 4,
        }))

    # === Schedule sweep (hierarchical + slerp) ===
    schedules = ['linear', 'cosine', 'sigmoid', 'tau', 'learnable']
    for sched in schedules:
        configs.append(('hierarchical', {
            'name': f'hier_sched_{sched}',
            'blend': 'slerp',
            'schedule': sched,
            'aggregation': 'mean',
            'num_steps': 4,
        }))

    # === Aggregation sweep (hierarchical + slerp + cosine) ===
    aggregations = ['mean', 'softmax', 'similarity', 'similarity_tree', 'min_p', 'attention', 'weighted']
    for agg in aggregations:
        configs.append(('hierarchical', {
            'name': f'hier_agg_{agg}',
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': agg,
            'num_steps': 4,
        }))

    # === Step count sweep ===
    for steps in [2, 4, 8, 16]:
        configs.append(('hierarchical', {
            'name': f'hier_steps_{steps}',
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': steps,
        }))

    # === Shiva decay rate sweep ===
    # Note: shiva blend uses decay_rate parameter
    for decay in [2.0, 4.0, 6.0, 8.0]:
        configs.append(('hierarchical', {
            'name': f'hier_shiva_decay_{decay}',
            'blend': 'shiva',
            'schedule': 'cosine',
            'aggregation': 'similarity_tree',
            'num_steps': 4,
            'shiva_decay': decay,
        }))

    # === Pairwise combine modes ===
    for combine in ['sum', 'adaptive', 'gated']:
        configs.append(('pairwise', {
            'name': f'pairwise_combine_{combine}',
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': 4,
            'combine_mode': combine,
        }))

    # === Top CIFAR-100 winners from 2-encoder test ===
    # walker_learnable: 88.36%
    configs.append(('hierarchical', {
        'name': 'hier_learnable_full',
        'blend': 'slerp',
        'schedule': 'learnable',
        'aggregation': 'learnable',
        'num_steps': 4,
    }))

    # walker_custom_shiva_best: 88.31%
    configs.append(('hierarchical', {
        'name': 'hier_shiva_similarity_tree',
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
    }))

    # walker_slip: 88.20%
    configs.append(('hierarchical', {
        'name': 'hier_slip_best',
        'blend': 'slip',
        'schedule': 'cosine',
        'aggregation': 'mean',
        'num_steps': 4,
    }))

    # === Heterogeneous hierarchical (different configs per stage) ===
    configs.append(('hierarchical', {
        'name': 'hier_hetero_shiva_slerp',
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
        'blend_2': 'slerp',
        'schedule_2': 'cosine',
        'aggregation_2': 'softmax',
        'num_steps_2': 4,
    }))

    configs.append(('hierarchical', {
        'name': 'hier_hetero_slip_shiva',
        'blend': 'slip',
        'schedule': 'cosine',
        'aggregation': 'similarity',
        'num_steps': 4,
        'blend_2': 'shiva',
        'schedule_2': 'cosine',
        'aggregation_2': 'similarity_tree',
        'num_steps_2': 4,
    }))

    # === Deep steps ===
    configs.append(('hierarchical', {
        'name': 'hier_deep_16_8',
        'blend': 'slerp',
        'schedule': 'cosine',
        'aggregation': 'mean',
        'num_steps': 16,
        'num_steps_2': 8,
    }))

    # === Pairwise with best walker configs ===
    configs.append(('pairwise', {
        'name': 'pairwise_shiva_adaptive',
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
        'combine_mode': 'adaptive',
    }))

    # === Hub variants (ConvNeXt as hub) ===
    configs.append(('hub', {
        'name': 'hub_shiva',
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
        'combine_mode': 'adaptive',
    }))

    # === Chain with best settings ===
    configs.append(('chain', {
        'name': 'chain_shiva',
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
    }))

    configs.append(('chain', {
        'name': 'chain_learnable',
        'blend': 'slerp',
        'schedule': 'learnable',
        'aggregation': 'learnable',
        'num_steps': 4,
    }))

    return configs


# -------------------------
# Checkpoint Management
# -------------------------
def load_checkpoint() -> Tuple[List[Dict], int]:
    """Load checkpoint if exists."""
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, 'r') as f:
            data = json.load(f)
        return data.get('results', []), data.get('next_idx', 0)
    return [], 0


def save_checkpoint(results: List[Dict], next_idx: int):
    """Save checkpoint."""
    with open(CHECKPOINT_FILE, 'w') as f:
        json.dump({'results': results, 'next_idx': next_idx}, f)


# -------------------------
# Main
# -------------------------
def run_triple_ablation():
    """Run the triple encoder ablation."""

    print("="*60)
    print("PREPARING CACHED LATENTS (TRIPLE ENCODER)")
    print("="*60)
    train_dataset, test_dataset, dims = prepare_cached_data()
    print(f"\nTrain: {len(train_dataset)}, Test: {len(test_dataset)}")
    print(f"Dims: {dims}")

    # Generate configs
    configs = generate_sweep_configs()
    print(f"\n{'='*60}")
    print(f"TRIPLE ENCODER ABLATION: {len(configs)} configurations")
    print("="*60)

    # Load checkpoint
    results, start_idx = load_checkpoint()
    if start_idx > 0:
        print(f"\nResuming from config {start_idx}/{len(configs)}")

    # Run configs with crash protection
    for i, (strategy, config) in enumerate(configs):
        if i < start_idx:
            continue

        name = config.get('name', f'{strategy}_{i}')
        print(f"\n[{i+1}/{len(configs)}] {name}")
        print("-" * 40)

        try:
            result = train_single_config(
                strategy, config, train_dataset, test_dataset, dims
            )

            row = {
                'idx': i,
                'name': name,
                'strategy': strategy,
                'blend': config.get('blend', 'N/A'),
                'schedule': config.get('schedule', 'N/A'),
                'aggregation': config.get('aggregation', 'N/A'),
                'num_steps': config.get('num_steps', 'N/A'),
                'combine_mode': config.get('combine_mode', 'N/A'),
                'best_acc': result['best_acc'],
                'final_acc': result['final_acc'],
                'final_loss': result['final_loss'],
                'params': result['trainable_params'],
            }
            results.append(row)

            print(f"  Best: {result['best_acc']:.2f}% | Final: {result['final_acc']:.2f}%")

            # Save checkpoint
            save_checkpoint(results, i + 1)

            # Save CSV
            df = pd.DataFrame(results)
            df.to_csv(RESULTS_FILE, index=False)

        except Exception as e:
            print(f"  ❌ ERROR: {e}")
            traceback.print_exc()

            # Log error but continue
            results.append({
                'idx': i,
                'name': name,
                'strategy': strategy,
                'best_acc': -1,
                'final_acc': -1,
                'error': str(e),
            })
            save_checkpoint(results, i + 1)

    # Final summary
    print("\n" + "="*60)
    print("ABLATION COMPLETE")
    print("="*60)

    df = pd.DataFrame(results)
    df_valid = df[df['best_acc'] > 0].sort_values('best_acc', ascending=False)
    df_valid.to_csv(RESULTS_FILE, index=False)

    print(f"\nResults saved to: {RESULTS_FILE}")

    print(f"\n{'='*60}")
    print("TOP 15 CONFIGURATIONS")
    print("="*60)
    print(df_valid[['name', 'strategy', 'best_acc', 'final_acc']].head(15).to_string(index=False))

    print(f"\n{'='*60}")
    print("RESULTS BY STRATEGY")
    print("="*60)
    for strat in ['hierarchical', 'pairwise', 'hub', 'chain', 'sum', 'adaptive', 'concat']:
        strat_df = df_valid[df_valid['strategy'] == strat]
        if len(strat_df) > 0:
            best = strat_df.iloc[0]
            print(f"  {strat:15s}: {best['best_acc']:.2f}% ({best['name']})")

    # Cleanup checkpoint
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)

    return df_valid


# -------------------------
# Run
# -------------------------
if __name__ == "__main__":
    results_df = run_triple_ablation()

Device: cuda
PREPARING CACHED LATENTS (TRIPLE ENCODER)

Loading convnext...
  convnext dim: 768
Loading cached: ./latent_cache_triple/train_convnext.pt
Loading cached: ./latent_cache_triple/test_convnext.pt

Loading dino_vit...
  dino_vit dim: 768
Loading cached: ./latent_cache_triple/train_dino_vit.pt
Loading cached: ./latent_cache_triple/test_dino_vit.pt

Loading clip_vit...


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]



  clip_vit dim: 768
Extracting: Train clip_vit


Train clip_vit:   0%|          | 0/196 [00:00<?, ?it/s]

Cached: ./latent_cache_triple/train_clip_vit.pt torch.Size([50000, 768])
Extracting: Test clip_vit


Test clip_vit:   0%|          | 0/40 [00:00<?, ?it/s]

Cached: ./latent_cache_triple/test_clip_vit.pt torch.Size([10000, 768])

Train: 50000, Test: 10000
Dims: {'convnext': 768, 'dino_vit': 768, 'clip_vit': 768}

TRIPLE ENCODER ABLATION: 47 configurations

[1/47] baseline_sum
----------------------------------------
    Epoch  1: loss=0.6091, acc=87.80%
    Epoch  2: loss=0.2890, acc=88.43%
    Epoch  3: loss=0.2022, acc=88.78%
    Epoch  4: loss=0.1492, acc=89.06%
    Epoch  5: loss=0.1120, acc=88.95%
    Epoch  6: loss=0.0838, acc=88.42%
    Epoch  7: loss=0.0705, acc=88.40%
    Epoch  8: loss=0.0591, acc=88.20%
    Epoch  9: loss=0.0507, acc=88.54%
    Epoch 10: loss=0.0419, acc=88.88%
  Best: 89.06% | Final: 88.88%

[2/47] baseline_adaptive
----------------------------------------
    Epoch  1: loss=0.6608, acc=87.52%
    Epoch  2: loss=0.3362, acc=87.26%
    Epoch  3: loss=0.2505, acc=87.65%
    Epoch  4: loss=0.1882, acc=88.04%
    Epoch  5: loss=0.1443, acc=87.33%
    Epoch  6: loss=0.1099, acc=87.30%
    Epoch  7: loss=0.0916, acc=

# text encoder

In [4]:
# =========================
# Text Encoder Triple Fusion Ablation
# 3 encoders: CLIP ViT-B (text) + T5-base + BERT-large
# Sequence-level walker interpolation (77 tokens)
# =========================

# !pip -q install transformers datasets tqdm pandas  # Uncomment for Colab

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
import pandas as pd
import os
import traceback
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional
import gc
import json

# Transformers
from transformers import (
    CLIPTokenizer, CLIPTextModel,
    T5Tokenizer, T5EncoderModel,
    BertTokenizer, BertModel,
)
from datasets import load_dataset

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.fusion_component import (
    AdaptiveFusion, SumFusion, ConcatFusion, GatedFusion,
)
from geofractal.router.components.aggregation_component import (
    FieldWalkerFusion, from_preset, WALKER_PRESETS,
)

# -------------------------
# Config
# -------------------------
CACHE_DIR = "./latent_cache_text"
RESULTS_FILE = f"text_walker_ablation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
CHECKPOINT_FILE = "text_ablation_checkpoint.json"
EPOCHS_PER_RUN = 10
BATCH_SIZE_TRAIN = 32
BATCH_SIZE_TEST = 64
MAX_LENGTH = 77  # CLIP's native window
HEAD_DIM = 512
NUM_TOWERS = 4  # Smaller for text (less data typically)

# Dataset config
DATASET_NAME = "ag_news"  # 4-class news classification
NUM_CLASSES = 4

# Stability
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# -------------------------
# Encoder Configs
# -------------------------
ENCODERS = {
    'clip': {
        'model': 'openai/clip-vit-base-patch32',
        'tokenizer': 'openai/clip-vit-base-patch32',
        'type': 'clip',
    },
    't5': {
        'model': 't5-base',
        'tokenizer': 't5-base',
        'type': 't5',
    },
    'bert': {
        'model': 'bert-large-uncased',
        'tokenizer': 'bert-large-uncased',
        'type': 'bert',
    },
}


# -------------------------
# Encoder Loading
# -------------------------
def load_text_encoder(name: str, config: Dict):
    """Load text encoder and tokenizer."""
    enc_type = config['type']

    if enc_type == 'clip':
        tokenizer = CLIPTokenizer.from_pretrained(config['tokenizer'])
        model = CLIPTextModel.from_pretrained(config['model'])
    elif enc_type == 't5':
        tokenizer = T5Tokenizer.from_pretrained(config['tokenizer'])
        model = T5EncoderModel.from_pretrained(config['model'])
    elif enc_type == 'bert':
        tokenizer = BertTokenizer.from_pretrained(config['tokenizer'])
        model = BertModel.from_pretrained(config['model'])
    else:
        raise ValueError(f"Unknown encoder type: {enc_type}")

    model.to(device)
    model.eval()
    for p in model.parameters():
        p.requires_grad = False

    return tokenizer, model


def encode_batch(model, input_ids, attention_mask, enc_type: str) -> torch.Tensor:
    """Encode a batch and return sequence outputs [B, L, D]."""
    with torch.no_grad():
        if enc_type == 'clip':
            # CLIP text model returns last_hidden_state
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            return outputs.last_hidden_state  # [B, L, 512]
        elif enc_type == 't5':
            # T5 encoder
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            return outputs.last_hidden_state  # [B, L, 768]
        elif enc_type == 'bert':
            # BERT
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            return outputs.last_hidden_state  # [B, L, 1024]


# -------------------------
# Dataset Preparation
# -------------------------
def prepare_dataset():
    """Load and prepare AG News dataset."""
    print(f"Loading {DATASET_NAME} dataset...")
    dataset = load_dataset(DATASET_NAME)

    train_texts = dataset['train']['text']
    train_labels = dataset['train']['label']
    test_texts = dataset['test']['text']
    test_labels = dataset['test']['label']

    print(f"Train: {len(train_texts)}, Test: {len(test_texts)}")
    print(f"Classes: {NUM_CLASSES}")

    return train_texts, train_labels, test_texts, test_labels


def tokenize_texts(texts: List[str], tokenizer, enc_type: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """Tokenize texts to fixed length."""
    if enc_type == 'clip':
        # CLIP tokenizer
        encoded = tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors='pt',
        )
    elif enc_type == 't5':
        # T5 tokenizer
        encoded = tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors='pt',
        )
    elif enc_type == 'bert':
        # BERT tokenizer
        encoded = tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors='pt',
        )

    return encoded['input_ids'], encoded['attention_mask']


# -------------------------
# Latent Caching
# -------------------------
def cache_text_latents(
    texts: List[str],
    labels: List[int],
    tokenizer,
    model,
    enc_type: str,
    cache_path: str,
    desc: str,
    batch_size: int = 64,
):
    """Extract and cache text latents."""
    if os.path.exists(cache_path):
        print(f"Loading cached: {cache_path}")
        data = torch.load(cache_path)
        return data['latents'], data['attention_masks'], data['labels']

    print(f"Extracting: {desc}")

    all_latents = []
    all_masks = []

    # Process in batches
    for i in tqdm(range(0, len(texts), batch_size), desc=desc):
        batch_texts = texts[i:i+batch_size]

        input_ids, attention_mask = tokenize_texts(batch_texts, tokenizer, enc_type)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        latents = encode_batch(model, input_ids, attention_mask, enc_type)

        all_latents.append(latents.cpu())
        all_masks.append(attention_mask.cpu())

    latents = torch.cat(all_latents, dim=0)
    masks = torch.cat(all_masks, dim=0)
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else '.', exist_ok=True)
    torch.save({
        'latents': latents,
        'attention_masks': masks,
        'labels': labels_tensor,
    }, cache_path)
    print(f"Cached: {cache_path} {latents.shape}")

    return latents, masks, labels_tensor


def prepare_cached_data():
    """Prepare cached latents for all 3 text encoders."""

    # Load dataset
    train_texts, train_labels, test_texts, test_labels = prepare_dataset()

    dims = {}
    train_latents = {}
    train_masks = {}
    test_latents = {}
    test_masks = {}
    train_labels_tensor = None
    test_labels_tensor = None

    for enc_name, enc_config in ENCODERS.items():
        print(f"\n{'='*40}")
        print(f"Loading {enc_name}...")
        print(f"{'='*40}")

        tokenizer, model = load_text_encoder(enc_name, enc_config)

        # Get dim from model config
        if enc_config['type'] == 'clip':
            dims[enc_name] = model.config.hidden_size
        elif enc_config['type'] == 't5':
            dims[enc_name] = model.config.d_model
        elif enc_config['type'] == 'bert':
            dims[enc_name] = model.config.hidden_size

        print(f"  {enc_name} dim: {dims[enc_name]}")

        # Cache train
        train_latents[enc_name], train_masks[enc_name], train_labels_tensor = cache_text_latents(
            train_texts, train_labels, tokenizer, model, enc_config['type'],
            f"{CACHE_DIR}/train_{enc_name}.pt", f"Train {enc_name}"
        )

        # Cache test
        test_latents[enc_name], test_masks[enc_name], test_labels_tensor = cache_text_latents(
            test_texts, test_labels, tokenizer, model, enc_config['type'],
            f"{CACHE_DIR}/test_{enc_name}.pt", f"Test {enc_name}"
        )

        # Free encoder
        del tokenizer, model
        torch.cuda.empty_cache()

    gc.collect()

    # Create datasets - include attention masks for sequence handling
    train_dataset = TensorDataset(
        train_latents['clip'],
        train_latents['t5'],
        train_latents['bert'],
        train_masks['clip'],  # Use CLIP mask as reference (all same length)
        train_labels_tensor
    )
    test_dataset = TensorDataset(
        test_latents['clip'],
        test_latents['t5'],
        test_latents['bert'],
        test_masks['clip'],
        test_labels_tensor
    )

    return train_dataset, test_dataset, dims


# -------------------------
# Model Components
# -------------------------
class SeqFFNBlock(TorchComponent):
    """FFN block for sequence processing."""
    def __init__(self, name, dim):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Linear(dim * 2, dim),
            nn.Dropout(0.1),
        )

    def forward(self, x):
        # x: [B, L, D]
        return self.net(x)


class SeqTower(BaseTower):
    """Tower for sequence processing."""
    def __init__(self, name, dim, depth=2):
        super().__init__(name, strict=False)
        for i in range(depth):
            self.append(SeqFFNBlock(f"{name}_ffn_{i}", dim))
        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x):
        # x: [B, L, D]
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)


class SeqSubCollective(WideRouter):
    """SubCollective for sequence inputs."""
    def __init__(self, name, in_dim, head_dim, num_towers):
        super().__init__(name, auto_discover=True)

        self.attach("proj", nn.Linear(in_dim, head_dim))
        for i in range(num_towers):
            self.attach(f"tower_{i}", SeqTower(f"{name}_tower_{i}", head_dim))

        self.discover_towers()
        # Use mean pooling across towers, keep sequence
        self.attach("fusion", SeqAdaptiveFusion(f"{name}_fusion", num_towers, head_dim))

    def forward(self, feats):
        # feats: [B, L, D_in]
        x = self["proj"](feats)  # [B, L, head_dim]
        opinions = self.wide_forward(x)
        return self["fusion"](*opinions.values())


class SeqAdaptiveFusion(nn.Module):
    """Adaptive fusion for sequences - fuses tower outputs while preserving sequence."""
    def __init__(self, name, num_inputs, dim):
        super().__init__()
        self.name = name
        self.num_inputs = num_inputs
        self.weight_net = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.GELU(),
            nn.Linear(dim // 4, 1),
        )

    def forward(self, *inputs):
        # inputs: each [B, L, D]
        stacked = torch.stack(inputs, dim=0)  # [N, B, L, D]
        weights = self.weight_net(stacked)  # [N, B, L, 1]
        weights = F.softmax(weights, dim=0)
        return (stacked * weights).sum(dim=0)  # [B, L, D]


# -------------------------
# Sequence Walker Fusion
# -------------------------
class SeqWalkerFusion(nn.Module):
    """
    Walker fusion for sequences.

    Projects all encoders to same dim, then walks along sequence.
    """
    def __init__(self, dims: Dict[str, int], head_dim: int, config: Dict[str, Any]):
        super().__init__()

        # Project each encoder to head_dim
        self.proj_clip = nn.Linear(dims['clip'], head_dim)
        self.proj_t5 = nn.Linear(dims['t5'], head_dim)
        self.proj_bert = nn.Linear(dims['bert'], head_dim)

        # Walker for sequence fusion
        self.walker1 = FieldWalkerFusion(
            "walker_seq_1",
            in_features=head_dim,
            num_steps=config.get('num_steps', 4),
            blend_mode=config.get('blend', 'slerp'),
            schedule=config.get('schedule', 'cosine'),
            aggregation=config.get('aggregation', 'mean'),
        )

        self.walker2 = FieldWalkerFusion(
            "walker_seq_2",
            in_features=head_dim,
            num_steps=config.get('num_steps_2', config.get('num_steps', 4)),
            blend_mode=config.get('blend_2', config.get('blend', 'slerp')),
            schedule=config.get('schedule_2', config.get('schedule', 'cosine')),
            aggregation=config.get('aggregation_2', config.get('aggregation', 'mean')),
        )

    def forward(self, clip_seq, t5_seq, bert_seq, mask=None):
        # Project to common dim: [B, L, head_dim]
        clip_proj = self.proj_clip(clip_seq)
        t5_proj = self.proj_t5(t5_seq)
        bert_proj = self.proj_bert(bert_seq)

        # Flatten for walker: [B*L, D]
        B, L, D = clip_proj.shape
        clip_flat = clip_proj.view(B * L, D)
        t5_flat = t5_proj.view(B * L, D)
        bert_flat = bert_proj.view(B * L, D)

        # Walk: CLIP + T5 -> fused1
        fused1 = self.walker1(clip_flat, t5_flat)
        # Walk: fused1 + BERT -> final
        fused = self.walker2(fused1, bert_flat)

        # Reshape back: [B, L, D]
        return fused.view(B, L, D)


class SeqChainWalker(nn.Module):
    """Chain walker: CLIP -> T5 -> BERT sequential."""
    def __init__(self, dims: Dict[str, int], head_dim: int, config: Dict[str, Any]):
        super().__init__()

        self.proj_clip = nn.Linear(dims['clip'], head_dim)
        self.proj_t5 = nn.Linear(dims['t5'], head_dim)
        self.proj_bert = nn.Linear(dims['bert'], head_dim)

        walker_kwargs = {
            'in_features': head_dim,
            'num_steps': config.get('num_steps', 4),
            'blend_mode': config.get('blend', 'slerp'),
            'schedule': config.get('schedule', 'cosine'),
            'aggregation': config.get('aggregation', 'mean'),
        }

        self.walker1 = FieldWalkerFusion("chain_1", **walker_kwargs)
        self.walker2 = FieldWalkerFusion("chain_2", **walker_kwargs)

    def forward(self, clip_seq, t5_seq, bert_seq, mask=None):
        B, L, _ = clip_seq.shape

        clip_proj = self.proj_clip(clip_seq).view(B * L, -1)
        t5_proj = self.proj_t5(t5_seq).view(B * L, -1)
        bert_proj = self.proj_bert(bert_seq).view(B * L, -1)

        # Chain: CLIP -> T5 -> BERT
        ct = self.walker1(clip_proj, t5_proj)
        ctb = self.walker2(ct, bert_proj)

        return ctb.view(B, L, -1)


class SeqPairwiseWalker(nn.Module):
    """Pairwise walker: (CLIP,T5), (T5,BERT), (CLIP,BERT) -> combine."""
    def __init__(self, dims: Dict[str, int], head_dim: int, config: Dict[str, Any]):
        super().__init__()

        self.proj_clip = nn.Linear(dims['clip'], head_dim)
        self.proj_t5 = nn.Linear(dims['t5'], head_dim)
        self.proj_bert = nn.Linear(dims['bert'], head_dim)

        walker_kwargs = {
            'in_features': head_dim,
            'num_steps': config.get('num_steps', 4),
            'blend_mode': config.get('blend', 'slerp'),
            'schedule': config.get('schedule', 'cosine'),
            'aggregation': config.get('aggregation', 'mean'),
        }

        self.walker_ct = FieldWalkerFusion("pair_ct", **walker_kwargs)
        self.walker_tb = FieldWalkerFusion("pair_tb", **walker_kwargs)
        self.walker_cb = FieldWalkerFusion("pair_cb", **walker_kwargs)

        # Combine 3 pairwise outputs
        self.combine = nn.Linear(head_dim * 3, head_dim)

    def forward(self, clip_seq, t5_seq, bert_seq, mask=None):
        B, L, _ = clip_seq.shape
        D = self.combine.out_features

        clip_proj = self.proj_clip(clip_seq).view(B * L, -1)
        t5_proj = self.proj_t5(t5_seq).view(B * L, -1)
        bert_proj = self.proj_bert(bert_seq).view(B * L, -1)

        ct = self.walker_ct(clip_proj, t5_proj)
        tb = self.walker_tb(t5_proj, bert_proj)
        cb = self.walker_cb(clip_proj, bert_proj)

        combined = torch.cat([ct, tb, cb], dim=-1)
        out = self.combine(combined)

        return out.view(B, L, D)


class SeqSumFusion(nn.Module):
    """Simple sum fusion for sequences."""
    def __init__(self, dims: Dict[str, int], head_dim: int, config: Dict[str, Any]):
        super().__init__()
        self.proj_clip = nn.Linear(dims['clip'], head_dim)
        self.proj_t5 = nn.Linear(dims['t5'], head_dim)
        self.proj_bert = nn.Linear(dims['bert'], head_dim)
        self.weights = nn.Parameter(torch.ones(3))

    def forward(self, clip_seq, t5_seq, bert_seq, mask=None):
        clip_proj = self.proj_clip(clip_seq)
        t5_proj = self.proj_t5(t5_seq)
        bert_proj = self.proj_bert(bert_seq)

        w = F.softmax(self.weights, dim=0)
        return w[0] * clip_proj + w[1] * t5_proj + w[2] * bert_proj


class SeqConcatFusion(nn.Module):
    """Concat fusion for sequences."""
    def __init__(self, dims: Dict[str, int], head_dim: int, config: Dict[str, Any]):
        super().__init__()
        total_dim = dims['clip'] + dims['t5'] + dims['bert']
        self.proj = nn.Linear(total_dim, head_dim)

    def forward(self, clip_seq, t5_seq, bert_seq, mask=None):
        concat = torch.cat([clip_seq, t5_seq, bert_seq], dim=-1)
        return self.proj(concat)


def create_seq_fusion(strategy: str, dims: Dict[str, int], head_dim: int, config: Dict[str, Any]) -> nn.Module:
    """Factory for sequence fusion strategies."""
    if strategy == 'hierarchical':
        return SeqWalkerFusion(dims, head_dim, config)
    elif strategy == 'chain':
        return SeqChainWalker(dims, head_dim, config)
    elif strategy == 'pairwise':
        return SeqPairwiseWalker(dims, head_dim, config)
    elif strategy == 'sum':
        return SeqSumFusion(dims, head_dim, config)
    elif strategy == 'concat':
        return SeqConcatFusion(dims, head_dim, config)
    else:
        raise ValueError(f"Unknown strategy: {strategy}")


# -------------------------
# Full Model
# -------------------------
class TextTripleModel(nn.Module):
    """
    Triple text encoder model with sequence-level fusion.

    CLIP + T5 + BERT -> Walker Fusion -> Pool -> Classifier
    """
    def __init__(self, dims: Dict[str, int], strategy: str, config: Dict[str, Any]):
        super().__init__()

        self.strategy = strategy
        self.fusion = create_seq_fusion(strategy, dims, HEAD_DIM, config)

        # Sequence pooling
        self.pool_type = config.get('pool', 'cls')  # 'cls', 'mean', 'max', 'attention'

        if self.pool_type == 'attention':
            self.pool_attn = nn.Linear(HEAD_DIM, 1)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(HEAD_DIM, HEAD_DIM),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(HEAD_DIM, NUM_CLASSES),
        )

    def pool_sequence(self, seq, mask=None):
        """Pool sequence to single vector."""
        # seq: [B, L, D], mask: [B, L]
        if self.pool_type == 'cls':
            return seq[:, 0, :]  # First token
        elif self.pool_type == 'mean':
            if mask is not None:
                mask_expanded = mask.unsqueeze(-1).float()
                return (seq * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1)
            return seq.mean(dim=1)
        elif self.pool_type == 'max':
            if mask is not None:
                seq = seq.masked_fill(~mask.unsqueeze(-1).bool(), float('-inf'))
            return seq.max(dim=1)[0]
        elif self.pool_type == 'attention':
            scores = self.pool_attn(seq).squeeze(-1)  # [B, L]
            if mask is not None:
                scores = scores.masked_fill(~mask.bool(), float('-inf'))
            weights = F.softmax(scores, dim=-1).unsqueeze(-1)
            return (seq * weights).sum(dim=1)

    def forward(self, clip_seq, t5_seq, bert_seq, mask=None):
        # Fuse sequences: [B, L, D]
        fused = self.fusion(clip_seq, t5_seq, bert_seq, mask)

        # Pool to [B, D]
        pooled = self.pool_sequence(fused, mask)

        # Classify
        return self.classifier(pooled)


# -------------------------
# Training
# -------------------------
def train_single_config(
    strategy: str,
    config: Dict[str, Any],
    train_dataset: TensorDataset,
    test_dataset: TensorDataset,
    dims: Dict[str, int],
    epochs: int = EPOCHS_PER_RUN,
) -> Dict[str, Any]:
    """Train a single configuration."""

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True,
        num_workers=0, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=False,
        num_workers=0, pin_memory=True
    )

    model = TextTripleModel(dims, strategy, config).to(device)

    # Count params
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    loss_fn = nn.CrossEntropyLoss()

    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        n_batches = 0

        for batch in train_loader:
            clip_seq, t5_seq, bert_seq, mask, labels = [x.to(device) for x in batch]

            opt.zero_grad()
            logits = model(clip_seq, t5_seq, bert_seq, mask)
            loss = loss_fn(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            train_loss += loss.item()
            n_batches += 1

        scheduler.step()
        avg_loss = train_loss / n_batches

        model.eval()
        correct = total = 0

        with torch.no_grad():
            for batch in test_loader:
                clip_seq, t5_seq, bert_seq, mask, labels = [x.to(device) for x in batch]
                preds = model(clip_seq, t5_seq, bert_seq, mask).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100 * correct / total
        best_acc = max(best_acc, acc)

        history.append({'epoch': epoch + 1, 'loss': avg_loss, 'acc': acc})
        print(f"    Epoch {epoch+1:2d}: loss={avg_loss:.4f}, acc={acc:.2f}%")

    del model, opt, scheduler
    torch.cuda.empty_cache()
    gc.collect()

    return {
        'strategy': strategy,
        'config': config,
        'best_acc': best_acc,
        'final_acc': history[-1]['acc'],
        'final_loss': history[-1]['loss'],
        'trainable_params': trainable_params,
        'history': history,
    }


# -------------------------
# Configuration Generator
# -------------------------
def generate_sweep_configs() -> List[Tuple[str, Dict[str, Any]]]:
    """Generate all configurations to sweep."""
    configs = []

    # === Baselines ===
    configs.append(('sum', {'name': 'baseline_sum', 'pool': 'mean'}))
    configs.append(('concat', {'name': 'baseline_concat', 'pool': 'mean'}))

    # === Pooling sweep with sum ===
    for pool in ['cls', 'mean', 'max', 'attention']:
        configs.append(('sum', {'name': f'sum_pool_{pool}', 'pool': pool}))

    # === Strategy sweep with defaults ===
    for strategy in ['hierarchical', 'chain', 'pairwise']:
        configs.append((strategy, {
            'name': f'{strategy}_default',
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': 4,
            'pool': 'mean',
        }))

    # === Blend mode sweep (hierarchical) ===
    blends = ['lerp', 'slerp', 'slip', 'shiva', 'min_p', 'gilgamesh']
    for blend in blends:
        configs.append(('hierarchical', {
            'name': f'hier_blend_{blend}',
            'blend': blend,
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': 4,
            'pool': 'mean',
        }))

    # === Schedule sweep ===
    schedules = ['linear', 'cosine', 'sigmoid', 'learnable']
    for sched in schedules:
        configs.append(('hierarchical', {
            'name': f'hier_sched_{sched}',
            'blend': 'slerp',
            'schedule': sched,
            'aggregation': 'mean',
            'num_steps': 4,
            'pool': 'mean',
        }))

    # === Aggregation sweep ===
    aggregations = ['mean', 'softmax', 'similarity', 'similarity_tree', 'min_p']
    for agg in aggregations:
        configs.append(('hierarchical', {
            'name': f'hier_agg_{agg}',
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': agg,
            'num_steps': 4,
            'pool': 'mean',
        }))

    # === Steps sweep ===
    for steps in [2, 4, 8, 16]:
        configs.append(('hierarchical', {
            'name': f'hier_steps_{steps}',
            'blend': 'slerp',
            'schedule': 'cosine',
            'aggregation': 'mean',
            'num_steps': steps,
            'pool': 'mean',
        }))

    # === Shiva variants (top performer from vision) ===
    configs.append(('hierarchical', {
        'name': 'hier_shiva_similarity_tree',
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
        'pool': 'mean',
    }))

    for decay in [2.0, 4.0, 6.0]:
        configs.append(('hierarchical', {
            'name': f'hier_shiva_decay_{decay}',
            'blend': 'shiva',
            'schedule': 'cosine',
            'aggregation': 'similarity_tree',
            'num_steps': 4,
            'pool': 'mean',
            'shiva_decay': decay,
        }))

    # === Learnable (top performer from vision) ===
    configs.append(('hierarchical', {
        'name': 'hier_learnable_full',
        'blend': 'slerp',
        'schedule': 'learnable',
        'aggregation': 'learnable',
        'num_steps': 4,
        'pool': 'mean',
    }))

    # === Chain variants ===
    configs.append(('chain', {
        'name': 'chain_shiva',
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
        'pool': 'mean',
    }))

    configs.append(('chain', {
        'name': 'chain_learnable',
        'blend': 'slerp',
        'schedule': 'learnable',
        'aggregation': 'learnable',
        'num_steps': 4,
        'pool': 'mean',
    }))

    # === Attention pooling with best walker ===
    configs.append(('hierarchical', {
        'name': 'hier_shiva_attn_pool',
        'blend': 'shiva',
        'schedule': 'cosine',
        'aggregation': 'similarity_tree',
        'num_steps': 4,
        'pool': 'attention',
    }))

    # === 8 steps with learnable (more steps for text?) ===
    configs.append(('hierarchical', {
        'name': 'hier_8step_learnable',
        'blend': 'slerp',
        'schedule': 'learnable',
        'aggregation': 'learnable',
        'num_steps': 8,
        'pool': 'mean',
    }))

    return configs


# -------------------------
# Checkpoint Management
# -------------------------
def load_checkpoint() -> Tuple[List[Dict], int]:
    """Load checkpoint if exists."""
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, 'r') as f:
            data = json.load(f)
        return data.get('results', []), data.get('next_idx', 0)
    return [], 0


def save_checkpoint(results: List[Dict], next_idx: int):
    """Save checkpoint."""
    with open(CHECKPOINT_FILE, 'w') as f:
        json.dump({'results': results, 'next_idx': next_idx}, f)


# -------------------------
# Main
# -------------------------
def run_text_ablation():
    """Run the text encoder ablation."""

    print("="*60)
    print("PREPARING CACHED LATENTS (TEXT ENCODERS)")
    print("="*60)
    train_dataset, test_dataset, dims = prepare_cached_data()
    print(f"\nTrain: {len(train_dataset)}, Test: {len(test_dataset)}")
    print(f"Dims: {dims}")
    print(f"Sequence length: {MAX_LENGTH}")

    # Generate configs
    configs = generate_sweep_configs()
    print(f"\n{'='*60}")
    print(f"TEXT ENCODER ABLATION: {len(configs)} configurations")
    print("="*60)

    # Load checkpoint
    results, start_idx = load_checkpoint()
    if start_idx > 0:
        print(f"\nResuming from config {start_idx}/{len(configs)}")

    # Run configs with crash protection
    for i, (strategy, config) in enumerate(configs):
        if i < start_idx:
            continue

        name = config.get('name', f'{strategy}_{i}')
        print(f"\n[{i+1}/{len(configs)}] {name}")
        print("-" * 40)

        try:
            result = train_single_config(
                strategy, config, train_dataset, test_dataset, dims
            )

            row = {
                'idx': i,
                'name': name,
                'strategy': strategy,
                'blend': config.get('blend', 'N/A'),
                'schedule': config.get('schedule', 'N/A'),
                'aggregation': config.get('aggregation', 'N/A'),
                'num_steps': config.get('num_steps', 'N/A'),
                'pool': config.get('pool', 'N/A'),
                'best_acc': result['best_acc'],
                'final_acc': result['final_acc'],
                'final_loss': result['final_loss'],
                'params': result['trainable_params'],
            }
            results.append(row)

            print(f"  Best: {result['best_acc']:.2f}% | Final: {result['final_acc']:.2f}%")

            # Save checkpoint
            save_checkpoint(results, i + 1)

            # Save CSV
            df = pd.DataFrame(results)
            df.to_csv(RESULTS_FILE, index=False)

        except Exception as e:
            print(f"  ❌ ERROR: {e}")
            traceback.print_exc()

            results.append({
                'idx': i,
                'name': name,
                'strategy': strategy,
                'best_acc': -1,
                'final_acc': -1,
                'error': str(e),
            })
            save_checkpoint(results, i + 1)

    # Final summary
    print("\n" + "="*60)
    print("ABLATION COMPLETE")
    print("="*60)

    df = pd.DataFrame(results)
    df_valid = df[df['best_acc'] > 0].sort_values('best_acc', ascending=False)
    df_valid.to_csv(RESULTS_FILE, index=False)

    print(f"\nResults saved to: {RESULTS_FILE}")

    print(f"\n{'='*60}")
    print("TOP 15 CONFIGURATIONS")
    print("="*60)
    print(df_valid[['name', 'strategy', 'best_acc', 'final_acc', 'pool']].head(15).to_string(index=False))

    print(f"\n{'='*60}")
    print("RESULTS BY STRATEGY")
    print("="*60)
    for strat in ['hierarchical', 'chain', 'pairwise', 'sum', 'concat']:
        strat_df = df_valid[df_valid['strategy'] == strat]
        if len(strat_df) > 0:
            best = strat_df.iloc[0]
            print(f"  {strat:15s}: {best['best_acc']:.2f}% ({best['name']})")

    # Cleanup checkpoint
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)

    return df_valid


# -------------------------
# Run
# -------------------------
if __name__ == "__main__":
    results_df = run_text_ablation()

Device: cuda
PREPARING CACHED LATENTS (TEXT ENCODERS)
Loading ag_news dataset...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Train: 120000, Test: 7600
Classes: 4

Loading clip...


tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

  clip dim: 512
Extracting: Train clip


Train clip:   0%|          | 0/1875 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

Cached: ./latent_cache_text/train_clip.pt torch.Size([120000, 77, 512])
Extracting: Test clip


Test clip:   0%|          | 0/119 [00:00<?, ?it/s]

Cached: ./latent_cache_text/test_clip.pt torch.Size([7600, 77, 512])

Loading t5...


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

  t5 dim: 768
Extracting: Train t5


Train t5:   0%|          | 0/1875 [00:00<?, ?it/s]

Cached: ./latent_cache_text/train_t5.pt torch.Size([120000, 77, 768])
Extracting: Test t5


Test t5:   0%|          | 0/119 [00:00<?, ?it/s]

Cached: ./latent_cache_text/test_t5.pt torch.Size([7600, 77, 768])

Loading bert...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

  bert dim: 1024
Extracting: Train bert


Train bert:   0%|          | 0/1875 [00:00<?, ?it/s]

Cached: ./latent_cache_text/train_bert.pt torch.Size([120000, 77, 1024])
Extracting: Test bert


Test bert:   0%|          | 0/119 [00:00<?, ?it/s]

Cached: ./latent_cache_text/test_bert.pt torch.Size([7600, 77, 1024])

Train: 120000, Test: 7600
Dims: {'clip': 512, 't5': 768, 'bert': 1024}
Sequence length: 77

TEXT ENCODER ABLATION: 37 configurations

[1/37] baseline_sum
----------------------------------------


KeyboardInterrupt: 

# inception fusion

In [3]:
# =========================
# InceptiveFusion Ablation
# Testing different auxiliary feature strategies
# Auxiliary features influence fusion weights without being fused into output
# =========================

# !pip -q install timm tqdm pandas  # Uncomment for Colab

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
import timm
import pandas as pd
import os
import traceback
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional
import gc
import json
import math

from geofractal.router.base_tower import BaseTower
from geofractal.router.wide_router import WideRouter
from geofractal.router.components.torch_component import TorchComponent
from geofractal.router.components.fusion_component import (
    InceptiveFusion, AdaptiveFusion, SumFusion,
)

# -------------------------
# Config
# -------------------------
CACHE_DIR = "./latent_cache_triple"  # Reuse triple encoder cache
RESULTS_FILE = f"inceptive_ablation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
CHECKPOINT_FILE = "inceptive_checkpoint.json"
EPOCHS_PER_RUN = 10
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 128
HEAD_DIM = 512
NUM_TOWERS = 8
NUM_CLASSES = 100  # CIFAR-100

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


# -------------------------
# Auxiliary Feature Generators
# -------------------------
class AuxiliaryFeatureGenerator(nn.Module):
    """Base class for auxiliary feature generators."""
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__()
        self.num_inputs = num_inputs
        self.in_features = in_features
        self.aux_dim = aux_dim

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor:
        """Generate auxiliary features from inputs.

        Args:
            *inputs: N tensors of shape [B, D]

        Returns:
            Auxiliary features [B, aux_dim]
        """
        raise NotImplementedError


class ZeroAuxiliary(AuxiliaryFeatureGenerator):
    """Zero auxiliary features - baseline (degrades to learned pairwise)."""
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)

    def forward(self, *inputs):
        B = inputs[0].shape[0]
        return torch.zeros(B, self.aux_dim, device=inputs[0].device, dtype=inputs[0].dtype)


class CosineSimilarityAuxiliary(AuxiliaryFeatureGenerator):
    """
    Pairwise cosine similarities between all inputs.

    For N inputs: N*(N-1)/2 pairs -> project to aux_dim
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        num_pairs = num_inputs * (num_inputs - 1) // 2
        self.proj = nn.Linear(num_pairs, aux_dim)

    def forward(self, *inputs):
        # Compute all pairwise cosine similarities
        cosines = []
        for i in range(len(inputs)):
            for j in range(i + 1, len(inputs)):
                cos = F.cosine_similarity(inputs[i], inputs[j], dim=-1, eps=1e-8)
                cosines.append(cos)

        # Stack: [B, num_pairs]
        cosines = torch.stack(cosines, dim=-1)
        return self.proj(cosines)


class MagnitudeRatioAuxiliary(AuxiliaryFeatureGenerator):
    """
    Magnitude ratios between inputs.

    Captures relative "confidence" or scale of each encoder.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        # N magnitudes + N*(N-1)/2 ratios
        num_features = num_inputs + num_inputs * (num_inputs - 1) // 2
        self.proj = nn.Linear(num_features, aux_dim)

    def forward(self, *inputs):
        # Compute magnitudes
        mags = [x.norm(dim=-1, keepdim=True) for x in inputs]

        # Compute ratios (log scale for stability)
        ratios = []
        for i in range(len(inputs)):
            for j in range(i + 1, len(inputs)):
                ratio = torch.log((mags[i] + 1e-8) / (mags[j] + 1e-8))
                ratios.append(ratio)

        # Combine: [B, N + N*(N-1)/2]
        features = torch.cat(mags + ratios, dim=-1)
        return self.proj(features)


class CantorStaircaseAuxiliary(AuxiliaryFeatureGenerator):
    """
    Cantor staircase / fractal coordinate features.

    Maps input statistics to fractal coordinate space.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int, depth: int = 5):
        super().__init__(num_inputs, in_features, aux_dim)
        self.depth = depth
        # Per-input: depth cantor levels
        num_features = num_inputs * depth
        self.proj = nn.Linear(num_features, aux_dim)

    def cantor_value(self, x: torch.Tensor, depth: int) -> torch.Tensor:
        """Compute Cantor staircase value for x in [0, 1]."""
        # Normalize to [0, 1]
        x = torch.sigmoid(x)

        result = torch.zeros_like(x)
        for d in range(depth):
            # Which third are we in?
            third = (x * 3).floor().clamp(0, 2)

            # Cantor function: 0->0, 1->0.5, 2->1 (scaled by 2^-d)
            contribution = (third / 2) * (0.5 ** d)
            result = result + contribution

            # Recurse into the third
            x = (x * 3) - third

        return result

    def forward(self, *inputs):
        features = []
        for inp in inputs:
            # Use mean activation as base signal
            mean_act = inp.mean(dim=-1, keepdim=True)

            # Compute Cantor values at different depths
            for d in range(1, self.depth + 1):
                cantor_val = self.cantor_value(mean_act, d)
                features.append(cantor_val)

        # [B, N * depth]
        features = torch.cat(features, dim=-1)
        return self.proj(features)


class LearnedEmbeddingAuxiliary(AuxiliaryFeatureGenerator):
    """
    Learned per-input embeddings.

    Simple learnable auxiliary features - one embedding per input.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        self.embeddings = nn.Parameter(torch.randn(num_inputs, aux_dim) * 0.02)

    def forward(self, *inputs):
        B = inputs[0].shape[0]
        # Return mean of embeddings (same for all samples in batch)
        # Could also use input-dependent attention over embeddings
        return self.embeddings.mean(dim=0, keepdim=True).expand(B, -1)


class InputDependentEmbeddingAuxiliary(AuxiliaryFeatureGenerator):
    """
    Input-dependent learned embeddings.

    Uses attention over per-input embeddings weighted by input content.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        self.embeddings = nn.Parameter(torch.randn(num_inputs, aux_dim) * 0.02)
        self.query_proj = nn.Linear(in_features, aux_dim)

    def forward(self, *inputs):
        # Stack inputs: [B, N, D]
        stacked = torch.stack(inputs, dim=1)

        # Project to query space: [B, N, aux_dim]
        queries = self.query_proj(stacked)

        # Attention over embeddings: [B, N, aux_dim] x [N, aux_dim] -> [B, N]
        attn = torch.einsum('bnd,nd->bn', queries, self.embeddings)
        attn = F.softmax(attn, dim=-1)

        # Weighted sum of embeddings: [B, N] x [N, aux_dim] -> [B, aux_dim]
        return torch.einsum('bn,nd->bd', attn, self.embeddings)


class CrossVarianceAuxiliary(AuxiliaryFeatureGenerator):
    """
    Cross-input variance features.

    Measures agreement/disagreement between encoders.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        # Per-dim variance + mean + min + max across inputs
        self.proj = nn.Linear(4, aux_dim)

    def forward(self, *inputs):
        # Stack: [B, N, D]
        stacked = torch.stack(inputs, dim=1)

        # Compute statistics across inputs (dim=1)
        var = stacked.var(dim=1).mean(dim=-1, keepdim=True)  # [B, 1]
        mean = stacked.mean(dim=1).mean(dim=-1, keepdim=True)  # [B, 1]
        min_val = stacked.min(dim=1)[0].mean(dim=-1, keepdim=True)  # [B, 1]
        max_val = stacked.max(dim=1)[0].mean(dim=-1, keepdim=True)  # [B, 1]

        features = torch.cat([var, mean, min_val, max_val], dim=-1)
        return self.proj(features)


class DotProductAuxiliary(AuxiliaryFeatureGenerator):
    """
    Raw dot products between inputs.

    Captures correlation structure.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        num_pairs = num_inputs * (num_inputs - 1) // 2
        self.proj = nn.Linear(num_pairs, aux_dim)
        self.scale = 1.0 / math.sqrt(in_features)

    def forward(self, *inputs):
        dots = []
        for i in range(len(inputs)):
            for j in range(i + 1, len(inputs)):
                dot = (inputs[i] * inputs[j]).sum(dim=-1, keepdim=True) * self.scale
                dots.append(dot)

        features = torch.cat(dots, dim=-1)
        return self.proj(features)


class EntropyAuxiliary(AuxiliaryFeatureGenerator):
    """
    Per-input feature entropy.

    Measures "uncertainty" or "spread" of each encoder's representation.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        self.proj = nn.Linear(num_inputs, aux_dim)

    def forward(self, *inputs):
        entropies = []
        for inp in inputs:
            # Softmax to get pseudo-probabilities
            probs = F.softmax(inp, dim=-1)
            # Entropy: -sum(p * log(p))
            entropy = -(probs * (probs + 1e-8).log()).sum(dim=-1, keepdim=True)
            entropies.append(entropy)

        features = torch.cat(entropies, dim=-1)
        return self.proj(features)


class GeometricAuxiliary(AuxiliaryFeatureGenerator):
    """
    Geometric features: Cayley-Menger determinant + angular differences.

    From David's geometric attention gate.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int, sample_dim: int = 32):
        super().__init__(num_inputs, in_features, aux_dim)
        self.sample_dim = sample_dim

        # Project to sample space for efficient Cayley-Menger
        self.sample_proj = nn.Linear(in_features, sample_dim)

        # Output projection
        # Features: 1 (CM volume) + N*(N-1)/2 (angular)
        num_pairs = num_inputs * (num_inputs - 1) // 2
        self.proj = nn.Linear(1 + num_pairs, aux_dim)

    def cayley_menger_volume(self, points: torch.Tensor) -> torch.Tensor:
        """
        Compute Cayley-Menger determinant for volume of simplex.

        Args:
            points: [B, N, D] - N points in D dimensions

        Returns:
            Volume proxy [B, 1]
        """
        B, N, D = points.shape

        # Distance matrix: [B, N, N]
        dist_sq = torch.cdist(points, points, p=2).pow(2)

        # Build Cayley-Menger matrix
        # [1, d01^2, d02^2, ...]
        # [d01^2, 0, d12^2, ...]
        # ...
        cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
        cm[:, 0, 1:] = 1
        cm[:, 1:, 0] = 1
        cm[:, 1:, 1:] = dist_sq

        # Determinant as volume proxy (avoid actual det for stability)
        # Use trace of cm @ cm as proxy
        volume = (cm @ cm).diagonal(dim1=-2, dim2=-1).sum(dim=-1, keepdim=True)

        return volume

    def angular_features(self, inputs: List[torch.Tensor]) -> torch.Tensor:
        """Compute pairwise angular differences."""
        angles = []
        for i in range(len(inputs)):
            for j in range(i + 1, len(inputs)):
                # Cosine -> angle
                cos = F.cosine_similarity(inputs[i], inputs[j], dim=-1, eps=1e-8)
                angle = torch.acos(cos.clamp(-1 + 1e-7, 1 - 1e-7))
                angles.append(angle.unsqueeze(-1))

        return torch.cat(angles, dim=-1)

    def forward(self, *inputs):
        # Project to sample space
        sampled = [self.sample_proj(x) for x in inputs]

        # Stack: [B, N, sample_dim]
        stacked = torch.stack(sampled, dim=1)

        # Cayley-Menger volume
        volume = self.cayley_menger_volume(stacked)

        # Angular features
        angular = self.angular_features(list(inputs))

        # Combine
        features = torch.cat([volume, angular], dim=-1)
        return self.proj(features)


class CombinedAuxiliary(AuxiliaryFeatureGenerator):
    """
    Combination of multiple auxiliary feature types.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)

        # Individual generators (each outputs aux_dim // 4)
        sub_dim = aux_dim // 4
        self.cosine = CosineSimilarityAuxiliary(num_inputs, in_features, sub_dim)
        self.magnitude = MagnitudeRatioAuxiliary(num_inputs, in_features, sub_dim)
        self.variance = CrossVarianceAuxiliary(num_inputs, in_features, sub_dim)
        self.geometric = GeometricAuxiliary(num_inputs, in_features, sub_dim)

        # Final projection
        self.out_proj = nn.Linear(sub_dim * 4, aux_dim)

    def forward(self, *inputs):
        cos_feat = self.cosine(*inputs)
        mag_feat = self.magnitude(*inputs)
        var_feat = self.variance(*inputs)
        geo_feat = self.geometric(*inputs)

        combined = torch.cat([cos_feat, mag_feat, var_feat, geo_feat], dim=-1)
        return self.out_proj(combined)


class WalkerInspiredAuxiliary(AuxiliaryFeatureGenerator):
    """
    Walker-inspired auxiliary features.

    Computes features that capture the "path" between encoders,
    similar to how walkers interpolate.
    """
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int, num_steps: int = 4):
        super().__init__(num_inputs, in_features, aux_dim)
        self.num_steps = num_steps

        # Features: step-wise similarities along interpolation path
        # For N inputs, N-1 paths, each with num_steps similarities
        num_features = (num_inputs - 1) * num_steps
        self.proj = nn.Linear(num_features, aux_dim)

    def forward(self, *inputs):
        features = []

        # Walk between consecutive inputs
        for i in range(len(inputs) - 1):
            start = inputs[i]
            end = inputs[i + 1]

            for step in range(self.num_steps):
                alpha = (step + 1) / (self.num_steps + 1)

                # Interpolated point
                interp = (1 - alpha) * start + alpha * end

                # Similarity to both endpoints
                sim_start = F.cosine_similarity(interp, start, dim=-1, eps=1e-8)
                sim_end = F.cosine_similarity(interp, end, dim=-1, eps=1e-8)

                # Path feature: balance between endpoints
                path_feat = (sim_start - sim_end).unsqueeze(-1)
                features.append(path_feat)

        features = torch.cat(features, dim=-1)
        return self.proj(features)


# -------------------------
# Auxiliary Feature Factory
# -------------------------
AUXILIARY_GENERATORS = {
    'zero': ZeroAuxiliary,
    'cosine': CosineSimilarityAuxiliary,
    'magnitude': MagnitudeRatioAuxiliary,
    'cantor': CantorStaircaseAuxiliary,
    'learned': LearnedEmbeddingAuxiliary,
    'input_dependent': InputDependentEmbeddingAuxiliary,
    'variance': CrossVarianceAuxiliary,
    'dot_product': DotProductAuxiliary,
    'entropy': EntropyAuxiliary,
    'geometric': GeometricAuxiliary,
    'combined': CombinedAuxiliary,
    'walker': WalkerInspiredAuxiliary,
}


def create_auxiliary_generator(name: str, num_inputs: int, in_features: int, aux_dim: int) -> AuxiliaryFeatureGenerator:
    if name not in AUXILIARY_GENERATORS:
        raise ValueError(f"Unknown auxiliary generator: {name}")
    return AUXILIARY_GENERATORS[name](num_inputs, in_features, aux_dim)


# -------------------------
# Model Components
# -------------------------
class FFNBlock(TorchComponent):
    def __init__(self, name, dim):
        super().__init__(name)
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )

    def forward(self, x):
        return self.net(x)


class Tower(BaseTower):
    def __init__(self, name, dim, depth=2):
        super().__init__(name, strict=False)
        for i in range(depth):
            self.append(FFNBlock(f"{name}_ffn_{i}", dim))
        self.attach("norm", nn.LayerNorm(dim))

    def forward(self, x):
        for stage in self.stages:
            x = x + stage(x)
        return self["norm"](x)


class SubCollective(WideRouter):
    def __init__(self, name, in_dim, head_dim, num_towers):
        super().__init__(name, auto_discover=True)

        self.attach("proj", nn.Linear(in_dim, head_dim))
        for i in range(num_towers):
            self.attach(f"tower_{i}", Tower(f"{name}_tower_{i}", head_dim))

        self.discover_towers()
        self.attach("fusion", AdaptiveFusion(f"{name}_fusion", num_towers, head_dim))

    def forward(self, feats):
        x = self["proj"](feats)
        opinions = self.wide_forward(x)
        return self["fusion"](*opinions.values())


# -------------------------
# InceptiveCollective Model
# -------------------------
class InceptiveCollective(nn.Module):
    """
    Triple encoder collective with InceptiveFusion and configurable auxiliary features.
    """
    def __init__(
        self,
        dims: Dict[str, int],
        aux_type: str,
        aux_dim: int = 64,
        num_heads: int = 8,
    ):
        super().__init__()

        self.sub_convnext = SubCollective("convnext", dims['convnext'], HEAD_DIM, NUM_TOWERS)
        self.sub_dino = SubCollective("dino_vit", dims['dino_vit'], HEAD_DIM, NUM_TOWERS)
        self.sub_clip = SubCollective("clip_vit", dims['clip_vit'], HEAD_DIM, NUM_TOWERS)

        # Auxiliary feature generator
        self.aux_type = aux_type
        self.aux_generator = create_auxiliary_generator(aux_type, 3, HEAD_DIM, aux_dim)

        # InceptiveFusion with auxiliary features
        self.meta_fusion = InceptiveFusion(
            "inceptive_meta",
            num_inputs=3,
            in_features=HEAD_DIM,
            aux_features=aux_dim,
            out_features=HEAD_DIM,
            num_heads=num_heads,
        )

        self.head = nn.Linear(HEAD_DIM, NUM_CLASSES)

    def forward(self, f_convnext, f_dino, f_clip):
        # Get sub-collective outputs
        o_convnext = self.sub_convnext(f_convnext)
        o_dino = self.sub_dino(f_dino)
        o_clip = self.sub_clip(f_clip)

        # Generate auxiliary features
        aux_features = self.aux_generator(o_convnext, o_dino, o_clip)

        # Fuse with auxiliary injection
        fused = self.meta_fusion.fuse(o_convnext, o_dino, o_clip, auxiliary_features=aux_features)

        return self.head(fused)


# -------------------------
# Data Loading
# -------------------------
ENCODERS = {
    'convnext': "convnext_small.dinov3_lvd1689m",
    'dino_vit': "vit_base_patch16_dinov3.lvd1689m",
    'clip_vit': "vit_base_patch16_clip_224.laion2b_ft_in12k_in1k",
}


def load_encoder(name):
    enc = timm.create_model(name, pretrained=True, num_classes=0, global_pool="avg")
    enc.to(device)
    enc.eval()
    for p in enc.parameters():
        p.requires_grad = False
    return enc


def cache_latents(encoder, dataloader, cache_path: str, desc: str):
    """Extract and cache latents from encoder."""
    if os.path.exists(cache_path):
        print(f"Loading cached: {cache_path}")
        data = torch.load(cache_path)
        return data['latents'], data['labels']

    print(f"Extracting: {desc}")
    all_latents = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in tqdm(dataloader, desc=desc):
            imgs = imgs.to(device)
            latents = encoder(imgs)
            all_latents.append(latents.cpu())
            all_labels.append(labels)

    latents = torch.cat(all_latents, dim=0)
    labels = torch.cat(all_labels, dim=0)

    os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else '.', exist_ok=True)
    torch.save({'latents': latents, 'labels': labels}, cache_path)
    print(f"Cached: {cache_path} {latents.shape}")

    return latents, labels


def load_cached_data():
    """Load or create cached latents for all 3 encoders."""
    from torchvision import datasets, transforms

    # Check if all caches exist
    required_files = [
        f"{CACHE_DIR}/train_convnext.pt",
        f"{CACHE_DIR}/train_dino_vit.pt",
        f"{CACHE_DIR}/train_clip_vit.pt",
        f"{CACHE_DIR}/test_convnext.pt",
        f"{CACHE_DIR}/test_dino_vit.pt",
        f"{CACHE_DIR}/test_clip_vit.pt",
    ]

    need_extraction = any(not os.path.exists(f) for f in required_files)

    if need_extraction:
        print("Cache not found - extracting latents from encoders...")

        tf = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        train_ds = datasets.CIFAR100("./data", train=True, download=True, transform=tf)
        test_ds = datasets.CIFAR100("./data", train=False, download=True, transform=tf)

        train_loader = DataLoader(train_ds, batch_size=256, shuffle=False, num_workers=4)
        test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4)

        os.makedirs(CACHE_DIR, exist_ok=True)

        for enc_name, enc_model in ENCODERS.items():
            if (os.path.exists(f"{CACHE_DIR}/train_{enc_name}.pt") and
                os.path.exists(f"{CACHE_DIR}/test_{enc_name}.pt")):
                continue

            print(f"\nLoading {enc_name}...")
            enc = load_encoder(enc_model)

            cache_latents(enc, train_loader, f"{CACHE_DIR}/train_{enc_name}.pt", f"Train {enc_name}")
            cache_latents(enc, test_loader, f"{CACHE_DIR}/test_{enc_name}.pt", f"Test {enc_name}")

            del enc
            torch.cuda.empty_cache()

        gc.collect()

    # Load cached data
    train_convnext = torch.load(f"{CACHE_DIR}/train_convnext.pt")
    train_dino = torch.load(f"{CACHE_DIR}/train_dino_vit.pt")
    train_clip = torch.load(f"{CACHE_DIR}/train_clip_vit.pt")
    test_convnext = torch.load(f"{CACHE_DIR}/test_convnext.pt")
    test_dino = torch.load(f"{CACHE_DIR}/test_dino_vit.pt")
    test_clip = torch.load(f"{CACHE_DIR}/test_clip_vit.pt")

    dims = {
        'convnext': train_convnext['latents'].shape[-1],
        'dino_vit': train_dino['latents'].shape[-1],
        'clip_vit': train_clip['latents'].shape[-1],
    }

    train_dataset = TensorDataset(
        train_convnext['latents'],
        train_dino['latents'],
        train_clip['latents'],
        train_convnext['labels'],
    )
    test_dataset = TensorDataset(
        test_convnext['latents'],
        test_dino['latents'],
        test_clip['latents'],
        test_convnext['labels'],
    )

    return train_dataset, test_dataset, dims


# -------------------------
# Training
# -------------------------
def train_single_config(
    aux_type: str,
    config: Dict[str, Any],
    train_dataset: TensorDataset,
    test_dataset: TensorDataset,
    dims: Dict[str, int],
    epochs: int = EPOCHS_PER_RUN,
) -> Dict[str, Any]:
    """Train a single configuration."""

    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True,
        num_workers=0, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=False,
        num_workers=0, pin_memory=True
    )

    model = InceptiveCollective(
        dims=dims,
        aux_type=aux_type,
        aux_dim=config.get('aux_dim', 64),
        num_heads=config.get('num_heads', 8),
    ).to(device)

    model.sub_convnext.prepare_and_compile()
    model.sub_dino.prepare_and_compile()
    model.sub_clip.prepare_and_compile()

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    history = []
    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        n_batches = 0

        for f_conv, f_dino, f_clip, labels in train_loader:
            f_conv = f_conv.to(device)
            f_dino = f_dino.to(device)
            f_clip = f_clip.to(device)
            labels = labels.to(device)

            opt.zero_grad()
            logits = model(f_conv, f_dino, f_clip)
            loss = loss_fn(logits, labels)
            loss.backward()
            opt.step()

            train_loss += loss.item()
            n_batches += 1

        avg_loss = train_loss / n_batches

        model.eval()
        correct = total = 0

        with torch.no_grad():
            for f_conv, f_dino, f_clip, labels in test_loader:
                f_conv = f_conv.to(device)
                f_dino = f_dino.to(device)
                f_clip = f_clip.to(device)
                labels = labels.to(device)

                preds = model(f_conv, f_dino, f_clip).argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100 * correct / total
        best_acc = max(best_acc, acc)

        history.append({'epoch': epoch + 1, 'loss': avg_loss, 'acc': acc})
        print(f"    Epoch {epoch+1:2d}: loss={avg_loss:.4f}, acc={acc:.2f}%")

    del model, opt
    torch.cuda.empty_cache()
    gc.collect()

    return {
        'aux_type': aux_type,
        'config': config,
        'best_acc': best_acc,
        'final_acc': history[-1]['acc'],
        'final_loss': history[-1]['loss'],
        'trainable_params': trainable_params,
        'history': history,
    }


# -------------------------
# Configuration Generator
# -------------------------
def generate_sweep_configs() -> List[Tuple[str, Dict[str, Any]]]:
    """Generate all configurations to sweep."""
    configs = []

    # === Baseline: Zero auxiliary ===
    configs.append(('zero', {'name': 'baseline_zero', 'aux_dim': 64}))

    # === All auxiliary types at default aux_dim ===
    aux_types = [
        'cosine', 'magnitude', 'cantor', 'learned', 'input_dependent',
        'variance', 'dot_product', 'entropy', 'geometric', 'combined', 'walker'
    ]

    for aux in aux_types:
        configs.append((aux, {'name': f'aux_{aux}', 'aux_dim': 64}))

    # === Aux dim sweep for best performers ===
    for aux_dim in [16, 32, 128, 256]:
        configs.append(('combined', {'name': f'combined_dim_{aux_dim}', 'aux_dim': aux_dim}))
        configs.append(('geometric', {'name': f'geometric_dim_{aux_dim}', 'aux_dim': aux_dim}))

    # === Num heads sweep ===
    for heads in [4, 16]:
        configs.append(('combined', {'name': f'combined_heads_{heads}', 'aux_dim': 64, 'num_heads': heads}))

    # === Walker auxiliary with different step counts ===
    # Note: WalkerInspiredAuxiliary uses num_steps internally, we'd need to extend the class

    return configs


# -------------------------
# Checkpoint Management
# -------------------------
def load_checkpoint() -> Tuple[List[Dict], int]:
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, 'r') as f:
            data = json.load(f)
        return data.get('results', []), data.get('next_idx', 0)
    return [], 0


def save_checkpoint(results: List[Dict], next_idx: int):
    with open(CHECKPOINT_FILE, 'w') as f:
        json.dump({'results': results, 'next_idx': next_idx}, f)


# -------------------------
# Main
# -------------------------
def run_inceptive_ablation():
    """Run the InceptiveFusion ablation."""

    print("="*60)
    print("LOADING CACHED LATENTS")
    print("="*60)
    train_dataset, test_dataset, dims = load_cached_data()
    print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")
    print(f"Dims: {dims}")

    configs = generate_sweep_configs()
    print(f"\n{'='*60}")
    print(f"INCEPTIVE FUSION ABLATION: {len(configs)} configurations")
    print("="*60)

    results, start_idx = load_checkpoint()
    if start_idx > 0:
        print(f"\nResuming from config {start_idx}/{len(configs)}")

    for i, (aux_type, config) in enumerate(configs):
        if i < start_idx:
            continue

        name = config.get('name', f'{aux_type}_{i}')
        print(f"\n[{i+1}/{len(configs)}] {name}")
        print("-" * 40)

        try:
            result = train_single_config(
                aux_type, config, train_dataset, test_dataset, dims
            )

            row = {
                'idx': i,
                'name': name,
                'aux_type': aux_type,
                'aux_dim': config.get('aux_dim', 64),
                'num_heads': config.get('num_heads', 8),
                'best_acc': result['best_acc'],
                'final_acc': result['final_acc'],
                'final_loss': result['final_loss'],
                'params': result['trainable_params'],
            }
            results.append(row)

            print(f"  Best: {result['best_acc']:.2f}% | Final: {result['final_acc']:.2f}%")

            save_checkpoint(results, i + 1)

            df = pd.DataFrame(results)
            df.to_csv(RESULTS_FILE, index=False)

        except Exception as e:
            print(f"  ❌ ERROR: {e}")
            traceback.print_exc()

            results.append({
                'idx': i,
                'name': name,
                'aux_type': aux_type,
                'best_acc': -1,
                'final_acc': -1,
                'error': str(e),
            })
            save_checkpoint(results, i + 1)

    # Final summary
    print("\n" + "="*60)
    print("ABLATION COMPLETE")
    print("="*60)

    df = pd.DataFrame(results)
    df_valid = df[df['best_acc'] > 0].sort_values('best_acc', ascending=False)
    df_valid.to_csv(RESULTS_FILE, index=False)

    print(f"\nResults saved to: {RESULTS_FILE}")

    print(f"\n{'='*60}")
    print("TOP 10 CONFIGURATIONS")
    print("="*60)
    print(df_valid[['name', 'aux_type', 'aux_dim', 'best_acc', 'final_acc']].head(10).to_string(index=False))

    print(f"\n{'='*60}")
    print("RESULTS BY AUXILIARY TYPE")
    print("="*60)
    for aux in ['zero'] + list(AUXILIARY_GENERATORS.keys()):
        aux_df = df_valid[df_valid['aux_type'] == aux]
        if len(aux_df) > 0:
            best = aux_df.iloc[0]
            print(f"  {aux:20s}: {best['best_acc']:.2f}% (dim={best['aux_dim']})")

    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)

    return df_valid


if __name__ == "__main__":
    results_df = run_inceptive_ablation()

Device: cuda
LOADING CACHED LATENTS
Cache not found - extracting latents from encoders...

Loading clip_vit...


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]



Extracting: Train clip_vit


Train clip_vit:   0%|          | 0/196 [00:00<?, ?it/s]

Cached: ./latent_cache_triple/train_clip_vit.pt torch.Size([50000, 768])
Extracting: Test clip_vit


Test clip_vit:   0%|          | 0/40 [00:00<?, ?it/s]

Cached: ./latent_cache_triple/test_clip_vit.pt torch.Size([10000, 768])
Train: 50000, Test: 10000
Dims: {'convnext': 768, 'dino_vit': 768, 'clip_vit': 768}

INCEPTIVE FUSION ABLATION: 22 configurations

[1/22] baseline_zero
----------------------------------------
    Epoch  1: loss=0.6797, acc=85.76%
    Epoch  2: loss=0.3707, acc=86.53%
    Epoch  3: loss=0.2858, acc=86.20%
    Epoch  4: loss=0.2238, acc=86.59%
    Epoch  5: loss=0.1783, acc=86.36%
    Epoch  6: loss=0.1433, acc=86.11%
    Epoch  7: loss=0.1226, acc=86.00%
    Epoch  8: loss=0.1019, acc=86.41%
    Epoch  9: loss=0.0909, acc=85.51%
    Epoch 10: loss=0.0841, acc=85.68%
  Best: 86.59% | Final: 85.68%

[2/22] aux_cosine
----------------------------------------
    Epoch  1: loss=0.6850, acc=86.19%
    Epoch  2: loss=0.3670, acc=87.54%
    Epoch  3: loss=0.2696, acc=87.71%
    Epoch  4: loss=0.2034, acc=86.84%
    Epoch  5: loss=0.1606, acc=87.48%
    Epoch  6: loss=0.1184, acc=86.47%
    Epoch  7: loss=0.1052, acc=87.13

# inception aux infused walker scheduled blend selection

In [4]:
# =========================
# Combo Walker Ablation
# Walker fusion + Auxiliary features informing schedule/blend
# The best of both worlds?
# =========================

# !pip -q install timm tqdm pandas  # Uncomment for Colab

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
import timm
import pandas as pd
import os
import traceback
from datetime import datetime
from typing import Dict, List, Tuple, Any, Optional
import gc
import json
import math

# -------------------------
# Config
# -------------------------
CACHE_DIR = "./latent_cache_triple"
RESULTS_FILE = f"combo_walker_ablation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
CHECKPOINT_FILE = "combo_checkpoint.json"
EPOCHS_PER_RUN = 10
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 128
HEAD_DIM = 512
NUM_TOWERS = 8
NUM_CLASSES = 100

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


# -------------------------
# Auxiliary Feature Generators
# -------------------------
class AuxiliaryFeatureGenerator(nn.Module):
    """Base class for auxiliary feature generators."""
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__()
        self.num_inputs = num_inputs
        self.in_features = in_features
        self.aux_dim = aux_dim

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class ZeroAuxiliary(AuxiliaryFeatureGenerator):
    """Zero auxiliary - pure walker baseline."""
    def forward(self, *inputs):
        B = inputs[0].shape[0]
        return torch.zeros(B, self.aux_dim, device=inputs[0].device, dtype=inputs[0].dtype)


class CosineSimilarityAuxiliary(AuxiliaryFeatureGenerator):
    """Pairwise cosine similarities."""
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        num_pairs = num_inputs * (num_inputs - 1) // 2
        self.proj = nn.Linear(num_pairs, aux_dim)

    def forward(self, *inputs):
        cosines = []
        for i in range(len(inputs)):
            for j in range(i + 1, len(inputs)):
                cos = F.cosine_similarity(inputs[i], inputs[j], dim=-1, eps=1e-8)
                cosines.append(cos)
        cosines = torch.stack(cosines, dim=-1)
        return self.proj(cosines)


class LearnedAuxiliary(AuxiliaryFeatureGenerator):
    """Fixed learned embeddings per input."""
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        self.embeddings = nn.Parameter(torch.randn(num_inputs, aux_dim) * 0.02)
        self.out_proj = nn.Linear(num_inputs * aux_dim, aux_dim)

    def forward(self, *inputs):
        B = inputs[0].shape[0]
        expanded = self.embeddings.unsqueeze(0).expand(B, -1, -1)
        return self.out_proj(expanded.reshape(B, -1))


class InputDependentAuxiliary(AuxiliaryFeatureGenerator):
    """Attention over per-input embeddings."""
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        self.query_proj = nn.Linear(in_features, aux_dim)
        self.embeddings = nn.Parameter(torch.randn(num_inputs, aux_dim) * 0.02)

    def forward(self, *inputs):
        stacked = torch.stack(inputs, dim=1)  # [B, N, D]
        queries = self.query_proj(stacked)  # [B, N, aux_dim]
        attn = torch.softmax(queries @ self.embeddings.T / (self.aux_dim ** 0.5), dim=-1)
        return (attn @ self.embeddings).mean(dim=1)


class GeometricAuxiliary(AuxiliaryFeatureGenerator):
    """Cayley-Menger inspired geometric features."""
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int):
        super().__init__(num_inputs, in_features, aux_dim)
        # Distance matrix features + angular features
        num_pairs = num_inputs * (num_inputs - 1) // 2
        raw_dim = num_pairs * 2 + num_inputs  # distances, angles, norms
        self.proj = nn.Linear(raw_dim, aux_dim)

    def forward(self, *inputs):
        features = []
        norms = [x.norm(dim=-1, keepdim=True) for x in inputs]

        # Pairwise distances and angles
        for i in range(len(inputs)):
            for j in range(i + 1, len(inputs)):
                dist = (inputs[i] - inputs[j]).norm(dim=-1, keepdim=True)
                cos = F.cosine_similarity(inputs[i], inputs[j], dim=-1, eps=1e-8).unsqueeze(-1)
                features.extend([dist, cos])

        # Norms
        features.extend(norms)

        combined = torch.cat(features, dim=-1)
        return self.proj(combined)


class WalkerPathAuxiliary(AuxiliaryFeatureGenerator):
    """Path similarities along interpolation - meta-walker features."""
    def __init__(self, num_inputs: int, in_features: int, aux_dim: int, num_steps: int = 4):
        super().__init__(num_inputs, in_features, aux_dim)
        self.num_steps = num_steps
        # For each pair, we sample num_steps points and compute similarities
        num_pairs = num_inputs * (num_inputs - 1) // 2
        raw_dim = num_pairs * (num_steps + 1)  # similarity at each step + endpoints
        self.proj = nn.Linear(raw_dim, aux_dim)

    def forward(self, *inputs):
        features = []
        ts = torch.linspace(0, 1, self.num_steps, device=inputs[0].device)

        for i in range(len(inputs)):
            for j in range(i + 1, len(inputs)):
                a, b = inputs[i], inputs[j]
                # Endpoint similarity
                cos_ab = F.cosine_similarity(a, b, dim=-1, eps=1e-8)
                features.append(cos_ab.unsqueeze(-1))

                # Similarities along path
                for t in ts:
                    interp = (1 - t) * a + t * b
                    cos_interp = F.cosine_similarity(interp, (a + b) / 2, dim=-1, eps=1e-8)
                    features.append(cos_interp.unsqueeze(-1))

        combined = torch.cat(features, dim=-1)
        return self.proj(combined)


AUX_GENERATORS = {
    'zero': ZeroAuxiliary,
    'cosine': CosineSimilarityAuxiliary,
    'learned': LearnedAuxiliary,
    'input_dependent': InputDependentAuxiliary,
    'geometric': GeometricAuxiliary,
    'walker_path': WalkerPathAuxiliary,
}


# -------------------------
# Combo Walker Fusion
# Walker + Auxiliary-informed schedule modulation
# -------------------------
class ComboWalkerFusion(nn.Module):
    """
    FieldWalker-style interpolation with auxiliary features
    modulating the schedule and blend weights.

    Core idea:
    - Base walker does slerp/shiva interpolation
    - Auxiliary features predict schedule offsets and blend adjustments
    - Combines "walking the path" with "knowing the terrain"
    """

    def __init__(
        self,
        in_features: int,
        num_inputs: int = 3,
        hidden_dim: int = 512,
        num_steps: int = 8,
        aux_type: str = 'cosine',
        aux_dim: int = 64,
        base_blend: str = 'shiva',  # shiva, slerp, lerp
        schedule_mode: str = 'aux_modulated',  # fixed, learnable, aux_modulated
    ):
        super().__init__()
        self.in_features = in_features
        self.num_inputs = num_inputs
        self.hidden_dim = hidden_dim
        self.num_steps = num_steps
        self.base_blend = base_blend
        self.schedule_mode = schedule_mode
        self.aux_dim = aux_dim

        # Auxiliary feature generator
        aux_cls = AUX_GENERATORS.get(aux_type, ZeroAuxiliary)
        self.aux_gen = aux_cls(num_inputs, in_features, aux_dim)

        # Project inputs to common dim
        self.input_projs = nn.ModuleList([
            nn.Linear(in_features, hidden_dim) for _ in range(num_inputs)
        ])

        # Base schedule (learnable)
        self.base_schedule = nn.Parameter(torch.linspace(0, 1, num_steps))

        # Auxiliary -> schedule modulation
        if schedule_mode == 'aux_modulated':
            self.schedule_modulator = nn.Sequential(
                nn.Linear(aux_dim, aux_dim),
                nn.GELU(),
                nn.Linear(aux_dim, num_steps),
                nn.Tanh(),  # Bounded modulation
            )
            self.modulation_scale = nn.Parameter(torch.tensor(0.1))

        # Per-step refinement (small MLP)
        self.step_refine = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Final aggregation weights (aux-informed)
        self.agg_weights = nn.Sequential(
            nn.Linear(aux_dim + hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Linear(hidden_dim // 2, num_steps),
        )

        # Shiva decay parameter
        self.shiva_decay = nn.Parameter(torch.tensor(4.0))

    def _slerp(self, a: torch.Tensor, b: torch.Tensor, t: float) -> torch.Tensor:
        """Spherical linear interpolation."""
        a_norm = F.normalize(a, dim=-1)
        b_norm = F.normalize(b, dim=-1)
        dot = (a_norm * b_norm).sum(dim=-1, keepdim=True).clamp(-1 + 1e-7, 1 - 1e-7)
        omega = torch.acos(dot)
        sin_omega = torch.sin(omega).clamp(min=1e-7)

        # Fall back to lerp for nearly parallel vectors
        mask = (omega.abs() < 1e-4).squeeze(-1)

        s0 = torch.sin((1 - t) * omega) / sin_omega
        s1 = torch.sin(t * omega) / sin_omega

        result = s0 * a_norm + s1 * b_norm

        # Lerp fallback
        if mask.any():
            lerp_result = (1 - t) * a_norm + t * b_norm
            result = torch.where(mask.unsqueeze(-1), lerp_result, result)

        # Restore magnitude (use interpolated magnitude)
        a_mag = a.norm(dim=-1, keepdim=True)
        b_mag = b.norm(dim=-1, keepdim=True)
        mag = (1 - t) * a_mag + t * b_mag

        return result * mag

    def _shiva_blend(self, a: torch.Tensor, b: torch.Tensor, t: float) -> torch.Tensor:
        """Shiva blend: exponential decay weighting."""
        decay = self.shiva_decay.abs() + 0.1
        w = torch.exp(-decay * t)
        return w * a + (1 - w) * b

    def _blend(self, a: torch.Tensor, b: torch.Tensor, t: float) -> torch.Tensor:
        """Apply selected blend mode."""
        if self.base_blend == 'slerp':
            return self._slerp(a, b, t)
        elif self.base_blend == 'shiva':
            return self._shiva_blend(a, b, t)
        else:  # lerp
            return (1 - t) * a + t * b

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor:
        """
        Args:
            *inputs: N tensors of shape [B, in_features]

        Returns:
            Fused tensor [B, hidden_dim]
        """
        B = inputs[0].shape[0]

        # Project all inputs
        projected = [proj(x) for proj, x in zip(self.input_projs, inputs)]

        # Generate auxiliary features
        aux_feats = self.aux_gen(*inputs)  # [B, aux_dim]

        # Compute schedule
        if self.schedule_mode == 'aux_modulated':
            # Auxiliary features modulate the base schedule
            modulation = self.schedule_modulator(aux_feats)  # [B, num_steps]
            schedule = self.base_schedule.unsqueeze(0) + self.modulation_scale * modulation
            schedule = schedule.clamp(0, 1)
        else:
            schedule = self.base_schedule.unsqueeze(0).expand(B, -1)

        # Hierarchical walking: ((A,B),C) style
        # First: walk between first two
        steps_ab = []
        for i in range(self.num_steps):
            t = schedule[:, i].unsqueeze(-1)  # [B, 1]
            # Use mean t for blend (per-sample schedule)
            blended = self._blend(projected[0], projected[1], t.mean().item())
            steps_ab.append(blended)

        # Aggregate first walk
        steps_ab = torch.stack(steps_ab, dim=1)  # [B, num_steps, hidden_dim]

        # Refined steps
        refined_ab = self.step_refine(steps_ab)  # [B, num_steps, hidden_dim]

        # If we have 3 inputs, walk from (A,B) result to C
        if self.num_inputs >= 3:
            # Aggregate intermediate result
            ab_mean = refined_ab.mean(dim=1)  # [B, hidden_dim]

            # Second walk: (AB) -> C
            steps_abc = []
            for i in range(self.num_steps):
                t = schedule[:, i].unsqueeze(-1)
                blended = self._blend(ab_mean, projected[2], t.mean().item())
                steps_abc.append(blended)

            steps_abc = torch.stack(steps_abc, dim=1)
            refined_abc = self.step_refine(steps_abc)

            # Final aggregation with aux-informed weights
            # Combine aux features with mean hidden state
            agg_input = torch.cat([aux_feats, refined_abc.mean(dim=1)], dim=-1)
            agg_weights = F.softmax(self.agg_weights(agg_input), dim=-1)  # [B, num_steps]

            output = (refined_abc * agg_weights.unsqueeze(-1)).sum(dim=1)
        else:
            # Just 2 inputs
            agg_input = torch.cat([aux_feats, refined_ab.mean(dim=1)], dim=-1)
            agg_weights = F.softmax(self.agg_weights(agg_input), dim=-1)
            output = (refined_ab * agg_weights.unsqueeze(-1)).sum(dim=1)

        return output


# -------------------------
# Model
# -------------------------
class ComboFusionModel(nn.Module):
    """Triple encoder -> ComboWalkerFusion -> Classifier."""

    def __init__(
        self,
        dims: Dict[str, int],
        num_classes: int = 100,
        hidden_dim: int = 512,
        fusion_config: Dict = None,
    ):
        super().__init__()
        self.dims = dims
        fusion_config = fusion_config or {}

        # Get max dim for projection
        in_features = max(dims.values())

        self.fusion = ComboWalkerFusion(
            in_features=in_features,
            num_inputs=3,
            hidden_dim=hidden_dim,
            **fusion_config
        )

        # Classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_classes),
        )

    def forward(self, convnext: torch.Tensor, dino: torch.Tensor, clip: torch.Tensor):
        fused = self.fusion(convnext, dino, clip)
        return self.classifier(fused)


# -------------------------
# Data Loading
# -------------------------
ENCODERS = {
    'convnext': "convnext_small.dinov3_lvd1689m",
    'dino_vit': "vit_base_patch16_dinov3.lvd1689m",
    'clip_vit': "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k",
}


def load_encoder(name):
    enc = timm.create_model(name, pretrained=True, num_classes=0, global_pool="avg")
    enc.to(device)
    enc.eval()
    for p in enc.parameters():
        p.requires_grad = False
    return enc


def cache_latents(encoder, dataloader, cache_path: str, desc: str):
    if os.path.exists(cache_path):
        print(f"Loading cached: {cache_path}")
        data = torch.load(cache_path)
        return data['latents'], data['labels']

    print(f"Extracting: {desc}")
    all_latents = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in tqdm(dataloader, desc=desc):
            imgs = imgs.to(device)
            latents = encoder(imgs)
            all_latents.append(latents.cpu())
            all_labels.append(labels)

    latents = torch.cat(all_latents, dim=0)
    labels = torch.cat(all_labels, dim=0)

    os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else '.', exist_ok=True)
    torch.save({'latents': latents, 'labels': labels}, cache_path)
    print(f"Cached: {cache_path} {latents.shape}")

    return latents, labels


def load_cached_data():
    from torchvision import datasets, transforms

    required_files = [
        f"{CACHE_DIR}/train_convnext.pt",
        f"{CACHE_DIR}/train_dino_vit.pt",
        f"{CACHE_DIR}/train_clip_vit.pt",
        f"{CACHE_DIR}/test_convnext.pt",
        f"{CACHE_DIR}/test_dino_vit.pt",
        f"{CACHE_DIR}/test_clip_vit.pt",
    ]

    need_extraction = any(not os.path.exists(f) for f in required_files)

    if need_extraction:
        print("Cache not found - extracting latents from encoders...")

        tf = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        train_ds = datasets.CIFAR100("./data", train=True, download=True, transform=tf)
        test_ds = datasets.CIFAR100("./data", train=False, download=True, transform=tf)

        train_loader = DataLoader(train_ds, batch_size=256, shuffle=False, num_workers=4)
        test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=4)

        os.makedirs(CACHE_DIR, exist_ok=True)

        for enc_name, enc_model in ENCODERS.items():
            if (os.path.exists(f"{CACHE_DIR}/train_{enc_name}.pt") and
                os.path.exists(f"{CACHE_DIR}/test_{enc_name}.pt")):
                continue

            print(f"\nLoading {enc_name}...")
            enc = load_encoder(enc_model)

            cache_latents(enc, train_loader, f"{CACHE_DIR}/train_{enc_name}.pt", f"Train {enc_name}")
            cache_latents(enc, test_loader, f"{CACHE_DIR}/test_{enc_name}.pt", f"Test {enc_name}")

            del enc
            torch.cuda.empty_cache()

        gc.collect()

    train_convnext = torch.load(f"{CACHE_DIR}/train_convnext.pt")
    train_dino = torch.load(f"{CACHE_DIR}/train_dino_vit.pt")
    train_clip = torch.load(f"{CACHE_DIR}/train_clip_vit.pt")
    test_convnext = torch.load(f"{CACHE_DIR}/test_convnext.pt")
    test_dino = torch.load(f"{CACHE_DIR}/test_dino_vit.pt")
    test_clip = torch.load(f"{CACHE_DIR}/test_clip_vit.pt")

    dims = {
        'convnext': train_convnext['latents'].shape[-1],
        'dino_vit': train_dino['latents'].shape[-1],
        'clip_vit': train_clip['latents'].shape[-1],
    }

    train_dataset = TensorDataset(
        train_convnext['latents'],
        train_dino['latents'],
        train_clip['latents'],
        train_convnext['labels'],
    )
    test_dataset = TensorDataset(
        test_convnext['latents'],
        test_dino['latents'],
        test_clip['latents'],
        test_convnext['labels'],
    )

    return train_dataset, test_dataset, dims


# -------------------------
# Training
# -------------------------
NUM_RUNS = 2  # Run each config twice for consistency check


def train_single_run(
    train_dataset,
    test_dataset,
    dims: Dict[str, int],
    config: Dict,
    seed: int,
) -> Tuple[float, float]:
    """Single training run with specific seed."""

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=False)

    model = ComboFusionModel(
        dims=dims,
        num_classes=NUM_CLASSES,
        hidden_dim=HEAD_DIM,
        fusion_config=config,
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_PER_RUN)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0
    final_acc = 0.0

    for epoch in range(EPOCHS_PER_RUN):
        model.train()
        total_loss = 0.0
        for convnext, dino, clip, labels in train_loader:
            convnext, dino, clip, labels = (
                convnext.to(device),
                dino.to(device),
                clip.to(device),
                labels.to(device),
            )

            optimizer.zero_grad()
            logits = model(convnext, dino, clip)
            loss = criterion(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(train_loader)

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for convnext, dino, clip, labels in test_loader:
                convnext, dino, clip, labels = (
                    convnext.to(device),
                    dino.to(device),
                    clip.to(device),
                    labels.to(device),
                )
                logits = model(convnext, dino, clip)
                preds = logits.argmax(dim=-1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = 100.0 * correct / total
        final_acc = acc
        if acc > best_acc:
            best_acc = acc

        print(f"      Epoch {epoch+1:2d}: loss={avg_loss:.4f}, acc={acc:.2f}%")

    del model, optimizer, scheduler
    torch.cuda.empty_cache()
    gc.collect()

    return best_acc, final_acc


def train_one_config(
    train_dataset,
    test_dataset,
    dims: Dict[str, int],
    config: Dict,
    config_name: str,
) -> Dict:
    """Train config multiple times and compute consistency."""

    run_bests = []
    run_finals = []

    for run_idx in range(NUM_RUNS):
        seed = 42 + run_idx * 1000
        print(f"    Run {run_idx+1}/{NUM_RUNS} (seed={seed})")
        best_acc, final_acc = train_single_run(
            train_dataset, test_dataset, dims, config, seed
        )
        run_bests.append(best_acc)
        run_finals.append(final_acc)
        print(f"    -> Best: {best_acc:.2f}%")

    # Consistency ratio: min/max (>0.95 = consistent, per OverMeta)
    best_min, best_max = min(run_bests), max(run_bests)
    consistency = best_min / best_max if best_max > 0 else 0

    # Stats
    best_mean = sum(run_bests) / len(run_bests)
    best_std = (sum((x - best_mean)**2 for x in run_bests) / len(run_bests)) ** 0.5

    return {
        'name': config_name,
        'best_mean': best_mean,
        'best_std': best_std,
        'best_min': best_min,
        'best_max': best_max,
        'consistency': consistency,
        'run1_best': run_bests[0],
        'run2_best': run_bests[1] if len(run_bests) > 1 else run_bests[0],
        'final_mean': sum(run_finals) / len(run_finals),
        **{k: str(v) if not isinstance(v, (int, float)) else v for k, v in config.items()}
    }


# -------------------------
# Ablation Configs
# -------------------------
def get_ablation_configs() -> List[Tuple[str, Dict]]:
    """Generate all configurations to test."""
    configs = []

    # === Baseline: Pure walker (zero aux) ===
    configs.append(("baseline_walker", {
        'aux_type': 'zero',
        'aux_dim': 64,
        'num_steps': 8,
        'base_blend': 'shiva',
        'schedule_mode': 'learnable',
    }))

    # === Aux type sweep with shiva blend ===
    for aux_type in ['cosine', 'learned', 'input_dependent', 'geometric', 'walker_path']:
        configs.append((f"combo_shiva_{aux_type}", {
            'aux_type': aux_type,
            'aux_dim': 64,
            'num_steps': 8,
            'base_blend': 'shiva',
            'schedule_mode': 'aux_modulated',
        }))

    # === Aux type sweep with slerp blend ===
    for aux_type in ['cosine', 'learned', 'input_dependent', 'geometric']:
        configs.append((f"combo_slerp_{aux_type}", {
            'aux_type': aux_type,
            'aux_dim': 64,
            'num_steps': 8,
            'base_blend': 'slerp',
            'schedule_mode': 'aux_modulated',
        }))

    # === Steps sweep with best aux types ===
    for num_steps in [4, 12, 16]:
        configs.append((f"combo_steps_{num_steps}_learned", {
            'aux_type': 'learned',
            'aux_dim': 64,
            'num_steps': num_steps,
            'base_blend': 'shiva',
            'schedule_mode': 'aux_modulated',
        }))

    # === Aux dim sweep ===
    for aux_dim in [32, 128]:
        configs.append((f"combo_auxdim_{aux_dim}", {
            'aux_type': 'learned',
            'aux_dim': aux_dim,
            'num_steps': 8,
            'base_blend': 'shiva',
            'schedule_mode': 'aux_modulated',
        }))

    # === Schedule mode comparison ===
    configs.append(("combo_fixed_schedule", {
        'aux_type': 'learned',
        'aux_dim': 64,
        'num_steps': 8,
        'base_blend': 'shiva',
        'schedule_mode': 'fixed',
    }))

    configs.append(("combo_learnable_no_aux_mod", {
        'aux_type': 'learned',
        'aux_dim': 64,
        'num_steps': 8,
        'base_blend': 'shiva',
        'schedule_mode': 'learnable',
    }))

    # === Best combo candidates ===
    configs.append(("combo_best_v1", {
        'aux_type': 'input_dependent',
        'aux_dim': 64,
        'num_steps': 8,
        'base_blend': 'shiva',
        'schedule_mode': 'aux_modulated',
    }))

    configs.append(("combo_best_v2", {
        'aux_type': 'geometric',
        'aux_dim': 32,
        'num_steps': 8,
        'base_blend': 'slerp',
        'schedule_mode': 'aux_modulated',
    }))

    configs.append(("combo_best_v3", {
        'aux_type': 'walker_path',
        'aux_dim': 64,
        'num_steps': 12,
        'base_blend': 'shiva',
        'schedule_mode': 'aux_modulated',
    }))

    return configs


# -------------------------
# Main
# -------------------------
def run_combo_ablation():
    print("=" * 60)
    print("LOADING CACHED LATENTS")
    print("=" * 60)

    train_dataset, test_dataset, dims = load_cached_data()
    print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")
    print(f"Dims: {dims}")

    configs = get_ablation_configs()
    print(f"\n{'='*60}")
    print(f"COMBO WALKER ABLATION: {len(configs)} configurations")
    print("=" * 60)

    # Load checkpoint
    completed = set()
    results = []
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, 'r') as f:
            checkpoint = json.load(f)
            completed = set(checkpoint.get('completed', []))
            results = checkpoint.get('results', [])
        print(f"Resuming from checkpoint: {len(completed)} completed")

    for i, (name, config) in enumerate(configs):
        if name in completed:
            print(f"\n[{i+1}/{len(configs)}] {name} - SKIPPED (already done)")
            continue

        print(f"\n[{i+1}/{len(configs)}] {name}")
        print("-" * 40)

        try:
            result = train_one_config(
                train_dataset, test_dataset, dims, config, name
            )
            results.append(result)
            completed.add(name)

            print(f"  Mean: {result['best_mean']:.2f}% ± {result['best_std']:.2f}% | Consistency: {result['consistency']:.3f}")

            # Checkpoint
            with open(CHECKPOINT_FILE, 'w') as f:
                json.dump({'completed': list(completed), 'results': results}, f)

        except Exception as e:
            print(f"  ERROR: {e}")
            traceback.print_exc()
            results.append({
                'name': name,
                'best_mean': 0.0,
                'best_std': 0.0,
                'best_min': 0.0,
                'best_max': 0.0,
                'consistency': 0.0,
                'run1_best': 0.0,
                'run2_best': 0.0,
                'final_mean': 0.0,
                'error': str(e),
            })

    # Save results
    df = pd.DataFrame(results)
    df.to_csv(RESULTS_FILE, index=False)

    print(f"\n{'='*60}")
    print("ABLATION COMPLETE")
    print("=" * 60)
    print(f"\nResults saved to: {RESULTS_FILE}")

    # Print top 10
    print(f"\n{'='*60}")
    print("TOP 10 CONFIGURATIONS")
    print("=" * 60)

    df_sorted = df.sort_values('best_mean', ascending=False).head(10)
    print(df_sorted[['name', 'aux_type', 'base_blend', 'best_mean', 'best_std', 'consistency']].to_string(index=False))

    # Flag inconsistent configs
    inconsistent = df[df['consistency'] < 0.95]
    if len(inconsistent) > 0:
        print(f"\n⚠️  INCONSISTENT CONFIGS (ratio < 0.95):")
        print(inconsistent[['name', 'run1_best', 'run2_best', 'consistency']].to_string(index=False))

    return df


if __name__ == "__main__":
    results_df = run_combo_ablation()

Device: cuda
LOADING CACHED LATENTS
Train: 50000, Test: 10000
Dims: {'convnext': 768, 'dino_vit': 768, 'clip_vit': 768}

COMBO WALKER ABLATION: 20 configurations

[1/20] baseline_walker
----------------------------------------
    Run 1/2 (seed=42)
      Epoch  1: loss=0.8585, acc=82.54%
      Epoch  2: loss=0.5005, acc=85.06%
      Epoch  3: loss=0.3868, acc=85.85%
      Epoch  4: loss=0.2987, acc=86.54%
      Epoch  5: loss=0.2222, acc=86.87%
      Epoch  6: loss=0.1513, acc=88.19%
      Epoch  7: loss=0.0956, acc=87.71%
      Epoch  8: loss=0.0526, acc=88.11%
      Epoch  9: loss=0.0292, acc=88.49%
      Epoch 10: loss=0.0192, acc=88.66%
    -> Best: 88.66%
    Run 2/2 (seed=1042)
      Epoch  1: loss=0.9067, acc=81.53%
      Epoch  2: loss=0.5472, acc=84.73%
      Epoch  3: loss=0.4245, acc=85.49%
      Epoch  4: loss=0.3377, acc=85.15%
      Epoch  5: loss=0.2564, acc=86.44%
      Epoch  6: loss=0.1814, acc=86.40%
      Epoch  7: loss=0.1229, acc=87.23%
      Epoch  8: loss=0.0735