In [1]:
try:
    !pip uninstall -qy geometricvocab geofractal
except:
    pass

!pip install -q git+https://github.com/AbstractEyes/geofractal.git

[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for geofractal (pyproject.toml) ... [?25l[?25hdone
  Building wheel for geometricvocab (pyproject.toml) ... [?25l[?25hdone


# ineffective CantorGELU on MNIST

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import triton
import triton.language as tl
import time

# =============================================================================
# FIXED CANTOR GELU KERNELS (numerically stable)
# =============================================================================

@triton.jit
def cantor_gelu_fwd_kernel(x_ptr, out_ptr, n_elements, step, strength, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < n_elements

    x = tl.load(x_ptr + offs, mask=mask)

    # GELU with numerically stable tanh
    x3 = x * x * x
    inner = 0.7978845608028654 * (x + 0.044715 * x3)

    # Stable tanh: clamp input to prevent exp overflow
    inner_clamped = tl.minimum(tl.maximum(inner, -10.0), 10.0)
    e2 = tl.exp(2.0 * inner_clamped)
    tanh_val = (e2 - 1.0) / (e2 + 1.0)

    gelu = 0.5 * x * (1.0 + tanh_val)

    # Staircase (also clamp x to prevent issues)
    x_clamped = tl.minimum(tl.maximum(x, -10.0), 10.0)
    snapped = tl.floor(x_clamped / step) * step

    out = strength * snapped + (1.0 - strength) * gelu
    tl.store(out_ptr + offs, out, mask=mask)


@triton.jit
def cantor_gelu_bwd_kernel(grad_out_ptr, x_ptr, grad_x_ptr, n_elements, step, strength, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < n_elements

    grad_out = tl.load(grad_out_ptr + offs, mask=mask)
    x = tl.load(x_ptr + offs, mask=mask)

    x3 = x * x * x
    inner = 0.7978845608028654 * (x + 0.044715 * x3)

    # Stable tanh
    inner_clamped = tl.minimum(tl.maximum(inner, -10.0), 10.0)
    e2 = tl.exp(2.0 * inner_clamped)
    tanh_inner = (e2 - 1.0) / (e2 + 1.0)

    sech2 = 1.0 - tanh_inner * tanh_inner
    f_prime = 0.7978845608028654 * (1.0 + 0.134145 * x * x)
    gelu_grad = 0.5 * (1.0 + tanh_inner) + 0.5 * x * sech2 * f_prime

    # Clamp gradient for stability
    gelu_grad = tl.minimum(tl.maximum(gelu_grad, -10.0), 10.0)

    grad_x = grad_out * (strength + (1.0 - strength) * gelu_grad)
    tl.store(grad_x_ptr + offs, grad_x, mask=mask)


class _CantorGELUFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, step, strength):
        out = torch.empty_like(x)
        n = x.numel()
        cantor_gelu_fwd_kernel[(triton.cdiv(n, 1024),)](x, out, n, step, strength, BLOCK=1024)
        ctx.save_for_backward(x)
        ctx.step, ctx.strength = step, strength
        return out

    @staticmethod
    def backward(ctx, grad_out):
        x, = ctx.saved_tensors
        grad_x = torch.empty_like(x)
        n = x.numel()
        cantor_gelu_bwd_kernel[(triton.cdiv(n, 1024),)](
            grad_out.contiguous(), x, grad_x, n, ctx.step, ctx.strength, BLOCK=1024
        )
        return grad_x, None, None


class CantorGELU(nn.Module):
    def __init__(self, num_stairs: int = 16, value_range: float = 8.0, init_strength: float = 0.1):
        super().__init__()
        self.step = 2 * value_range / num_stairs
        self.strength = nn.Parameter(torch.tensor(init_strength))
        self.num_stairs = num_stairs

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return F.gelu(x)
        s = torch.sigmoid(self.strength).item()
        return _CantorGELUFunc.apply(x.contiguous(), self.step, s)


# =============================================================================
# CONFIG
# =============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 128
EPOCHS = 10
LR = 1e-3
NUM_STAIRS = 8

print(f"Device: {DEVICE}")
print(f"CantorGELU stairs: {NUM_STAIRS}")

# =============================================================================
# DATA
# =============================================================================

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Train: {len(train_dataset):,} samples")
print(f"Test:  {len(test_dataset):,} samples")

# =============================================================================
# MODEL
# =============================================================================

class MNISTClassifier(nn.Module):
    def __init__(self, activation='cantor', num_stairs=16):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)

        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.bn_fc2 = nn.BatchNorm1d(128)

        if activation == 'cantor':
            self.act = CantorGELU(num_stairs=num_stairs)
            self.act_name = f'CantorGELU({num_stairs})'
        elif activation == 'gelu':
            self.act = nn.GELU()
            self.act_name = 'GELU'
        else:
            raise ValueError(f"Unknown activation: {activation}")

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.pool(self.act(self.bn1(self.conv1(x))))
        x = self.pool(self.act(self.bn2(self.conv2(x))))
        x = self.pool(self.act(self.bn3(self.conv3(x))))

        x = x.view(x.size(0), -1)

        x = self.dropout(self.act(self.bn_fc1(self.fc1(x))))
        x = self.dropout(self.act(self.bn_fc2(self.fc2(x))))
        x = self.fc3(x)
        return x

# =============================================================================
# TRAINING
# =============================================================================

def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data, target in loader:
        data, target = data.to(DEVICE), target.to(DEVICE)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)

    return total_loss / len(loader), 100. * correct / total


def evaluate(model, loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            total_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    return total_loss / len(loader), 100. * correct / total


def train_model(activation='cantor', num_stairs=16):
    print(f"\n{'='*60}")
    print(f"Training with {activation.upper()}")
    print(f"{'='*60}")

    model = MNISTClassifier(activation=activation, num_stairs=num_stairs).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    num_params = sum(p.numel() for p in model.parameters())
    print(f"Parameters: {num_params:,}")

    best_acc = 0
    history = []
    start_time = time.time()

    for epoch in range(1, EPOCHS + 1):
        epoch_start = time.time()

        train_loss, train_acc = train_epoch(model, train_loader, optimizer)
        test_loss, test_acc = evaluate(model, test_loader)
        scheduler.step()

        epoch_time = time.time() - epoch_start
        best_marker = " *" if test_acc > best_acc else ""
        if test_acc > best_acc:
            best_acc = test_acc

        print(f"E{epoch:02d} | Train: {train_loss:.4f} / {train_acc:.2f}% | "
              f"Test: {test_loss:.4f} / {test_acc:.2f}%{best_marker} | {epoch_time:.1f}s")

        history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc,
                       'test_loss': test_loss, 'test_acc': test_acc, 'time': epoch_time})

    total_time = time.time() - start_time
    print(f"Best: {best_acc:.2f}% | Total: {total_time:.1f}s")

    return model, history, best_acc


# =============================================================================
# RUN
# =============================================================================

print("\n" + "="*60)
print("MNIST CLASSIFIER - CANTOR GELU vs GELU")
print("="*60)

cantor_model, cantor_history, cantor_best = train_model('cantor', num_stairs=NUM_STAIRS)
gelu_model, gelu_history, gelu_best = train_model('gelu')

# Summary
print("\n" + "="*60)
print("SUMMARY")
print("="*60)

cantor_time = sum(h['time'] for h in cantor_history)
#gelu_time = sum(h['time'] for h in gelu_history)

print(f"\n{'Activation':<20} {'Best Acc':<12} {'Time':<12}")
print("-"*44)
print(f"{'CantorGELU':<20} {cantor_best:.2f}%{'':<6} {cantor_time:.1f}s")
#print(f"{'GELU':<20} {gelu_best:.2f}%{'':<6} {gelu_time:.1f}s")
#print(f"\nDiff: {cantor_best - gelu_best:+.2f}% | Overhead: {cantor_time/gelu_time:.2f}x")
print(f"Learned strength: {torch.sigmoid(cantor_model.act.strength).item():.3f}")

Device: cuda
CantorGELU stairs: 8
Train: 60,000 samples
Test:  10,000 samples

MNIST CLASSIFIER - CANTOR GELU vs GELU

Training with CANTOR
Parameters: 423,243
E01 | Train: 0.1155 / 96.93% | Test: 0.4026 / 95.63% * | 8.3s
E02 | Train: 0.0466 / 98.62% | Test: 0.2316 / 97.83% * | 8.2s
E03 | Train: 0.0324 / 98.98% | Test: 0.2443 / 95.51% | 8.0s
E04 | Train: 0.0253 / 99.22% | Test: 0.1452 / 98.11% * | 8.5s
E05 | Train: 0.0173 / 99.44% | Test: 0.1666 / 96.55% | 8.5s
E06 | Train: 0.0140 / 99.54% | Test: 0.1389 / 97.29% | 8.6s
E07 | Train: 0.0090 / 99.72% | Test: 0.1081 / 98.00% | 8.1s
E08 | Train: 0.0059 / 99.84% | Test: 0.1277 / 97.44% | 8.5s
E09 | Train: 0.0039 / 99.91% | Test: 0.0920 / 98.34% * | 8.5s
E10 | Train: 0.0032 / 99.94% | Test: 0.1060 / 98.03% | 8.2s
Best: 98.34% | Total: 83.2s

SUMMARY

Activation           Best Acc     Time        
--------------------------------------------
CantorGELU           98.34%       83.2s
Learned strength: 0.525


# dropouts MNIST

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time

# =============================================================================
# TOPOLOGICAL DROPOUT VARIANTS
# =============================================================================

class TopologicalDropout(nn.Module):
    """
    Structure-preserving dropout: drops entire routes/channels, not individual neurons.

    For CNNs: treats channels as routes (drops entire feature maps)
    For attention: treats heads/routes as units

    Key insight: Preserves internal structure of surviving routes.
    Standard dropout: Random holes everywhere → broken features
    Topo dropout: Some features fully on, others fully off → intact features
    """
    def __init__(self, drop_prob: float = 0.1, min_keep: int = 1, route_dim: int = 1):
        super().__init__()
        self.drop_prob = drop_prob
        self.min_keep = min_keep
        self.route_dim = route_dim  # Which dim contains "routes" (channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.drop_prob == 0:
            return x

        num_routes = x.shape[self.route_dim]
        num_keep = max(self.min_keep, int(num_routes * (1 - self.drop_prob)))

        # Random mask
        mask = torch.zeros(num_routes, device=x.device)
        perm = torch.randperm(num_routes, device=x.device)[:num_keep]
        mask[perm] = 1.0

        # Scale to preserve expected value
        mask = mask * (num_routes / num_keep)

        # Reshape for broadcast
        shape = [1] * x.dim()
        shape[self.route_dim] = num_routes

        return x * mask.view(shape)


class ImportanceTopologicalDropout(nn.Module):
    """
    Topo dropout with importance weighting.
    Less important routes more likely to be dropped.
    """
    def __init__(self, drop_prob: float = 0.1, min_keep: int = 1, route_dim: int = 1):
        super().__init__()
        self.drop_prob = drop_prob
        self.min_keep = min_keep
        self.route_dim = route_dim

    def forward(self, x: torch.Tensor, importance: torch.Tensor = None) -> torch.Tensor:
        if not self.training or self.drop_prob == 0:
            return x

        num_routes = x.shape[self.route_dim]
        num_keep = max(self.min_keep, int(num_routes * (1 - self.drop_prob)))

        if importance is None:
            # Use activation magnitude as importance proxy
            # Reduce all dims except route_dim
            reduce_dims = [i for i in range(x.dim()) if i != self.route_dim]
            importance = x.abs().mean(dim=reduce_dims)

        # Add noise to importance for stochasticity
        noise = torch.rand_like(importance) * 0.3
        scores = importance + noise

        # Keep top-k by importance
        _, keep_idx = scores.topk(num_keep)
        mask = torch.zeros(num_routes, device=x.device)
        mask[keep_idx] = 1.0

        # Scale
        mask = mask * (num_routes / num_keep)

        shape = [1] * x.dim()
        shape[self.route_dim] = num_routes

        return x * mask.view(shape)


class ScheduledTopologicalDropout(nn.Module):
    """
    Topo dropout with warmup schedule.
    Starts mild, increases over training.
    """
    def __init__(self, drop_prob: float = 0.2, min_keep: int = 1, route_dim: int = 1,
                 warmup_steps: int = 1000):
        super().__init__()
        self.target_drop_prob = drop_prob
        self.min_keep = min_keep
        self.route_dim = route_dim
        self.warmup_steps = warmup_steps
        self.register_buffer('step', torch.tensor(0))

    @property
    def current_drop_prob(self):
        progress = min(1.0, self.step.item() / self.warmup_steps)
        return self.target_drop_prob * progress

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            self.step += 1

        if not self.training or self.current_drop_prob == 0:
            return x

        num_routes = x.shape[self.route_dim]
        num_keep = max(self.min_keep, int(num_routes * (1 - self.current_drop_prob)))

        mask = torch.zeros(num_routes, device=x.device)
        perm = torch.randperm(num_routes, device=x.device)[:num_keep]
        mask[perm] = 1.0
        mask = mask * (num_routes / num_keep)

        shape = [1] * x.dim()
        shape[self.route_dim] = num_routes

        return x * mask.view(shape)


class SpatialTopologicalDropout(nn.Module):
    """
    For 2D feature maps: drops entire spatial regions (patches).
    Complements channel-wise topo dropout.
    """
    def __init__(self, drop_prob: float = 0.1, patch_size: int = 2):
        super().__init__()
        self.drop_prob = drop_prob
        self.patch_size = patch_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.drop_prob == 0:
            return x

        B, C, H, W = x.shape
        pH, pW = H // self.patch_size, W // self.patch_size

        if pH == 0 or pW == 0:
            return x

        # Create patch mask
        mask = (torch.rand(B, 1, pH, pW, device=x.device) > self.drop_prob).float()

        # Scale surviving patches
        keep_ratio = mask.mean()
        if keep_ratio > 0:
            mask = mask / keep_ratio

        # Upsample mask to full resolution
        mask = F.interpolate(mask, size=(H, W), mode='nearest')

        return x * mask


# =============================================================================
# CONFIG
# =============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 128
EPOCHS = 10
LR = 1e-3

print(f"Device: {DEVICE}")

# =============================================================================
# DATA
# =============================================================================

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Train: {len(train_dataset):,} | Test: {len(test_dataset):,}")

# =============================================================================
# MODEL
# =============================================================================

class MNISTClassifier(nn.Module):
    """
    CNN with configurable dropout strategy.

    dropout_type options:
    - 'standard': nn.Dropout (element-wise)
    - 'topo': TopologicalDropout (channel-wise)
    - 'topo_importance': ImportanceTopologicalDropout
    - 'topo_scheduled': ScheduledTopologicalDropout
    - 'spatial': SpatialTopologicalDropout
    - 'topo_spatial': Both channel and spatial topo dropout
    - 'none': No dropout
    """
    def __init__(self, dropout_type='topo', drop_prob=0.2):
        super().__init__()
        self.dropout_type = dropout_type

        # Conv layers
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)

        # FC layers
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

        # Batch norms
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.bn_fc2 = nn.BatchNorm1d(128)

        # Dropout variants
        if dropout_type == 'standard':
            self.drop_conv = nn.Dropout2d(drop_prob)
            self.drop_fc = nn.Dropout(drop_prob)
            self.spatial_drop = None
        elif dropout_type == 'topo':
            self.drop_conv = TopologicalDropout(drop_prob, min_keep=4, route_dim=1)
            self.drop_fc = TopologicalDropout(drop_prob, min_keep=16, route_dim=1)
            self.spatial_drop = None
        elif dropout_type == 'topo_importance':
            self.drop_conv = ImportanceTopologicalDropout(drop_prob, min_keep=4, route_dim=1)
            self.drop_fc = ImportanceTopologicalDropout(drop_prob, min_keep=16, route_dim=1)
            self.spatial_drop = None
        elif dropout_type == 'topo_scheduled':
            self.drop_conv = ScheduledTopologicalDropout(drop_prob, min_keep=4, route_dim=1, warmup_steps=500)
            self.drop_fc = ScheduledTopologicalDropout(drop_prob, min_keep=16, route_dim=1, warmup_steps=500)
            self.spatial_drop = None
        elif dropout_type == 'spatial':
            self.drop_conv = SpatialTopologicalDropout(drop_prob, patch_size=2)
            self.drop_fc = nn.Dropout(drop_prob)
            self.spatial_drop = None
        elif dropout_type == 'topo_spatial':
            self.drop_conv = TopologicalDropout(drop_prob * 0.5, min_keep=4, route_dim=1)
            self.drop_fc = TopologicalDropout(drop_prob, min_keep=16, route_dim=1)
            self.spatial_drop = SpatialTopologicalDropout(drop_prob * 0.5, patch_size=2)
        elif dropout_type == 'none':
            self.drop_conv = nn.Identity()
            self.drop_fc = nn.Identity()
            self.spatial_drop = None
        else:
            raise ValueError(f"Unknown dropout_type: {dropout_type}")

        self.pool = nn.MaxPool2d(2, 2)
        self.act = nn.GELU()

    def forward(self, x):
        # Conv block 1: 28x28 -> 14x14
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.drop_conv(x)
        if self.spatial_drop:
            x = self.spatial_drop(x)
        x = self.pool(x)

        # Conv block 2: 14x14 -> 7x7
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.drop_conv(x)
        if self.spatial_drop:
            x = self.spatial_drop(x)
        x = self.pool(x)

        # Conv block 3: 7x7 -> 3x3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act(x)
        x = self.drop_conv(x)
        x = self.pool(x)

        # Flatten
        x = x.view(x.size(0), -1)

        # FC layers
        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = self.act(x)
        x = self.drop_fc(x)

        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.act(x)
        x = self.drop_fc(x)

        x = self.fc3(x)
        return x


# =============================================================================
# TRAINING
# =============================================================================

def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data, target in loader:
        data, target = data.to(DEVICE), target.to(DEVICE)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)

    return total_loss / len(loader), 100. * correct / total


def evaluate(model, loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            total_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    return total_loss / len(loader), 100. * correct / total


def train_model(dropout_type='topo', drop_prob=0.2, verbose=True):
    model = MNISTClassifier(dropout_type=dropout_type, drop_prob=drop_prob).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    if verbose:
        print(f"\n{'='*60}")
        print(f"Training: {dropout_type} (p={drop_prob})")
        print(f"{'='*60}")

    best_acc = 0
    history = []
    start_time = time.time()

    for epoch in range(1, EPOCHS + 1):
        epoch_start = time.time()

        train_loss, train_acc = train_epoch(model, train_loader, optimizer)
        test_loss, test_acc = evaluate(model, test_loader)
        scheduler.step()

        epoch_time = time.time() - epoch_start
        best_marker = " *" if test_acc > best_acc else ""
        if test_acc > best_acc:
            best_acc = test_acc

        if verbose:
            print(f"E{epoch:02d} | Train: {train_loss:.4f} / {train_acc:.2f}% | "
                  f"Test: {test_loss:.4f} / {test_acc:.2f}%{best_marker} | {epoch_time:.1f}s")

        history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc,
                       'test_loss': test_loss, 'test_acc': test_acc, 'time': epoch_time})

    total_time = time.time() - start_time

    if verbose:
        print(f"Best: {best_acc:.2f}% | Total: {total_time:.1f}s")

    return model, history, best_acc, total_time


