# Single Path One-Shot Neural Architecture Search using Random Search

### Make everything a bit faster

Your task is to implement SPOS + Random Search to find the optimal ResNet-18-like architecture
for image classification on CIFAR-10 dataset.

### Proposed search space
The same search space as in the lecture, but without kernel size search.
This means that each search block should only have 3 operations.

### Latency awareness
While searching, you should implement hard constraint on model latency. As a proxy for
latency, use the number of MACs. The baseline ResNet-18 has 37M MACs for images of size 32x32.
The final selected architecture should have at most 30M MACs.
You can find an example computation for the number of MACs in this notebook.

### How to complete the task
Plan for your experiments:
1. Train baseline ResNet-18 model.  **(2 pts)**
2. Finalize supernet code. **(15 pts)**
3. Train supernet with the same hyperparameters except the number of epochs. You should increase
   it by a factor of 3-6. Save supernet weights.  **(10 pts)**
4. Write implementation for Random Search.  **(10 pts)**
5. Run it for 100-1000 iterations by measuring accuracies of models from the supernet.
   Keep track of each model accuracy and latency.  **(5 pts)**
6. Train the best architecture from scratch. For this, you can build the whole supernet, and
   select this architecture for all the forward passes of the training.  **(6 pts)**
7. Compare with the model from step 1. Does your model have a better accuracy?  **(2 pts)**

### Code structure

Main.ipynb - Learn your neural networks here.

resnet.py - ResNet-18 implementation from torchvision, simplified for this project.

supernet.py - Unfinished Supernet implementation.

cifar10.py - Transforms for CIFAR-10 training.


## Setup

Before you begin, make sure to install torch, torchvision, numpy, thop and tqdm libraries.

In [1]:
!pip install torch torchvision numpy thop tqdm

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


### Imports

In [1]:
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from thop import profile
from torch.utils.data import DataLoader
from tqdm import tqdm

from cifar10 import get_train_transform, get_val_transform
from resnet import resnet18

### Make everything a bit faster

In [2]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

### Build datasets and dataloaders for CIFAR-10

In [3]:
# Change this value if needed.
batch_size = 2048

In [4]:
train_transform = get_train_transform()
val_transform = get_val_transform()

train_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform,
)
test_set = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=val_transform,
)

train_dataloader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)
test_dataloader = DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    drop_last=False,
)

Files already downloaded and verified
Files already downloaded and verified


## Create your baseline model

In [5]:
# Select suitable device.
# You should probably use either cuda (NVidia GPU) or mps (Apple) backend.
device = torch.device('cuda:0')

In [6]:
model = resnet18(num_classes=10, zero_init_residual=True)
model.to(device=device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### Compute the number of MACs and parameters for the model

In [7]:
macs, params = profile(model, inputs=(torch.zeros(1, 3, 32, 32, device=device),))
print(f'Number of macs: {macs / 1e6:.2f}M, number of parameters: {params / 1e6:.2f}M')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Number of macs: 37.22M, number of parameters: 11.18M


## Train baseline model

### Define loss function

In [8]:
criterion = nn.CrossEntropyLoss()

### Select hyperparameters

In [10]:
lr = 0.25
weight_decay = 5e-4
momentum = 0.9
n_epochs = 20  # Longer training gives better results, but let's keep baseline model epochs to 20.

### Build optimizer and scheduler

In [11]:
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * n_epochs)

### Define training and evaluation functions

In [12]:
def train_one_epoch(
        model: nn.Module,
        criterion: nn.Module,
        dataloader: DataLoader,
        optimizer: optim.Optimizer,
        scheduler,
        device: torch.device,
        epoch: int,
) -> Tuple[float, float]:
    model.train()

    total_loss = 0.0
    total_correct = 0.0
    total_samples = 0

    wrapped_dataloader = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (inputs, labels) in wrapped_dataloader:
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        optimizer.zero_grad()

        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        with torch.no_grad():
            _, predicted_labels = torch.max(logits, 1)
            total_loss += loss.item()
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.shape[0]

        wrapped_dataloader.set_description(
            f'(train) Epoch={epoch}, lr={scheduler.get_last_lr()[0]:.4f} loss={total_loss / (i + 1):.3f}'
        )

    return total_loss / len(dataloader), total_correct / total_samples


@torch.no_grad()
def validate_one_epoch(
        model: nn.Module,
        criterion: nn.Module,
        dataloader: DataLoader,
        device: torch.device,
        epoch: int,
) -> Tuple[float, float]:
    model.eval()

    total_loss = 0.0
    total_correct = 0.0
    total_samples = 0

    wrapped_dataloader = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (inputs, labels) in wrapped_dataloader:
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        logits = model(inputs)
        loss = criterion(logits, labels)
        _, predicted_labels = torch.max(logits, 1)
        total_loss += loss.item()
        total_correct += (predicted_labels == labels).sum().item()
        total_samples += labels.shape[0]

        wrapped_dataloader.set_description(f'(val) Epoch={epoch}, loss={total_loss / (i + 1):.3f}')

    return total_loss / len(dataloader), total_correct / total_samples

