In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split

In [2]:
# Simulated LayerSelect. We can choose whether we want a layer to be active
# or not by toggling the active parameter. 
class LayerSelect(nn.Module):
    def __init__(self, layer, active=True):
        super().__init__()
        self.layer = layer
        self.active = active

    def forward(self, x):
        if self.active:
            return self.layer(x)
        else:
            return x  # Skip layer

In [3]:
# Simulated SubnetNorm
class SubnetNorm2d(nn.Module):
    def __init__(self, num_channels, eps=1e-5):
        super().__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.stats = {}  # config_id -> {"mean": tensor, "var": tensor}
        self.active_config = None

    def set_active(self, config_id):
        self.active_config = config_id

    def forward(self, x):
        if self.active_config in self.stats:
            stats = self.stats[self.active_config]
            C = x.shape[1]
            mean = stats["mean"][:C]
            var = stats["var"][:C]
            x_hat = (x - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
            return x_hat
        else:
            # If no stats yet, just pass through identity (or use dummy BatchNorm if needed)
            return x

In [4]:
# Simulated ResNet blocks. Has a convolution layer followed by a batchnorm layer followed by
# another convolution layer followed by another batchnorm layer. 
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, width_mult=1.0):
        super().__init__()

        # layer initialization, mid_channels represents our output channels based on WeightSlice
        mid_channels = int(out_channels * width_mult)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = SubnetNorm2d(out_channels)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = SubnetNorm2d(out_channels)

        # define downsampling function for shape matching
        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.downsample(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + identity)

In [5]:
# Simulated SuperNet. 
# We represent LayerSelect by choosing which BasicBlocks we want to be active.
class MiniSuperNet(nn.Module):
    def __init__(self, depth=3, width_mult=1.0):
        super().__init__()
        self.width_mult = width_mult
        self.depth = depth        

        # initial stem definition
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            SubnetNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # hardcoded 4 residual blocks that we will dynamically activate
        self.blocks = nn.ModuleList()
        in_out_channels = [(64, 64), (64, 128), (128, 128), (128, 256)]
        for i, (in_c, out_c) in enumerate(in_out_channels):
            block = BasicBlock(in_c, out_c, stride=2 if i > 0 else 1, width_mult=width_mult)
            self.blocks.append(LayerSelect(block, active=(i < depth)))

        # define pooling function
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
    
    # finds the SubnetNorm2d layer and sets the active config to the (config id) config
    def set_active_config(self, config_id):
        for layer in self.modules():
            if isinstance(layer, SubnetNorm2d):
                layer.set_active(config_id)

    def forward(self, x):
        x = self.stem(x)
        for block in self.blocks:
            x = block(x)
        x = self.pool(x).squeeze(-1).squeeze(-1)

        if not hasattr(self, 'fc') or self.fc.in_features != x.shape[1]:
            self.fc = nn.Linear(x.shape[1], 10).to(x.device)

        return self.fc(x)

In [11]:
def precompute_stats(model, config_id, dataloader, device="cpu", num_batches=20):
    model.to(device)
    model.eval()
    model.set_active_config(config_id)

    hooks = []
    activations = {}

    # Find SubnetNorm layers and register hooks on their corresponding conv layers
    for name, module in model.named_modules():
        if isinstance(module, SubnetNorm2d):
            activations[module] = []
            
            def make_hook(layer_ref):
                def hook_fn(module, input, output):
                    # Capture the raw conv output (input to SubnetNorm)
                    activations[layer_ref].append(input[0].detach())
                return hook_fn
            
            # Register hook on the SubnetNorm layer to capture its input
            hooks.append(module.register_forward_hook(make_hook(module)))

    # Run forward passes
    with torch.no_grad():
        for i, (images, _) in enumerate(dataloader):
            if i >= num_batches:
                break
            images = images.to(device)
            _ = model(images)

    # Clean up hooks
    for h in hooks:
        h.remove()

    # Calculate statistics
    for layer, xs in activations.items():
        if len(xs) > 0:  # Safety check
            all_x = torch.cat(xs, dim=0)
            mean = all_x.mean(dim=(0, 2, 3))
            var = all_x.var(dim=(0, 2, 3), unbiased=False)
            layer.stats[config_id] = {"mean": mean, "var": var}
            print(f"Precomputed stats for {layer} with {len(xs)} batches")
        else:
            print(f"Warning: No activations captured for layer {layer} - likely inactive for this config")

In [26]:
def train_model(model, train_loader, epochs=10, device='cpu'):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (images, labels) in enumerate(train_loader):
            if i >= 200:  # More batches
                break
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Track accuracy during training
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        if epoch % 2 == 0:
            train_acc = 100 * correct / total
            print(f'Epoch {epoch}, Loss: {running_loss/200:.4f}, Train Acc: {train_acc:.1f}%')


In [18]:
def benchmark_inference(model, dataloader, config_id, device='cpu', num_batches=20, measure_accuracy=False):
    model.eval()
    model.to(device)
    model.set_active_config(config_id)

    total_time = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for i, (images, labels) in enumerate(dataloader):
            if i >= num_batches:
                break

            images = images.to(device)
            labels = labels.to(device)

            # Start timing
            start_time = time.perf_counter()

            logits = model(images)

            # Stop timing
            end_time = time.perf_counter()
            total_time += (end_time - start_time)

            if measure_accuracy:
                preds = torch.argmax(logits, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

    avg_latency = total_time / (num_batches * images.size(0))  # time per sample

    results = {
        "config_id": config_id,
        "avg_latency_sec": avg_latency
    }

    if measure_accuracy and total > 0:
        results["accuracy"] = correct / total

    return results

In [43]:
# Use standard normalization and resizing for CIFAR-10
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(full_train))
precomp_size = len(full_train) - train_size
train_dataset, precomp_dataset = random_split(full_train, [train_size, precomp_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
precomp_loader = DataLoader(precomp_dataset, batch_size=64, shuffle=True, num_workers=2)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Create different models for each config
models = {
    "D1_W50": MiniSuperNet(depth=1, width_mult=0.5),    
    "D2_W75": MiniSuperNet(depth=2, width_mult=0.75), 
    "D3_W100": MiniSuperNet(depth=3, width_mult=1.0)
}

# Train each model using the training split
for config_id, model in models.items():
    print(f"\nTraining {config_id}...")
    train_model(model, train_loader, epochs=10, device=device)  # Use train_loader instead of cifar_loader
    print(f"Finished training {config_id}")

In [38]:
# Precompute stats for each model
for config_id, model in models.items():
    model = model.to("cuda" if torch.cuda.is_available() else "cpu")
    precompute_stats(model, config_id, precomp_loader, num_batches=10)  # Use precomp_loader instead of cifar_loader


Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches
Precomputed stats for SubnetNorm2d() with 10 batches


In [39]:
# Load test data for benchmarking
cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())  # No augmentation for test
test_loader = DataLoader(cifar_test, batch_size=64, shuffle=False, num_workers=2)


In [None]:
# Benchmark each model
results = []
for config_id, model in models.items():
    res = benchmark_inference(model, test_loader, config_id=config_id, device=device, num_batches=20, measure_accuracy=True)
    results.append(res)

# Print nicely
for r in results:
    print(f"{r['config_id']} → Latency: {r['avg_latency_sec']*1e3:.2f} ms/sample, Accuracy: {r.get('accuracy', 'N/A'):.3f}")