# =============================================================================
# RUN EXPERIMENTS
# =============================================================================

print("\n" + "="*70)
print("TOPOLOGICAL DROPOUT EXPERIMENT - MNIST")
print("="*70)

results = {}

# Test each dropout variant
dropout_configs = [
    ('none', 0.0),
    ('standard', 0.2),
    ('topo', 0.2),
    ('topo_importance', 0.2),
    ('topo_scheduled', 0.2),
    ('spatial', 0.2),
    ('topo_spatial', 0.2),
]

for dropout_type, drop_prob in dropout_configs:
    model, history, best_acc, total_time = train_model(dropout_type, drop_prob)
    results[dropout_type] = {
        'best_acc': best_acc,
        'final_acc': history[-1]['test_acc'],
        'time': total_time,
        'history': history
    }

# =============================================================================
# SUMMARY
# =============================================================================

print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)

print(f"\n{'Dropout Type':<20} {'Best Acc':<12} {'Final Acc':<12} {'Time':<10} {'Train-Test Gap':<15}")
print("-"*70)

# Sort by best accuracy
sorted_results = sorted(results.items(), key=lambda x: x[1]['best_acc'], reverse=True)

for dropout_type, data in sorted_results:
    final_train = data['history'][-1]['train_acc']
    final_test = data['history'][-1]['test_acc']
    gap = final_train - final_test
    print(f"{dropout_type:<20} {data['best_acc']:<12.2f}% {data['final_acc']:<12.2f}% "
          f"{data['time']:<10.1f}s {gap:<15.2f}%")