### Run training

In [13]:
for epoch in range(n_epochs):
    print(f'Epoch: {epoch}')
    loss, accuracy = train_one_epoch(model, criterion, train_dataloader, optimizer, scheduler, device, epoch)
    print(f'train_loss={loss:.4f}, train_accuracy={accuracy:.3%}')
    loss, accuracy = validate_one_epoch(model, criterion, test_dataloader, device, epoch)
    print(f'test_loss={loss:.4f}, test_accuracy={accuracy:.3%}')

Epoch: 0


(train) Epoch=0, lr=0.2485 loss=2.438: 100%|██████████| 24/24 [00:04<00:00,  6.00it/s]

train_loss=2.4383, train_accuracy=15.466%



(val) Epoch=0, loss=2.250: 100%|██████████| 5/5 [00:00<00:00,  5.21it/s]

test_loss=2.2498, test_accuracy=18.220%
Epoch: 1



(train) Epoch=1, lr=0.2439 loss=1.961: 100%|██████████| 24/24 [00:03<00:00,  7.85it/s]

train_loss=1.9612, train_accuracy=27.061%



(val) Epoch=1, loss=1.677: 100%|██████████| 5/5 [00:00<00:00, 10.04it/s]

test_loss=1.6774, test_accuracy=38.130%
Epoch: 2



(train) Epoch=2, lr=0.2364 loss=1.610: 100%|██████████| 24/24 [00:03<00:00,  7.84it/s]

train_loss=1.6102, train_accuracy=40.249%



(val) Epoch=2, loss=1.457: 100%|██████████| 5/5 [00:00<00:00,  9.12it/s]

test_loss=1.4567, test_accuracy=46.390%
Epoch: 3



(train) Epoch=3, lr=0.2261 loss=1.441: 100%|██████████| 24/24 [00:03<00:00,  7.94it/s]

train_loss=1.4414, train_accuracy=47.078%



(val) Epoch=3, loss=1.314: 100%|██████████| 5/5 [00:00<00:00,  9.54it/s]

test_loss=1.3144, test_accuracy=51.600%
Epoch: 4



(train) Epoch=4, lr=0.2134 loss=1.320: 100%|██████████| 24/24 [00:03<00:00,  7.64it/s]

train_loss=1.3205, train_accuracy=52.028%



(val) Epoch=4, loss=1.261: 100%|██████████| 5/5 [00:00<00:00,  9.70it/s]

test_loss=1.2611, test_accuracy=53.290%
Epoch: 5



(train) Epoch=5, lr=0.1985 loss=1.236: 100%|██████████| 24/24 [00:03<00:00,  7.84it/s]

train_loss=1.2356, train_accuracy=55.290%



(val) Epoch=5, loss=1.302: 100%|██████████| 5/5 [00:00<00:00,  9.99it/s]

test_loss=1.3016, test_accuracy=55.680%
Epoch: 6



(train) Epoch=6, lr=0.1817 loss=1.148: 100%|██████████| 24/24 [00:03<00:00,  7.94it/s]

train_loss=1.1481, train_accuracy=58.677%



(val) Epoch=6, loss=1.050: 100%|██████████| 5/5 [00:00<00:00,  9.47it/s]

test_loss=1.0504, test_accuracy=63.070%
Epoch: 7



(train) Epoch=7, lr=0.1636 loss=1.059: 100%|██████████| 24/24 [00:02<00:00,  8.04it/s]

train_loss=1.0587, train_accuracy=62.115%



(val) Epoch=7, loss=1.064: 100%|██████████| 5/5 [00:00<00:00,  8.85it/s]

test_loss=1.0636, test_accuracy=62.430%
Epoch: 8



(train) Epoch=8, lr=0.1446 loss=0.989: 100%|██████████| 24/24 [00:03<00:00,  7.39it/s]

train_loss=0.9885, train_accuracy=64.705%



(val) Epoch=8, loss=1.028: 100%|██████████| 5/5 [00:00<00:00,  9.87it/s]

test_loss=1.0279, test_accuracy=63.020%
Epoch: 9



(train) Epoch=9, lr=0.1250 loss=0.929: 100%|██████████| 24/24 [00:03<00:00,  7.68it/s]

train_loss=0.9289, train_accuracy=66.980%



(val) Epoch=9, loss=0.895: 100%|██████████| 5/5 [00:00<00:00,  9.68it/s]

