## Using latency optimization with custom operations

This notebook describes the additional steps required to enable latency optimization for custom models.

### Main chapters of this notebook:
1. Setup the environment
1. Prepare dataset and create dataloaders
1. Adding custom modules (head/stem/custom_operation) with latency to use them in search space
1. Build model with custom operations
1. Check pretrain, search and tune phases
    1. Pretrain search space
    1. Search without latency loss
    1. Search with latency loss
    1. Tune found architecture

## Setup the environment

First, let's set up the environment and make some common imports.

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# You may need to uncomment and change this variable to match free GPU index
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
from pathlib import Path

import torch
import torch.nn as nn

from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_optimizer import RAdam

from pkg_resources import parse_version, get_distribution

if parse_version(get_distribution('torchvision').version) < parse_version('0.9'):
    from torchvision.models.mobilenet import ConvBNReLU
else:
    from torchvision.models.mobilenetv2 import ConvBNReLU

from enot.models import SearchSpaceModel
from enot.models.operations import SearchableMobileInvertedBottleneck
from enot.models.operations import SearchVariantsContainer
from enot.optimize import EnotPretrainOptimizer
from enot.optimize import EnotSearchOptimizer
from enot.latency import conv_mac_count
from enot.latency import LatencyMixin
from enot.latency import initialize_latency
from enot.latency import best_arch_latency
from enot.utils.common import input_hw

from tutorial_utils.train import accuracy
from tutorial_utils.train import WarmupScheduler

from tutorial_utils.dataset import create_imagenette_dataloaders
from tutorial_utils.phases import tutorial_pretrain_loop
from tutorial_utils.phases import tutorial_search_loop
from tutorial_utils.phases import tutorial_train_loop

## Prepare dataset and create dataloaders

In [None]:
ENOT_HOME_DIR = Path.home() / '.enot'
ENOT_DATASETS_DIR = ENOT_HOME_DIR / 'datasets'
PROJECT_DIR = ENOT_HOME_DIR / 'using_latency_optimization'

ENOT_HOME_DIR.mkdir(exist_ok=True)
ENOT_DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

In [None]:
dataloaders = create_imagenette_dataloaders(
    dataset_root_dir=ENOT_DATASETS_DIR, 
    project_dir=PROJECT_DIR,
    input_size=(224, 224),
    batch_size=32,
    imagenette_kind='imagenette2-320',
)

## Adding custom modules (head/stem/custom_operation) with latency to use them in search space

To add latency support for your module, you need to implement latency calculation.

Adding latency calculation is done in two steps:
1. Subclass your class from `enot.latency.LatencyMixin`. `LatencyMixin` is a part of latency calculation mechanism in `SearchSpaceModel`. By using this mixin, user can define methods of latency calculation for custom modules (see the next step).
2. Add a method with a signature `latency_<name>(self, inputs) -> float`, which will calculate latency of each added operation.

At this moment, only `'mmac'` (millions of multiply-accumulates) latency type is supported.

#### Define custom operation with latency

In [None]:
class MyOperation(nn.Module, LatencyMixin):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        kernel_size=3,
        padding=None,
        use_skip_connection=True,
    ):
        super().__init__()

        if padding is None:
            padding = (kernel_size - 1) // 2

        self.stride = stride
        self.kernel_size = kernel_size
        self.padding = padding
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_skip_connection = use_skip_connection and in_channels == out_channels and stride == 1

        self.operation = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
            ),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        out = self.operation(x)
        if self.use_skip_connection:
            out = out + x
        return out

    def latency_mmac(self, inputs):
        """Calculate number of millions of multiply-accumulate operations"""
        spatial_size = input_hw(inputs)  # input_hw simply unpacks the spatial size of inputs
        
        def num_conv_steps(size, padding, kernel, stride):
            size = size + 2 * padding - (kernel - 1)
            return (size + stride - 1) // stride

        h, w = spatial_size
        h_steps = num_conv_steps(h, self.padding, self.kernel_size, self.stride)
        w_steps = num_conv_steps(w, self.padding, self.kernel_size, self.stride)

        mmac = h_steps * w_steps * self.kernel_size**2 * self.in_channels * self.out_channels
        mmac /= 10 ** 6

        return mmac

#### Define head and stem for your model

In [None]:
class MyStem(ConvBNReLU, LatencyMixin):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            groups=1,
            norm_layer=None,
    ):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups

        super().__init__(in_channels, out_channels, kernel_size, stride, groups, norm_layer)

    def latency_mmac(self, inputs):
        """Calculate millions of multiply-accumulates"""
        spatial_size = input_hw(inputs)  # input_hw simply unpacks the spatial size of inputs
        mmac, _ = conv_mac_count(
            spatial_size=spatial_size,
            kernel_size=self.kernel_size,
            stride=self.stride,
            in_channels=self.in_channels,
            padding=0,
            out_channels=self.out_channels,
            groups=self.groups,
        )

        return mmac


class MyHead(nn.Sequential, LatencyMixin):
    def __init__(
        self, 
        in_channels,
        hidden_channels,
        num_classes,
        dropout_rate=0.2,
    ):
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_classes = num_classes

        super().__init__(
            ConvBNReLU(in_channels, hidden_channels, kernel_size=1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_channels, num_classes),
        )

    def latency_mmac(self, inputs):
        """Calculate millions of multiply-accumulates"""
        spatial_size = input_hw(inputs)  # input_hw simply unpacks the spatial size of inputs
        mmac, (h_out, w_out) = conv_mac_count(spatial_size, 1, 1, self.in_channels, 0, self.hidden_channels)
        mmac += h_out * w_out * self.hidden_channels / 10 ** 6
        mmac += self.hidden_channels * self.num_classes / 10 ** 6

        return mmac