# Best vs baseline
baseline = results['standard']['best_acc']
print(f"\n--- vs Standard Dropout ---")
for dropout_type, data in sorted_results:
    if dropout_type != 'standard':
        diff = data['best_acc'] - baseline
        print(f"{dropout_type:<20} {diff:+.2f}%")

# =============================================================================
# GENERALIZATION ANALYSIS
# =============================================================================

print("\n" + "="*70)
print("GENERALIZATION ANALYSIS (Train-Test Gap Over Time)")
print("="*70)

print(f"\n{'Dropout Type':<20} {'E1 Gap':<10} {'E5 Gap':<10} {'E10 Gap':<10} {'Trend':<10}")
print("-"*60)

for dropout_type, data in sorted_results:
    h = data['history']
    gap_1 = h[0]['train_acc'] - h[0]['test_acc']
    gap_5 = h[4]['train_acc'] - h[4]['test_acc']
    gap_10 = h[9]['train_acc'] - h[9]['test_acc']

    if gap_10 < gap_1:
        trend = "↓ good"
    elif gap_10 > gap_1 + 1:
        trend = "↑ overfit"
    else:
        trend = "→ stable"

    print(f"{dropout_type:<20} {gap_1:<10.2f}% {gap_5:<10.2f}% {gap_10:<10.2f}% {trend:<10}")