test_loss=0.8946, test_accuracy=68.200%
Epoch: 10



(train) Epoch=10, lr=0.1054 loss=0.876: 100%|██████████| 24/24 [00:03<00:00,  7.63it/s]

train_loss=0.8756, train_accuracy=68.852%



(val) Epoch=10, loss=0.838: 100%|██████████| 5/5 [00:00<00:00,  9.85it/s]

test_loss=0.8383, test_accuracy=70.780%
Epoch: 11



(train) Epoch=11, lr=0.0864 loss=0.838: 100%|██████████| 24/24 [00:03<00:00,  7.54it/s]

train_loss=0.8384, train_accuracy=70.160%



(val) Epoch=11, loss=0.828: 100%|██████████| 5/5 [00:00<00:00,  9.60it/s]

test_loss=0.8276, test_accuracy=71.090%
Epoch: 12



(train) Epoch=12, lr=0.0683 loss=0.801: 100%|██████████| 24/24 [00:02<00:00,  8.21it/s]

train_loss=0.8011, train_accuracy=71.619%



(val) Epoch=12, loss=0.788: 100%|██████████| 5/5 [00:00<00:00,  9.66it/s]

test_loss=0.7877, test_accuracy=72.220%
Epoch: 13



(train) Epoch=13, lr=0.0515 loss=0.770: 100%|██████████| 24/24 [00:02<00:00,  8.22it/s]

train_loss=0.7699, train_accuracy=72.752%



(val) Epoch=13, loss=0.769: 100%|██████████| 5/5 [00:00<00:00,  9.77it/s]


test_loss=0.7694, test_accuracy=72.890%
Epoch: 14


(train) Epoch=14, lr=0.0366 loss=0.741: 100%|██████████| 24/24 [00:02<00:00,  8.06it/s]

train_loss=0.7413, train_accuracy=73.653%



(val) Epoch=14, loss=0.763: 100%|██████████| 5/5 [00:00<00:00,  9.79it/s]


test_loss=0.7628, test_accuracy=73.760%
Epoch: 15


(train) Epoch=15, lr=0.0239 loss=0.718: 100%|██████████| 24/24 [00:02<00:00,  8.16it/s]

train_loss=0.7179, train_accuracy=74.388%



(val) Epoch=15, loss=0.725: 100%|██████████| 5/5 [00:00<00:00,  9.88it/s]

test_loss=0.7250, test_accuracy=74.170%
Epoch: 16



(train) Epoch=16, lr=0.0136 loss=0.693: 100%|██████████| 24/24 [00:03<00:00,  7.87it/s]

train_loss=0.6928, train_accuracy=75.399%



(val) Epoch=16, loss=0.706: 100%|██████████| 5/5 [00:00<00:00,  9.57it/s]

test_loss=0.7055, test_accuracy=74.880%
Epoch: 17



(train) Epoch=17, lr=0.0061 loss=0.683: 100%|██████████| 24/24 [00:03<00:00,  7.82it/s]

train_loss=0.6828, train_accuracy=75.680%



(val) Epoch=17, loss=0.698: 100%|██████████| 5/5 [00:00<00:00,  9.53it/s]


test_loss=0.6975, test_accuracy=75.430%
Epoch: 18


(train) Epoch=18, lr=0.0015 loss=0.672: 100%|██████████| 24/24 [00:02<00:00,  8.07it/s]

train_loss=0.6725, train_accuracy=76.166%



(val) Epoch=18, loss=0.688: 100%|██████████| 5/5 [00:00<00:00,  9.12it/s]

test_loss=0.6878, test_accuracy=75.750%
Epoch: 19



(train) Epoch=19, lr=0.0000 loss=0.668: 100%|██████████| 24/24 [00:03<00:00,  7.89it/s]

train_loss=0.6677, train_accuracy=76.308%



(val) Epoch=19, loss=0.687: 100%|██████████| 5/5 [00:00<00:00,  9.26it/s]


test_loss=0.6868, test_accuracy=75.850%


### Save trained model weights

In [14]:
torch.save(model.state_dict(), 'baseline_model.pth')

## Neural Architecture Search - Supernet training

### Create supernet

**Before running code in this section, you need to finish supernet implementation.**

Please, go to `supernet.py` file and inspect the current implementation of SearchBlock and Supernet classes.
Pay attention to the TODOs. You need to implement all of them.

Supernet and BasicBlock classes are modified versions of ResNet and BasicBlock classes from `resnet.py`.

    Tip: to understand how the Supernet is constructed, compare the implementation of Supernet and ResNet classes. You should probably use diff tool in your IDE or something.

Task: briefly describe the differences made to construct supernet.