## Build model with custom operations

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.stem = MyStem(
            in_channels=3, 
            out_channels=16, 
            stride=2,
        )
        self.body = nn.ModuleList([
            self.build_search_variants(16, 24, 2),
            self.build_search_variants(24, 24, 1),
            self.build_search_variants(24, 32, 2),
            self.build_search_variants(32, 32, 1),
            self.build_search_variants(32, 64, 2),
            self.build_search_variants(64, 64, 1),
            self.build_search_variants(64, 96, 1),
            self.build_search_variants(96, 160, 2),
            self.build_search_variants(160, 160, 1),
            self.build_search_variants(160, 320, 1),
        ])
        self.head = MyHead(
            in_channels=320,
            hidden_channels=1280,
            num_classes=10,
        )

    @staticmethod
    def build_search_variants(in_channels, out_channels, stride):
        return SearchVariantsContainer([
            SearchableMobileInvertedBottleneck(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=stride,
                expand_ratio=6,
            ),
            SearchableMobileInvertedBottleneck(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=5,
                stride=stride,
                expand_ratio=6,
            ),
            MyOperation(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=7,
                stride=stride,
            ),
        ])

    def forward(self, x):
        x = self.stem(x)

        for block in self.body:
            x = block(x)

        x = self.head(x)

        return x

In [None]:
model = MyModel()

# move model to search space
search_space = SearchSpaceModel(model).cuda()

## Check pretrain, search and tune phases

Let's check that everything works.

In this tutorial we use the same pretrain/search/train loops as in <span style="color:green;white-space:nowrap">***1. Tutorial - getting started***</span>.

**IMPORTANT**:<br>
To turn on latency optimization, you must set `latency_loss_weight` (> 0) parameter for `tutorial_search_loop`.

### Pretrain search space

In [None]:
N_EPOCHS = 3
N_WARMUP_EPOCHS = 1
len_train = len(dataloaders['pretrain_train_dataloader'])

optimizer = SGD(params=search_space.model_parameters(), lr=0.06, momentum=0.9, weight_decay=1e-4)
enot_optimizer = EnotPretrainOptimizer(search_space=search_space, optimizer=optimizer)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train*N_EPOCHS, eta_min=1e-8)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train*N_WARMUP_EPOCHS)
loss_function = nn.CrossEntropyLoss().cuda()

tutorial_pretrain_loop(
    epochs=N_EPOCHS,
    search_space=search_space,
    enot_optimizer=enot_optimizer,
    metric_function=accuracy,
    loss_function=loss_function,
    train_loader=dataloaders['pretrain_train_dataloader'],
    validation_loader=dataloaders['pretrain_validation_dataloader'],
    scheduler=scheduler,
)

### Search without latency loss

In [None]:
optimizer = RAdam(search_space.architecture_parameters(), lr=0.01)
enot_optimizer = EnotSearchOptimizer(search_space=search_space, optimizer=optimizer)
# backup weights for next cell
torch.save({'model': search_space.state_dict()}, PROJECT_DIR / 'search_space_backup.pth')

tutorial_search_loop(
    epochs=5,
    search_space=search_space,
    enot_optimizer=enot_optimizer,
    metric_function=accuracy,
    loss_function=loss_function,
    train_loader=dataloaders['search_train_dataloader'],
    validation_loader=dataloaders['search_validation_dataloader'],
    latency_loss_weight=0,
    latency_type=None,
    scheduler=None,
)

In [None]:
search_space.sample_best_arch()
search_space.eval()

sample_batch, _ = next(iter(dataloaders['search_train_dataloader']))
initialize_latency('mmac', search_space, (sample_batch, ))
latency_0 = best_arch_latency(search_space)
print('best architecture latency (latency_loss_weight == 0) =', latency_0)

### Search with latency loss

In [None]:
# restore weights
checkpoint_data = torch.load(PROJECT_DIR / 'search_space_backup.pth')
search_space.load_state_dict(checkpoint_data['model'])

optimizer = RAdam(search_space.architecture_parameters(), lr=0.01)
enot_optimizer = EnotSearchOptimizer(search_space=search_space, optimizer=optimizer)

tutorial_search_loop(
    epochs=5,
    search_space=search_space,
    enot_optimizer=enot_optimizer,
    metric_function=accuracy,
    loss_function=loss_function,
    train_loader=dataloaders['search_train_dataloader'],
    validation_loader=dataloaders['search_validation_dataloader'],
    latency_loss_weight=2e-3,
    latency_type='mmac',
    scheduler=None,
)

In [None]:
search_space.sample_best_arch()
search_space.eval()

sample_batch, _ = next(iter(dataloaders['search_train_dataloader']))
initialize_latency('mmac', search_space, (sample_batch, ))
latency_1 = best_arch_latency(search_space)
print('best architecture latency (latency_loss_weight > 0) =', latency_1)
print('best architecture latency (latency_loss_weight == 0) =', latency_0)

#### We should get an architecture with lower latency

### Tune best architecture

In [None]:
# get regular model with the best architecture
best_model = search_space.get_network_with_best_arch().cuda()

In [None]:
optimizer = RAdam(best_model.parameters(), lr=1e-3, weight_decay=1e-4)

tutorial_train_loop(
    epochs=5,
    model=best_model,
    optimizer=optimizer,
    metric_function=accuracy,
    loss_function=loss_function,
    train_loader=dataloaders['tune_train_dataloader'],
    validation_loader=dataloaders['tune_validation_dataloader'],
    scheduler=None,
)