Device: cuda
Train: 60,000 | Test: 10,000

TOPOLOGICAL DROPOUT EXPERIMENT - MNIST

Training: none (p=0.0)
E01 | Train: 0.1001 / 97.71% | Test: 0.0351 / 98.95% * | 8.2s
E02 | Train: 0.0311 / 99.08% | Test: 0.0281 / 99.14% * | 8.4s
E03 | Train: 0.0191 / 99.41% | Test: 0.0295 / 99.01% | 8.6s
E04 | Train: 0.0112 / 99.67% | Test: 0.0244 / 99.22% * | 8.3s
E05 | Train: 0.0071 / 99.79% | Test: 0.0196 / 99.31% * | 8.4s
E06 | Train: 0.0043 / 99.88% | Test: 0.0252 / 99.23% | 8.2s
E07 | Train: 0.0021 / 99.95% | Test: 0.0183 / 99.40% * | 8.8s
E08 | Train: 0.0009 / 99.99% | Test: 0.0183 / 99.37% | 8.0s
E09 | Train: 0.0005 / 100.00% | Test: 0.0182 / 99.36% | 8.1s
E10 | Train: 0.0004 / 100.00% | Test: 0.0176 / 99.39% | 8.3s
Best: 99.40% | Total: 83.3s

Training: standard (p=0.2)
E01 | Train: 0.1947 / 95.04% | Test: 0.0356 / 98.79% * | 8.1s
E02 | Train: 0.0562 / 98.29% | Test: 0.0245 / 99.12% * | 8.2s
E03 | Train: 0.0408 / 98.79% | Test: 0.0206 / 99.34% * | 8.1s
E04 | Train: 0.0328 / 98.99% | Test: 0.0