In [15]:
import importlib
import speed_up_nn.hw_04.supernet as supernet_lib
importlib.reload(supernet_lib)
from speed_up_nn.hw_04.supernet import supernet18

In [16]:
# Define inner channel multipliers as in lecture.
channel_multipliers = [0.5, 1.0, 2.0]

In [17]:
supernet = supernet18(num_classes=10, zero_init_residual=True, channel_multipliers=channel_multipliers)
supernet.to(device=device)

Supernet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): SearchBlock(
      (ops): ModuleList(
        (0): BasicBlock(
          (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05,

### Select hyperparameters

In [18]:
# define hyperparameters for supernet training.
lr = 0.25
weight_decay = 5e-4
momentum = 0.9
n_epochs = 50

### Build optimizer and scheduler

In [19]:
# build optimizer and scheduler for supernet training.
optimizer = optim.SGD(supernet.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_dataloader) * n_epochs)

### Define training function

In [20]:
def pretrain_one_epoch(
        model: nn.Module,
        criterion: nn.Module,
        dataloader: DataLoader,
        optimizer: optim.Optimizer,
        scheduler,
        device: torch.device,
        epoch: int,
) -> Tuple[float, float]:
    model.train()

    total_loss = 0.0
    total_correct = 0.0
    total_samples = 0

    wrapped_dataloader = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (inputs, labels) in wrapped_dataloader:
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        optimizer.zero_grad()
        model.sample_random_architecture()
        logits = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        with torch.no_grad():
            _, predicted_labels = torch.max(logits, 1)
            total_loss += loss.item()
            total_correct += (predicted_labels == labels).sum().item()
            total_samples += labels.shape[0]

        wrapped_dataloader.set_description(
            f'(train) Epoch={epoch}, lr={scheduler.get_last_lr()[0]:.4f} loss={total_loss / (i + 1):.3f}'
        )

    return total_loss / len(dataloader), total_correct / total_samples

In [21]:
for epoch in range(n_epochs):
    print(f'Epoch: {epoch}')
    loss, accuracy = pretrain_one_epoch(supernet, criterion, train_dataloader, optimizer, scheduler, device, epoch)
    print(f'train_loss={loss:.4f}, train_accuracy={accuracy:.3%}')
    loss, accuracy = validate_one_epoch(supernet, criterion, test_dataloader, device, epoch)
    print(f'test_loss={loss:.4f}, test_accuracy={accuracy:.3%}')

Epoch: 0


(train) Epoch=0, lr=0.2498 loss=3.113: 100%|██████████| 24/24 [00:05<00:00,  4.08it/s]

train_loss=3.1134, train_accuracy=11.292%



(val) Epoch=0, loss=3.655: 100%|██████████| 5/5 [00:01<00:00,  4.10it/s]

test_loss=3.6555, test_accuracy=9.420%
Epoch: 1



(train) Epoch=1, lr=0.2490 loss=5.228: 100%|██████████| 24/24 [00:02<00:00,  8.06it/s]

train_loss=5.2279, train_accuracy=11.414%



(val) Epoch=1, loss=93.874: 100%|██████████| 5/5 [00:00<00:00,  5.79it/s]

test_loss=93.8742, test_accuracy=11.310%
Epoch: 2



(train) Epoch=2, lr=0.2478 loss=4.236: 100%|██████████| 24/24 [00:03<00:00,  7.96it/s]

train_loss=4.2360, train_accuracy=10.870%



(val) Epoch=2, loss=75.687: 100%|██████████| 5/5 [00:00<00:00,  7.26it/s]

test_loss=75.6869, test_accuracy=10.080%
Epoch: 3



(train) Epoch=3, lr=0.2461 loss=2.893: 100%|██████████| 24/24 [00:02<00:00,  8.06it/s]

train_loss=2.8930, train_accuracy=10.775%



(val) Epoch=3, loss=35.303: 100%|██████████| 5/5 [00:00<00:00,  7.96it/s]

test_loss=35.3034, test_accuracy=9.970%
Epoch: 4



(train) Epoch=4, lr=0.2439 loss=2.573: 100%|██████████| 24/24 [00:03<00:00,  7.54it/s]

train_loss=2.5732, train_accuracy=10.868%



(val) Epoch=4, loss=21.741: 100%|██████████| 5/5 [00:00<00:00,  6.74it/s]

test_loss=21.7409, test_accuracy=7.630%
Epoch: 5



(train) Epoch=5, lr=0.2412 loss=2.342: 100%|██████████| 24/24 [00:03<00:00,  7.72it/s]


train_loss=2.3422, train_accuracy=11.491%


(val) Epoch=5, loss=4.704: 100%|██████████| 5/5 [00:00<00:00,  9.43it/s]

test_loss=4.7044, test_accuracy=13.160%
Epoch: 6



(train) Epoch=6, lr=0.2381 loss=2.290: 100%|██████████| 24/24 [00:03<00:00,  7.76it/s]

train_loss=2.2896, train_accuracy=12.423%



(val) Epoch=6, loss=2.315: 100%|██████████| 5/5 [00:00<00:00,  8.55it/s]

test_loss=2.3152, test_accuracy=15.340%
Epoch: 7



(train) Epoch=7, lr=0.2345 loss=2.252: 100%|██████████| 24/24 [00:03<00:00,  7.56it/s]

train_loss=2.2522, train_accuracy=14.419%



(val) Epoch=7, loss=2.523: 100%|██████████| 5/5 [00:00<00:00,  9.36it/s]

test_loss=2.5226, test_accuracy=16.340%
Epoch: 8



(train) Epoch=8, lr=0.2305 loss=2.193: 100%|██████████| 24/24 [00:03<00:00,  7.62it/s]

train_loss=2.1928, train_accuracy=17.120%



(val) Epoch=8, loss=2.136: 100%|██████████| 5/5 [00:00<00:00,  9.56it/s]


test_loss=2.1356, test_accuracy=20.180%
Epoch: 9


(train) Epoch=9, lr=0.2261 loss=2.149: 100%|██████████| 24/24 [00:03<00:00,  7.76it/s]

train_loss=2.1491, train_accuracy=18.620%



(val) Epoch=9, loss=2.108: 100%|██████████| 5/5 [00:00<00:00,  9.72it/s]

test_loss=2.1080, test_accuracy=20.020%
Epoch: 10



(train) Epoch=10, lr=0.2213 loss=2.091: 100%|██████████| 24/24 [00:03<00:00,  7.69it/s]


train_loss=2.0906, train_accuracy=20.746%


(val) Epoch=10, loss=2.135: 100%|██████████| 5/5 [00:00<00:00,  9.40it/s]


test_loss=2.1354, test_accuracy=19.130%
Epoch: 11


(train) Epoch=11, lr=0.2161 loss=2.049: 100%|██████████| 24/24 [00:02<00:00,  8.07it/s]

train_loss=2.0494, train_accuracy=22.870%



(val) Epoch=11, loss=2.016: 100%|██████████| 5/5 [00:00<00:00,  9.38it/s]


test_loss=2.0164, test_accuracy=22.540%
Epoch: 12


(train) Epoch=12, lr=0.2106 loss=2.005: 100%|██████████| 24/24 [00:03<00:00,  7.86it/s]


train_loss=2.0051, train_accuracy=24.384%


(val) Epoch=12, loss=1.972: 100%|██████████| 5/5 [00:00<00:00,  9.49it/s]


test_loss=1.9721, test_accuracy=25.310%
Epoch: 13


(train) Epoch=13, lr=0.2047 loss=1.967: 100%|██████████| 24/24 [00:03<00:00,  7.73it/s]


train_loss=1.9672, train_accuracy=25.098%


(val) Epoch=13, loss=2.042: 100%|██████████| 5/5 [00:00<00:00,  9.60it/s]

test_loss=2.0419, test_accuracy=22.060%
Epoch: 14



(train) Epoch=14, lr=0.1985 loss=1.933: 100%|██████████| 24/24 [00:02<00:00,  8.07it/s]

train_loss=1.9335, train_accuracy=26.310%



(val) Epoch=14, loss=1.896: 100%|██████████| 5/5 [00:00<00:00,  9.61it/s]

test_loss=1.8958, test_accuracy=26.300%
Epoch: 15



(train) Epoch=15, lr=0.1920 loss=1.887: 100%|██████████| 24/24 [00:03<00:00,  7.85it/s]

train_loss=1.8873, train_accuracy=28.125%



(val) Epoch=15, loss=1.854: 100%|██████████| 5/5 [00:00<00:00,  9.26it/s]

test_loss=1.8537, test_accuracy=30.620%
Epoch: 16



(train) Epoch=16, lr=0.1852 loss=1.855: 100%|██████████| 24/24 [00:03<00:00,  7.62it/s]

train_loss=1.8553, train_accuracy=29.415%



(val) Epoch=16, loss=1.925: 100%|██████████| 5/5 [00:00<00:00,  9.36it/s]

test_loss=1.9252, test_accuracy=26.470%
Epoch: 17



(train) Epoch=17, lr=0.1782 loss=1.812: 100%|██████████| 24/24 [00:03<00:00,  7.85it/s]

train_loss=1.8119, train_accuracy=31.022%



(val) Epoch=17, loss=1.844: 100%|██████████| 5/5 [00:00<00:00,  9.71it/s]

test_loss=1.8440, test_accuracy=30.970%
Epoch: 18



(train) Epoch=18, lr=0.1710 loss=1.785: 100%|██████████| 24/24 [00:03<00:00,  7.51it/s]

train_loss=1.7850, train_accuracy=32.660%



(val) Epoch=18, loss=1.790: 100%|██████████| 5/5 [00:00<00:00,  8.26it/s]

test_loss=1.7902, test_accuracy=33.110%
Epoch: 19



(train) Epoch=19, lr=0.1636 loss=1.751: 100%|██████████| 24/24 [00:03<00:00,  7.65it/s]

train_loss=1.7512, train_accuracy=33.462%



(val) Epoch=19, loss=1.710: 100%|██████████| 5/5 [00:00<00:00,  9.07it/s]

test_loss=1.7100, test_accuracy=31.450%
Epoch: 20



(train) Epoch=20, lr=0.1561 loss=1.724: 100%|██████████| 24/24 [00:03<00:00,  7.88it/s]


train_loss=1.7242, train_accuracy=34.576%


(val) Epoch=20, loss=1.966: 100%|██████████| 5/5 [00:00<00:00,  9.63it/s]

test_loss=1.9656, test_accuracy=29.040%
Epoch: 21



(train) Epoch=21, lr=0.1484 loss=1.686: 100%|██████████| 24/24 [00:03<00:00,  7.78it/s]

train_loss=1.6863, train_accuracy=36.314%



(val) Epoch=21, loss=1.724: 100%|██████████| 5/5 [00:00<00:00,  9.34it/s]

test_loss=1.7239, test_accuracy=34.630%
Epoch: 22



(train) Epoch=22, lr=0.1407 loss=1.668: 100%|██████████| 24/24 [00:03<00:00,  7.82it/s]


train_loss=1.6684, train_accuracy=37.016%


(val) Epoch=22, loss=1.691: 100%|██████████| 5/5 [00:00<00:00,  9.35it/s]

test_loss=1.6914, test_accuracy=35.640%
Epoch: 23



(train) Epoch=23, lr=0.1328 loss=1.641: 100%|██████████| 24/24 [00:03<00:00,  7.47it/s]

train_loss=1.6409, train_accuracy=38.269%



(val) Epoch=23, loss=1.754: 100%|██████████| 5/5 [00:00<00:00,  9.47it/s]

test_loss=1.7538, test_accuracy=35.070%
Epoch: 24



(train) Epoch=24, lr=0.1250 loss=1.619: 100%|██████████| 24/24 [00:03<00:00,  7.70it/s]

train_loss=1.6189, train_accuracy=38.932%



(val) Epoch=24, loss=1.679: 100%|██████████| 5/5 [00:00<00:00,  9.43it/s]

test_loss=1.6785, test_accuracy=36.540%
Epoch: 25



(train) Epoch=25, lr=0.1172 loss=1.598: 100%|██████████| 24/24 [00:02<00:00,  8.27it/s]

train_loss=1.5983, train_accuracy=39.943%



(val) Epoch=25, loss=1.746: 100%|██████████| 5/5 [00:00<00:00,  9.53it/s]

test_loss=1.7455, test_accuracy=34.800%
Epoch: 26



(train) Epoch=26, lr=0.1093 loss=1.585: 100%|██████████| 24/24 [00:03<00:00,  7.91it/s]

train_loss=1.5849, train_accuracy=40.397%



(val) Epoch=26, loss=1.642: 100%|██████████| 5/5 [00:00<00:00,  9.47it/s]

test_loss=1.6423, test_accuracy=38.250%
Epoch: 27



(train) Epoch=27, lr=0.1016 loss=1.567: 100%|██████████| 24/24 [00:02<00:00,  8.05it/s]

train_loss=1.5667, train_accuracy=41.036%



(val) Epoch=27, loss=1.551: 100%|██████████| 5/5 [00:00<00:00,  9.28it/s]


test_loss=1.5509, test_accuracy=41.650%
Epoch: 28


(train) Epoch=28, lr=0.0939 loss=1.544: 100%|██████████| 24/24 [00:03<00:00,  7.79it/s]

train_loss=1.5439, train_accuracy=41.996%



(val) Epoch=28, loss=1.919: 100%|██████████| 5/5 [00:00<00:00,  9.49it/s]

test_loss=1.9195, test_accuracy=30.610%
Epoch: 29



(train) Epoch=29, lr=0.0864 loss=1.533: 100%|██████████| 24/24 [00:03<00:00,  7.90it/s]

train_loss=1.5327, train_accuracy=42.389%



(val) Epoch=29, loss=1.558: 100%|██████████| 5/5 [00:00<00:00,  9.36it/s]

test_loss=1.5583, test_accuracy=42.080%
Epoch: 30



(train) Epoch=30, lr=0.0790 loss=1.528: 100%|██████████| 24/24 [00:02<00:00,  8.01it/s]

train_loss=1.5284, train_accuracy=43.113%



(val) Epoch=30, loss=1.678: 100%|██████████| 5/5 [00:00<00:00,  9.48it/s]


test_loss=1.6780, test_accuracy=38.120%
Epoch: 31


(train) Epoch=31, lr=0.0718 loss=1.496: 100%|██████████| 24/24 [00:02<00:00,  8.18it/s]


train_loss=1.4958, train_accuracy=44.035%


(val) Epoch=31, loss=1.528: 100%|██████████| 5/5 [00:00<00:00,  9.16it/s]

test_loss=1.5282, test_accuracy=44.240%
Epoch: 32



(train) Epoch=32, lr=0.0648 loss=1.489: 100%|██████████| 24/24 [00:03<00:00,  7.73it/s]

train_loss=1.4894, train_accuracy=44.442%



(val) Epoch=32, loss=1.540: 100%|██████████| 5/5 [00:00<00:00,  8.78it/s]

test_loss=1.5397, test_accuracy=42.190%
Epoch: 33



(train) Epoch=33, lr=0.0580 loss=1.458: 100%|██████████| 24/24 [00:03<00:00,  7.60it/s]

train_loss=1.4581, train_accuracy=45.740%



(val) Epoch=33, loss=1.453: 100%|██████████| 5/5 [00:00<00:00,  9.17it/s]

test_loss=1.4530, test_accuracy=45.570%
Epoch: 34



(train) Epoch=34, lr=0.0515 loss=1.438: 100%|██████████| 24/24 [00:03<00:00,  7.90it/s]

train_loss=1.4381, train_accuracy=46.417%



(val) Epoch=34, loss=1.534: 100%|██████████| 5/5 [00:00<00:00,  9.26it/s]

test_loss=1.5340, test_accuracy=44.210%
Epoch: 35



(train) Epoch=35, lr=0.0453 loss=1.427: 100%|██████████| 24/24 [00:03<00:00,  7.33it/s]

train_loss=1.4266, train_accuracy=47.070%



(val) Epoch=35, loss=1.410: 100%|██████████| 5/5 [00:00<00:00,  9.56it/s]

test_loss=1.4100, test_accuracy=46.620%
Epoch: 36



(train) Epoch=36, lr=0.0394 loss=1.410: 100%|██████████| 24/24 [00:03<00:00,  7.75it/s]

train_loss=1.4104, train_accuracy=47.618%



(val) Epoch=36, loss=1.423: 100%|██████████| 5/5 [00:00<00:00,  9.28it/s]

test_loss=1.4229, test_accuracy=46.570%
Epoch: 37



(train) Epoch=37, lr=0.0339 loss=1.402: 100%|██████████| 24/24 [00:03<00:00,  7.99it/s]

train_loss=1.4022, train_accuracy=47.764%



(val) Epoch=37, loss=1.443: 100%|██████████| 5/5 [00:00<00:00,  9.18it/s]

test_loss=1.4428, test_accuracy=46.710%
Epoch: 38



(train) Epoch=38, lr=0.0287 loss=1.384: 100%|██████████| 24/24 [00:02<00:00,  8.09it/s]

train_loss=1.3844, train_accuracy=48.710%



(val) Epoch=38, loss=1.437: 100%|██████████| 5/5 [00:00<00:00,  9.27it/s]

test_loss=1.4368, test_accuracy=47.680%
Epoch: 39



(train) Epoch=39, lr=0.0239 loss=1.376: 100%|██████████| 24/24 [00:03<00:00,  7.60it/s]

train_loss=1.3757, train_accuracy=49.184%



(val) Epoch=39, loss=1.402: 100%|██████████| 5/5 [00:00<00:00,  8.87it/s]

test_loss=1.4021, test_accuracy=47.900%
Epoch: 40



(train) Epoch=40, lr=0.0195 loss=1.358: 100%|██████████| 24/24 [00:02<00:00,  8.02it/s]

train_loss=1.3579, train_accuracy=49.522%



(val) Epoch=40, loss=1.321: 100%|██████████| 5/5 [00:00<00:00,  9.68it/s]

test_loss=1.3211, test_accuracy=51.530%
Epoch: 41



(train) Epoch=41, lr=0.0155 loss=1.349: 100%|██████████| 24/24 [00:03<00:00,  7.92it/s]

train_loss=1.3488, train_accuracy=49.986%



(val) Epoch=41, loss=1.434: 100%|██████████| 5/5 [00:00<00:00,  9.61it/s]

test_loss=1.4337, test_accuracy=47.850%
Epoch: 42



(train) Epoch=42, lr=0.0119 loss=1.351: 100%|██████████| 24/24 [00:02<00:00,  8.03it/s]

train_loss=1.3510, train_accuracy=49.797%



(val) Epoch=42, loss=1.467: 100%|██████████| 5/5 [00:00<00:00,  9.61it/s]

test_loss=1.4670, test_accuracy=46.680%
Epoch: 43



(train) Epoch=43, lr=0.0088 loss=1.342: 100%|██████████| 24/24 [00:03<00:00,  7.95it/s]


train_loss=1.3425, train_accuracy=50.350%


(val) Epoch=43, loss=1.435: 100%|██████████| 5/5 [00:00<00:00,  9.40it/s]

test_loss=1.4354, test_accuracy=46.600%
Epoch: 44



(train) Epoch=44, lr=0.0061 loss=1.340: 100%|██████████| 24/24 [00:03<00:00,  7.92it/s]


train_loss=1.3401, train_accuracy=50.346%


(val) Epoch=44, loss=1.312: 100%|██████████| 5/5 [00:00<00:00,  9.50it/s]

test_loss=1.3122, test_accuracy=51.200%
Epoch: 45



(train) Epoch=45, lr=0.0039 loss=1.333: 100%|██████████| 24/24 [00:03<00:00,  7.92it/s]

train_loss=1.3334, train_accuracy=50.732%



(val) Epoch=45, loss=1.427: 100%|██████████| 5/5 [00:00<00:00,  9.31it/s]

test_loss=1.4265, test_accuracy=48.110%
Epoch: 46



(train) Epoch=46, lr=0.0022 loss=1.329: 100%|██████████| 24/24 [00:03<00:00,  7.95it/s]

train_loss=1.3291, train_accuracy=51.196%



(val) Epoch=46, loss=1.344: 100%|██████████| 5/5 [00:00<00:00,  9.31it/s]


test_loss=1.3441, test_accuracy=50.620%
Epoch: 47


(train) Epoch=47, lr=0.0010 loss=1.332: 100%|██████████| 24/24 [00:03<00:00,  7.99it/s]

train_loss=1.3317, train_accuracy=51.180%



(val) Epoch=47, loss=1.420: 100%|██████████| 5/5 [00:00<00:00,  9.69it/s]

test_loss=1.4198, test_accuracy=47.000%
Epoch: 48



(train) Epoch=48, lr=0.0002 loss=1.328: 100%|██████████| 24/24 [00:03<00:00,  8.00it/s]

train_loss=1.3276, train_accuracy=50.918%



(val) Epoch=48, loss=1.355: 100%|██████████| 5/5 [00:00<00:00,  9.61it/s]

test_loss=1.3547, test_accuracy=49.960%
Epoch: 49



(train) Epoch=49, lr=0.0000 loss=1.322: 100%|██████████| 24/24 [00:03<00:00,  7.95it/s]

train_loss=1.3220, train_accuracy=51.404%



(val) Epoch=49, loss=1.297: 100%|██████████| 5/5 [00:00<00:00,  9.53it/s]

test_loss=1.2969, test_accuracy=51.580%





### Save supernet weights

In [22]:
torch.save(supernet.state_dict(), 'supernet.pth')

## Neural Architecture Search - Random Search.

In [None]:
def random_search(
        trained_supernet: nn.Module,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        device: torch.device,
        n_architectures_to_test: int,
        target_latency: float,
):
    results = []
    
    for _ in tqdm(range(n_architectures_to_test), desc="Random Search"):
        trained_supernet.sample_random_architecture()
        accuracy = evaluate_accuracy(trained_supernet, val_dataloader, device)
        latency = estimate_latency(trained_supernet)
        
        if latency <= target_latency:
            current_architecture = [
                block.active_index.item() 
                for block in trained_supernet.search_blocks
            ]
            results.append((accuracy, current_architecture))
    
    # sort
    if not results:
        return 0.0, []
    
    results.sort(key=lambda x: x[0], reverse=True)
    best_accuracy, best_architecture = results[0]
    
    return best_accuracy, best_architecture

In [None]:
n_architectures_to_test = 100
target_latency = 30 * 1e6

accuracy, best_architecture = random_search(
    supernet, train_dataloader, test_dataloader, device, n_architectures_to_test, target_latency,
)
print(f'best architecture: {best_architecture} (test_accuracy={accuracy:.3%})')

## Train found architecture from scratch

Go to `Create your baseline model` section.

Change model definition to this to build unitialized best architecture from the search space:

```python
model = supernet18(num_classes=10, zero_init_residual=True, channel_multipliers=channel_multipliers)
model.to(device=device)
model.sample(best_architecture)
```

After this, simply run all the cells in that section.