# Super-network Tutorial

NAS has seen rapid growth in the machine learning research community. It automates the discovery of optimal deep neural network architectures in domains like computer vision and natural language processing. While there have been many recent advancements, there is still a significant focus on making the search more efficient to reduce the computational cost incurred when validating discovered architectures.

The computational overhead of evaluating deep neural network architectures during the search process can be costly due to the training and validation cycles. Novel weight-sharing approaches known as one-shot or super-networks offer a way to mitigate the training overhead. These approaches train a task-specific super-network architecture with a weight-sharing mechanism that allows the sub-networks to be treated as unique individual architectures. This enables sub-network model extraction and validation without a separate training cycle.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


# CIFAR

In [83]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load the training and test datasets
train_data = datasets.CIFAR10(root='data', train=True,
                                   download=True, transform=transform)
test_data = datasets.CIFAR10(root='data', train=False,
                                  download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

if False:
    
    # convert data to torch.FloatTensor
    transform = transforms.ToTensor()
    
    # Download training data from open datasets.
    train_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor(),
    )

    # Download test data from open datasets.
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor(),
    )

Files already downloaded and verified
Files already downloaded and verified


In [106]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 3, 32, 32])
Shape of y: torch.Size([64]) torch.int64


In [129]:
search_space_dict = {'layer1' : [10, 512], 'layer2': [10, 512]}


def sample_config(search_space_dict, reset_random_seed=False, seed=0):
    if reset_random_seed:
        random.seed(seed)
    
    config = dict()
    
    for key, value in search_space_dict.items():
        config[key] = random.choice(value)
        
    return config

test = sample_config(search_space_dict)
print(test)

{'layer1': 512, 'layer2': 512}


In [141]:
def _sample_weight(weight, sample_in_dim, sample_out_dim):

        sample_weight = weight[:, :sample_in_dim]
        sample_weight = sample_weight[:sample_out_dim, :]

        return sample_weight

def _sample_bias(bias, sample_out_dim):
    sample_bias = bias[:sample_out_dim]

    return sample_bias

class SuperLinear(nn.Linear):
    
    def __init__(self, super_in_dim, super_out_dim, bias=True):
        super().__init__(super_in_dim, super_out_dim, bias=bias)
        
        # Define SuperNetwork Bounds
        self.super_in_dim = super_in_dim
        self.super_out_dim = super_out_dim

        self.sample_in_dim = None
        self.sample_out_dim = None
        
        self.samples = {}
        super().reset_parameters()
        
        self.profiling = False
        
    def set_sample_config(self, sample_in_dim, sample_out_dim):
        self.sample_in_dim = sample_in_dim
        self.sample_out_dim = sample_out_dim

        self._sample_parameters() 
        
    def _sample_parameters(self):
        self.samples['weight'] = _sample_weight(self.weight, self.sample_in_dim, self.sample_out_dim)
        self.samples['bias'] = self.bias
        if self.bias is not None:
            self.samples['bias'] = _sample_bias(self.bias, self.sample_out_dim)
        return self.samples    
    
    def sample_parameters(self, resample=False):
        if self.profiling or resample:
            return self._sample_parameters()
        return self.samples
    
    def forward(self, x):
        self.sample_parameters()
        return F.linear(x, self.samples['weight'], self.samples['bias'])
    
    
    
    

In [142]:
# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
    
        self.layer1 = SuperLinear(super_in_dim=3*32*32, super_out_dim=512)
        self.act = nn.ReLU()
        self.layer2 = SuperLinear(super_in_dim=512, super_out_dim=512)
        self.layer3 = SuperLinear(super_in_dim=512, super_out_dim=10)

    def forward(self, x):
        x = self.flatten(x)
        
        x = self.layer1(x)
        x = self.act(x)
        x = self.layer2(x)
        x = self.act(x)
        x = self.layer3(x)
        
        return x #logits

    
    def set_sample_config(self, sample):
        
        self.layer1.set_sample_config(sample_in_dim=3*32*32, sample_out_dim=sample['layer1'])
        self.layer2.set_sample_config(sample_in_dim=sample['layer1'], sample_out_dim=sample['layer2'])
        self.layer3.set_sample_config(sample_in_dim=sample['layer2'], sample_out_dim=10)  


In [144]:
model = NeuralNetwork().to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (layer1): SuperLinear(in_features=3072, out_features=512, bias=True)
  (act): ReLU()
  (layer2): SuperLinear(in_features=512, out_features=512, bias=True)
  (layer3): SuperLinear(in_features=512, out_features=10, bias=True)
)


In [145]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        cfg = sample_config(search_space_dict)
        model.set_sample_config(cfg)
        
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [146]:
from itertools import product


keys, values = zip(*search_space_dict.items())
all_subnets = [dict(zip(keys, p)) for p in product(*values)]
print(result)

[{'layer1': 256, 'layer2': 256}, {'layer1': 256, 'layer2': 512}, {'layer1': 512, 'layer2': 256}, {'layer1': 512, 'layer2': 512}]


In [147]:
def test(dataloader, model, loss_fn):
    
    for cfg in all_subnets:
        print(cfg)
        
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)

                model.set_sample_config(cfg)

                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [148]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.309598  [    0/50000]
loss: 2.301352  [ 6400/50000]
loss: 2.296367  [12800/50000]
loss: 2.307145  [19200/50000]
loss: 2.066028  [25600/50000]
loss: 2.059029  [32000/50000]
loss: 2.055593  [38400/50000]
loss: 1.933538  [44800/50000]
{'layer1': 10, 'layer2': 10}
Test Error: 
 Accuracy: 12.3%, Avg loss: 2.302286 

{'layer1': 10, 'layer2': 512}
Test Error: 
 Accuracy: 19.9%, Avg loss: 2.271987 

{'layer1': 512, 'layer2': 10}
Test Error: 
 Accuracy: 19.7%, Avg loss: 2.281700 

{'layer1': 512, 'layer2': 512}
Test Error: 
 Accuracy: 32.0%, Avg loss: 1.951181 

Epoch 2
-------------------------------
loss: 2.288526  [    0/50000]
loss: 2.262601  [ 6400/50000]
loss: 2.297771  [12800/50000]
loss: 2.278017  [19200/50000]
loss: 2.299622  [25600/50000]
loss: 2.300469  [32000/50000]
loss: 1.954387  [38400/50000]
loss: 1.773582  [44800/50000]
{'layer1': 10, 'layer2': 10}
Test Error: 
 Accuracy: 14.4%, Avg loss: 2.299174 

{'layer1': 10, 'layer2': 512}
T

KeyboardInterrupt: 