# dropouts fashionmnist

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time

# =============================================================================
# TOPOLOGICAL DROPOUT VARIANTS
# =============================================================================

class TopologicalDropout(nn.Module):
    """
    Structure-preserving dropout: drops entire routes/channels, not individual neurons.

    For CNNs: treats channels as routes (drops entire feature maps)
    For attention: treats heads/routes as units

    Key insight: Preserves internal structure of surviving routes.
    Standard dropout: Random holes everywhere → broken features
    Topo dropout: Some features fully on, others fully off → intact features
    """
    def __init__(self, drop_prob: float = 0.1, min_keep: int = 1, route_dim: int = 1):
        super().__init__()
        self.drop_prob = drop_prob
        self.min_keep = min_keep
        self.route_dim = route_dim  # Which dim contains "routes" (channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.drop_prob == 0:
            return x

        num_routes = x.shape[self.route_dim]
        num_keep = max(self.min_keep, int(num_routes * (1 - self.drop_prob)))

        # Random mask
        mask = torch.zeros(num_routes, device=x.device)
        perm = torch.randperm(num_routes, device=x.device)[:num_keep]
        mask[perm] = 1.0

        # Scale to preserve expected value
        mask = mask * (num_routes / num_keep)

        # Reshape for broadcast
        shape = [1] * x.dim()
        shape[self.route_dim] = num_routes

        return x * mask.view(shape)


