In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

In [3]:
# Simulated LayerSelect. We can choose whether we want a layer to be active
# or not by toggling the active parameter. 
class LayerSelect(nn.Module):
    def __init__(self, layer, active=True):
        super().__init__()
        self.layer = layer
        self.active = active

    def forward(self, x):
        if self.active:
            return self.layer(x)
        else:
            return x  # Skip layer

In [4]:
# Simulated SubnetNorm
class SubnetNorm2d(nn.Module):
    def __init__(self, num_channels, eps=1e-5):
        super().__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.stats = {}  # config_id -> {"mean": tensor, "var": tensor}
        self.active_config = None

    def set_active(self, config_id):
        self.active_config = config_id

    def forward(self, x):
        if self.active_config is None:
            raise ValueError("Active configuration not set for SubnetNorm2d")

        stats = self.stats[self.active_config]
        C = x.shape[1]
        mean = stats["mean"][:C]
        var = stats["var"][:C]

        return (x - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))

In [5]:
# Simulated ResNet blocks. Has a convolution layer followed by a batchnorm layer followed by
# another convolution layer followed by another batchnorm layer. 
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, width_mult=1.0):
        super().__init__()

        # layer initialization, mid_channels represents our output channels based on WeightSlice
        mid_channels = int(out_channels * width_mult)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = SubnetNorm2d(out_channels)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = SubnetNorm2d(out_channels)

        # define downsampling function for shape matching
        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.downsample(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + identity)

In [6]:
# Simulated SuperNet. 
# We represent LayerSelect by choosing which BasicBlocks we want to be active.
class MiniSuperNet(nn.Module):
    def __init__(self, depth=3, width_mult=1.0):
        super().__init__()
        self.width_mult = width_mult
        self.depth = depth        

        # initial stem definition
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            SubnetNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # hardcoded 4 residual blocks that we will dynamically activate
        self.blocks = nn.ModuleList()
        in_out_channels = [(64, 64), (64, 128), (128, 128), (128, 256)]
        for i, (in_c, out_c) in enumerate(in_out_channels):
            block = BasicBlock(in_c, out_c, stride=2 if i > 0 else 1, width_mult=width_mult)
            self.blocks.append(LayerSelect(block, active=(i < depth)))

        # define pooling function
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
    
    # finds the SubnetNorm2d layer and sets the active config to the (config id) config
    def set_active_config(self, config_id):
        for layer in self.modules():
            if isinstance(layer, SubnetNorm2d):
                layer.set_active(config_id)

    def forward(self, x):
        x = self.stem(x)
        for block in self.blocks:
            x = block(x)
        x = self.pool(x).squeeze(-1).squeeze(-1)

        if not hasattr(self, 'fc') or self.fc.in_features != x.shape[1]:
            self.fc = nn.Linear(x.shape[1], 10).to(x.device)

        return self.fc(x)

In [7]:
# simply fake stats population. ideally during training, we would get mean and variance
# statistics for these configurations and would load them as we set up our infrastructure. 
def populate_fake_stats(model, config_id, num_channels_per_layer):
    for m in model.modules():
        if isinstance(m, SubnetNorm2d):
            C = num_channels_per_layer.get(m, m.num_channels)
            mean = torch.randn(C) * 0.1  # fake mean
            var = torch.rand(C) * 0.2 + 0.1  # fake var, avoid near-zero
            m.stats[config_id] = {
                "mean": mean,
                "var": var
            }

In [8]:
model = MiniSuperNet(depth=3, width_mult=0.75)
# TODO: implement a realistic version of statistic population
populate_fake_stats(model, "D3_W75", num_channels_per_layer={})
model.set_active_config("D3_W75")

x = torch.randn(8, 3, 32, 32)
logits = model(x)
print(logits.shape)

torch.Size([8, 10])