class ImportanceTopologicalDropout(nn.Module):
    """
    Topo dropout with importance weighting.
    Less important routes more likely to be dropped.
    """
    def __init__(self, drop_prob: float = 0.1, min_keep: int = 1, route_dim: int = 1):
        super().__init__()
        self.drop_prob = drop_prob
        self.min_keep = min_keep
        self.route_dim = route_dim

    def forward(self, x: torch.Tensor, importance: torch.Tensor = None) -> torch.Tensor:
        if not self.training or self.drop_prob == 0:
            return x

        num_routes = x.shape[self.route_dim]
        num_keep = max(self.min_keep, int(num_routes * (1 - self.drop_prob)))

        if importance is None:
            # Use activation magnitude as importance proxy
            # Reduce all dims except route_dim
            reduce_dims = [i for i in range(x.dim()) if i != self.route_dim]
            importance = x.abs().mean(dim=reduce_dims)

        # Add noise to importance for stochasticity
        noise = torch.rand_like(importance) * 0.3
        scores = importance + noise

        # Keep top-k by importance
        _, keep_idx = scores.topk(num_keep)
        mask = torch.zeros(num_routes, device=x.device)
        mask[keep_idx] = 1.0

        # Scale
        mask = mask * (num_routes / num_keep)

        shape = [1] * x.dim()
        shape[self.route_dim] = num_routes

        return x * mask.view(shape)


class ScheduledTopologicalDropout(nn.Module):
    """
    Topo dropout with warmup schedule.
    Starts mild, increases over training.
    """
    def __init__(self, drop_prob: float = 0.2, min_keep: int = 1, route_dim: int = 1,
                 warmup_steps: int = 1000):
        super().__init__()
        self.target_drop_prob = drop_prob
        self.min_keep = min_keep
        self.route_dim = route_dim
        self.warmup_steps = warmup_steps
        self.register_buffer('step', torch.tensor(0))

    @property
    def current_drop_prob(self):
        progress = min(1.0, self.step.item() / self.warmup_steps)
        return self.target_drop_prob * progress

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            self.step += 1

        if not self.training or self.current_drop_prob == 0:
            return x

        num_routes = x.shape[self.route_dim]
        num_keep = max(self.min_keep, int(num_routes * (1 - self.current_drop_prob)))

        mask = torch.zeros(num_routes, device=x.device)
        perm = torch.randperm(num_routes, device=x.device)[:num_keep]
        mask[perm] = 1.0
        mask = mask * (num_routes / num_keep)

        shape = [1] * x.dim()
        shape[self.route_dim] = num_routes

        return x * mask.view(shape)


class SpatialTopologicalDropout(nn.Module):
    """
    For 2D feature maps: drops entire spatial regions (patches).
    Complements channel-wise topo dropout.
    """
    def __init__(self, drop_prob: float = 0.1, patch_size: int = 2):
        super().__init__()
        self.drop_prob = drop_prob
        self.patch_size = patch_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.drop_prob == 0:
            return x

        B, C, H, W = x.shape
        pH, pW = H // self.patch_size, W // self.patch_size

        if pH == 0 or pW == 0:
            return x

        # Create patch mask
        mask = (torch.rand(B, 1, pH, pW, device=x.device) > self.drop_prob).float()

        # Scale surviving patches
        keep_ratio = mask.mean()
        if keep_ratio > 0:
            mask = mask / keep_ratio

        # Upsample mask to full resolution
        mask = F.interpolate(mask, size=(H, W), mode='nearest')

        return x * mask


# =============================================================================
# CONFIG
# =============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 128
EPOCHS = 10
LR = 1e-3

print(f"Device: {DEVICE}")

# =============================================================================
# DATA
# =============================================================================

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Train: {len(train_dataset):,} | Test: {len(test_dataset):,}")

# =============================================================================
# MODEL
# =============================================================================

class MNISTClassifier(nn.Module):
    """
    CNN with configurable dropout strategy.

    dropout_type options:
    - 'standard': nn.Dropout (element-wise)
    - 'topo': TopologicalDropout (channel-wise)
    - 'topo_importance': ImportanceTopologicalDropout
    - 'topo_scheduled': ScheduledTopologicalDropout
    - 'spatial': SpatialTopologicalDropout
    - 'topo_spatial': Both channel and spatial topo dropout
    - 'none': No dropout
    """
    def __init__(self, dropout_type='topo', drop_prob=0.2):
        super().__init__()
        self.dropout_type = dropout_type

        # Conv layers
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)

        # FC layers
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

        # Batch norms
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn_fc1 = nn.BatchNorm1d(256)
        self.bn_fc2 = nn.BatchNorm1d(128)

        # Dropout variants
        if dropout_type == 'standard':
            self.drop_conv = nn.Dropout2d(drop_prob)
            self.drop_fc = nn.Dropout(drop_prob)
            self.spatial_drop = None
        elif dropout_type == 'topo':
            self.drop_conv = TopologicalDropout(drop_prob, min_keep=4, route_dim=1)
            self.drop_fc = TopologicalDropout(drop_prob, min_keep=16, route_dim=1)
            self.spatial_drop = None
        elif dropout_type == 'topo_importance':
            self.drop_conv = ImportanceTopologicalDropout(drop_prob, min_keep=4, route_dim=1)
            self.drop_fc = ImportanceTopologicalDropout(drop_prob, min_keep=16, route_dim=1)
            self.spatial_drop = None
        elif dropout_type == 'topo_scheduled':
            self.drop_conv = ScheduledTopologicalDropout(drop_prob, min_keep=4, route_dim=1, warmup_steps=500)
            self.drop_fc = ScheduledTopologicalDropout(drop_prob, min_keep=16, route_dim=1, warmup_steps=500)
            self.spatial_drop = None
        elif dropout_type == 'spatial':
            self.drop_conv = SpatialTopologicalDropout(drop_prob, patch_size=2)
            self.drop_fc = nn.Dropout(drop_prob)
            self.spatial_drop = None
        elif dropout_type == 'topo_spatial':
            self.drop_conv = TopologicalDropout(drop_prob * 0.5, min_keep=4, route_dim=1)
            self.drop_fc = TopologicalDropout(drop_prob, min_keep=16, route_dim=1)
            self.spatial_drop = SpatialTopologicalDropout(drop_prob * 0.5, patch_size=2)
        elif dropout_type == 'none':
            self.drop_conv = nn.Identity()
            self.drop_fc = nn.Identity()
            self.spatial_drop = None
        else:
            raise ValueError(f"Unknown dropout_type: {dropout_type}")

        self.pool = nn.MaxPool2d(2, 2)
        self.act = nn.GELU()

    def forward(self, x):
        # Conv block 1: 28x28 -> 14x14
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.drop_conv(x)
        if self.spatial_drop:
            x = self.spatial_drop(x)
        x = self.pool(x)

        # Conv block 2: 14x14 -> 7x7
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.drop_conv(x)
        if self.spatial_drop:
            x = self.spatial_drop(x)
        x = self.pool(x)

        # Conv block 3: 7x7 -> 3x3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act(x)
        x = self.drop_conv(x)
        x = self.pool(x)

        # Flatten
        x = x.view(x.size(0), -1)

        # FC layers
        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = self.act(x)
        x = self.drop_fc(x)

        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = self.act(x)
        x = self.drop_fc(x)

        x = self.fc3(x)
        return x


# =============================================================================
# TRAINING
# =============================================================================

def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data, target in loader:
        data, target = data.to(DEVICE), target.to(DEVICE)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)

    return total_loss / len(loader), 100. * correct / total


def evaluate(model, loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            total_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    return total_loss / len(loader), 100. * correct / total


def train_model(dropout_type='topo', drop_prob=0.2, verbose=True):
    model = MNISTClassifier(dropout_type=dropout_type, drop_prob=drop_prob).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    if verbose:
        print(f"\n{'='*60}")
        print(f"Training: {dropout_type} (p={drop_prob})")
        print(f"{'='*60}")

    best_acc = 0
    history = []
    start_time = time.time()

    for epoch in range(1, EPOCHS + 1):
        epoch_start = time.time()

        train_loss, train_acc = train_epoch(model, train_loader, optimizer)
        test_loss, test_acc = evaluate(model, test_loader)
        scheduler.step()

        epoch_time = time.time() - epoch_start
        best_marker = " *" if test_acc > best_acc else ""
        if test_acc > best_acc:
            best_acc = test_acc

        if verbose:
            print(f"E{epoch:02d} | Train: {train_loss:.4f} / {train_acc:.2f}% | "
                  f"Test: {test_loss:.4f} / {test_acc:.2f}%{best_marker} | {epoch_time:.1f}s")

        history.append({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc,
                       'test_loss': test_loss, 'test_acc': test_acc, 'time': epoch_time})

    total_time = time.time() - start_time

    if verbose:
        print(f"Best: {best_acc:.2f}% | Total: {total_time:.1f}s")

    return model, history, best_acc, total_time


# =============================================================================
# RUN EXPERIMENTS
# =============================================================================

print("\n" + "="*70)
print("TOPOLOGICAL DROPOUT EXPERIMENT 2 - MNIST Fashion")
print("="*70)

results = {}

# Test each dropout variant
dropout_configs = [
    ('none', 0.0),
    ('standard', 0.2),
    ('topo', 0.2),
    ('topo_importance', 0.2),
    ('topo_scheduled', 0.2),
    ('spatial', 0.2),
    ('topo_spatial', 0.2),
]

for dropout_type, drop_prob in dropout_configs:
    model, history, best_acc, total_time = train_model(dropout_type, drop_prob)
    results[dropout_type] = {
        'best_acc': best_acc,
        'final_acc': history[-1]['test_acc'],
        'time': total_time,
        'history': history
    }

# =============================================================================
# SUMMARY
# =============================================================================

print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)

print(f"\n{'Dropout Type':<20} {'Best Acc':<12} {'Final Acc':<12} {'Time':<10} {'Train-Test Gap':<15}")
print("-"*70)

# Sort by best accuracy
sorted_results = sorted(results.items(), key=lambda x: x[1]['best_acc'], reverse=True)

for dropout_type, data in sorted_results:
    final_train = data['history'][-1]['train_acc']
    final_test = data['history'][-1]['test_acc']
    gap = final_train - final_test
    print(f"{dropout_type:<20} {data['best_acc']:<12.2f}% {data['final_acc']:<12.2f}% "
          f"{data['time']:<10.1f}s {gap:<15.2f}%")

# Best vs baseline
baseline = results['standard']['best_acc']
print(f"\n--- vs Standard Dropout ---")
for dropout_type, data in sorted_results:
    if dropout_type != 'standard':
        diff = data['best_acc'] - baseline
        print(f"{dropout_type:<20} {diff:+.2f}%")

# =============================================================================
# GENERALIZATION ANALYSIS
# =============================================================================

print("\n" + "="*70)
print("GENERALIZATION ANALYSIS (Train-Test Gap Over Time)")
print("="*70)

print(f"\n{'Dropout Type':<20} {'E1 Gap':<10} {'E5 Gap':<10} {'E10 Gap':<10} {'Trend':<10}")
print("-"*60)

for dropout_type, data in sorted_results:
    h = data['history']
    gap_1 = h[0]['train_acc'] - h[0]['test_acc']
    gap_5 = h[4]['train_acc'] - h[4]['test_acc']
    gap_10 = h[9]['train_acc'] - h[9]['test_acc']

    if gap_10 < gap_1:
        trend = "↓ good"
    elif gap_10 > gap_1 + 1:
        trend = "↑ overfit"
    else:
        trend = "→ stable"

    print(f"{dropout_type:<20} {gap_1:<10.2f}% {gap_5:<10.2f}% {gap_10:<10.2f}% {trend:<10}")

Device: cuda


100%|██████████| 26.4M/26.4M [00:02<00:00, 9.71MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 185kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.57MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 17.5MB/s]

Train: 60,000 | Test: 10,000

TOPOLOGICAL DROPOUT EXPERIMENT - MNIST

Training: none (p=0.0)





E01 | Train: 0.3570 / 87.82% | Test: 0.2878 / 89.31% * | 8.2s
E02 | Train: 0.2209 / 91.86% | Test: 0.2457 / 91.27% * | 8.4s
E03 | Train: 0.1792 / 93.42% | Test: 0.2370 / 91.34% * | 8.6s
E04 | Train: 0.1427 / 94.78% | Test: 0.2227 / 92.19% * | 8.6s
E05 | Train: 0.1093 / 95.96% | Test: 0.2426 / 92.14% | 8.4s
E06 | Train: 0.0787 / 97.15% | Test: 0.2372 / 92.68% * | 8.3s
E07 | Train: 0.0484 / 98.33% | Test: 0.2590 / 92.75% * | 8.0s
E08 | Train: 0.0272 / 99.15% | Test: 0.2797 / 92.46% | 8.0s
E09 | Train: 0.0141 / 99.65% | Test: 0.2785 / 93.01% * | 8.5s
E10 | Train: 0.0090 / 99.87% | Test: 0.2821 / 93.06% * | 8.1s
Best: 93.06% | Total: 83.1s

Training: standard (p=0.2)
E01 | Train: 0.4817 / 83.29% | Test: 0.3101 / 88.74% * | 8.7s
E02 | Train: 0.3107 / 88.54% | Test: 0.2800 / 89.34% * | 8.2s
E03 | Train: 0.2680 / 90.15% | Test: 0.2490 / 90.74% * | 8.7s
E04 | Train: 0.2445 / 90.94% | Test: 0.2437 / 91.02% * | 8.0s
E05 | Train: 0.2231 / 91.78% | Test: 0.2218 / 91.81% * | 8.2s
E06 | Train: 0.206