## Resolution search

This notebook describes the basic steps you need to optimize an architecture with fixed latency and search the best input resolution using NAS framework.

### Main chapters of this notebook:
1. Setup the environment
1. Prepare dataset and create dataloaders
1. Create model and move it into search space
1. Pretrain constructed search space on different resolutions
1. Search the best architecture and resolution
1. Tune model with the best architecture on the best resolution

## Setup the environment
First, let's set up the environment and make some common imports.

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# You may need to uncomment and change this variable to match free GPU index
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import sys

sys.path.append('../')

from pathlib import Path

import torch
import torch.nn as nn

from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import RAdam

from enot.latency import current_latency
from enot.models import SearchSpaceModel
from enot.models.mobilenet import build_mobilenet
from enot.optimize import FixedLatencySearchOptimizer
from enot.optimize import PretrainOptimizer

from enot.experimental.resolution_search import ConstantResolutionStrategy
from enot.experimental.resolution_search import PretrainResolutionStrategy
from enot.experimental.resolution_search import ResolutionSearcherWithFixedLatencyIterator

from tutorial_utils.train import accuracy
from tutorial_utils.train import WarmupScheduler

from tutorial_utils.checkpoints import download_resolution_search_pretrain_checkpoint
from tutorial_utils.dataset import create_imagenette_dataloaders

### In the following cell we setup all necessary dirs

* `HOME_DIR` - experiments home directory
* `DATASETS_DIR` - root directory for datasets (imagenette2, ...)
* `PROJECT_DIR` - project directory to save training logs, checkpoints, ...

In [None]:
HOME_DIR = Path.home() / '.optimization_experiments'
DATASETS_DIR = HOME_DIR / 'datasets'
PROJECT_DIR = HOME_DIR / 'classification_resolution_search'

HOME_DIR.mkdir(exist_ok=True)
DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

## Prepare dataset and create dataloaders

In [None]:
dataloaders = create_imagenette_dataloaders(
    dataset_root_dir=DATASETS_DIR,
    project_dir=PROJECT_DIR,
    input_size=(224, 224),
    batch_size=32,
)

## Create model and move it into search space

In [None]:
# Search space will have these ops as choose options in each layer.
# Short format for operations is 'Name_param1=value1_param2=value2...'.
# MIB is a MNv2 inverted bottleneck, k is a kernel size for depthwise
# convolution, and t is an expansion ratio coefficient.
# See more in-depth info in "Tutorial - adding custom operations".
SEARCH_OPS = [
    'MIB_k=3_t=6',
    'MIB_k=5_t=6',
    'MIB_k=7_t=6',
    'conv1x1-skip',
]

# build model
model = build_mobilenet(
    search_ops=SEARCH_OPS,
    num_classes=10,
    blocks_out_channels=[24, 32, 64, 96, 160, 320],
    blocks_count=[2, 2, 2, 1, 2, 1],
    blocks_stride=[2, 2, 2, 1, 2, 1],
)

# move model to search space
search_space = SearchSpaceModel(model).cuda()

## Pretrain constructed search space
Pretrain phase in the case of resolution search is similar to the regular pretrain procedure. You can find the detailed description of the regular pretrain procedure in <span style="color:green;white-space:nowrap">***1. Tutorial - getting started***</span>.

These are the required extra steps for resolution search:
1. Create `PretrainResolutionStrategy` iterator object.
1. Wrap your train and validation dataloaders with this iterator to generate resized images.

This makes your pretrain procedure resiliant to different resolution values.

**IMPORTANT: `PretrainResolutionStrategy.__call__(...)` returns an iterator, you must apply this strategy on every epoch to train_loader and validation_loader.**

In [None]:
N_EPOCHS = 3
N_WARMUP_EPOCHS = 1

# Create PretrainResolutionStrategy for resolution range [min_resolution, max_resolution].
pretrain_resolution_strategy = PretrainResolutionStrategy(
    min_resolution=100,
    max_resolution=300,
)

constant_resolution_strategy = ConstantResolutionStrategy((300 + 100) // 2)

train_loader = dataloaders['pretrain_train_dataloader']
validation_loader = dataloaders['pretrain_validation_dataloader']

metric_function = accuracy
loss_function = nn.CrossEntropyLoss().cuda()

# Using `search_space.model_parameters()` as optimizable variables.
optimizer = SGD(params=search_space.model_parameters(), lr=0.06, momentum=0.9, weight_decay=1e-4)
pretrain_optimizer = PretrainOptimizer(search_space=search_space, optimizer=optimizer)

len_train_loader = len(train_loader)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train_loader * N_EPOCHS, eta_min=1e-8)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train_loader * N_WARMUP_EPOCHS)

for epoch in range(N_EPOCHS):
    print(f'EPOCH #{epoch}')

    search_space.train()
    train_metrics_acc = {
        'loss': 0.0,
        'accuracy': 0.0,
        'n': 0,
    }
    # apply resolution strategy and start iteration
    for inputs, labels in pretrain_resolution_strategy(train_loader):
        if not search_space.output_distribution_optimization_enabled:
            search_space.initialize_output_distribution_optimization(inputs)

        pretrain_optimizer.zero_grad()

        def closure():
            pred_labels = search_space(inputs)
            batch_loss = loss_function(pred_labels, labels)
            batch_loss.backward()
            batch_metric = metric_function(pred_labels, labels)

            train_metrics_acc['loss'] += batch_loss.item()
            train_metrics_acc['accuracy'] += batch_metric.item()
            train_metrics_acc['n'] += 1

        pretrain_optimizer.step(closure)
        if scheduler is not None:
            scheduler.step()

    train_loss = train_metrics_acc['loss'] / train_metrics_acc['n']
    train_accuracy = train_metrics_acc['accuracy'] / train_metrics_acc['n']

    print('train metrics:')
    print('  loss:', train_loss)
    print('  accuracy:', train_accuracy)

    arch_to_test = [0] * len(search_space.search_variants_containers)
    test_model = search_space.get_network_by_indexes(arch_to_test)
    test_model.eval()

    validation_loss = 0
    validation_accuracy = 0
    with torch.no_grad():
        # apply resolution strategy and start iteration
        for inputs, labels in constant_resolution_strategy(validation_loader):
            pred_labels = test_model(inputs)
            batch_loss = loss_function(pred_labels, labels)
            batch_metric = metric_function(pred_labels, labels)

            validation_loss += batch_loss.item()
            validation_accuracy += batch_metric.item()

    n = len(validation_loader)
    validation_loss /= n
    validation_accuracy /= n

    print('validation metrics:')
    print('  loss:', validation_loss)
    print('  accuracy:', validation_accuracy)

    print()

In [None]:
# We pretrained search space for 3 epochs in this example. In this cell, we are downloading
# search space checkpoint, pretrained for 100 epochs (for demonstration purposes).

checkpoint_path = PROJECT_DIR / 'resolution_search_pretrain_checkpoint.pth'
download_resolution_search_pretrain_checkpoint(checkpoint_path)

search_space.load_state_dict(
    torch.load(checkpoint_path)['model'],
)

## Search the best architecture and resolution
Search phase in case of resolution search is simular to regular pretrain procedure. You can find the detailed description of the regular search procedure in <span style="color:green;white-space:nowrap">***1. Tutorial - getting started***</span>.

These are the required extra steps for resolution search:
1. Create `FixedLatencySearchOptimizer` instance as described in <span style="color:green;white-space:nowrap">***6. Tutorial - search with the specified latency***</span>.
1. Create `ResolutionSearcherWithFixedLatencyIterator` iterator object and provide latency search bounds.
1. Iterate over `ResolutionSearcherWithFixedLatencyIterator` via `for` loop. This performs the following:
    1. Iterator resets your search space states and initializes latency for the current resolution;
    1. Resets your search optimizer state;
    1. Returns (resolution, resolution_strategy) pair, where the first item is the resolution for the current search step, and the second is the transform which you should apply over your dataloaders.
1. You must follow the regular search phase steps for each iteration over `ResolutionSearcherWithFixedLatencyIterator`.
1. Before each next iteration, you should send your last validation target metric to `set_resolution_target_metric` function of `ResolutionSearcherWithFixedLatencyIterator`.

**IMPORTANT: `resolution_strategy.__call__(...)` returns an iterator, you must apply this strategy on every epoch to train_loader and validation_loader.**

In [None]:
# This parameters are chosen for illustrative purposes.
# Uncomment lines below to reach better performance of searched model.
N_EPOCHS = 2
WARMUP_STEPS = 1
resolution_tolerance = 32

# N_EPOCHS = 50
# WARMUP_STEPS = 5
# resolution_tolerance = 8

target_latency = 100.0

metric_function = accuracy
loss_function = nn.CrossEntropyLoss().cuda()

train_loader = dataloaders['search_train_dataloader']
validation_loader = dataloaders['search_validation_dataloader']

optimizer = RAdam(search_space.architecture_parameters(), lr=0.01)
len_train_loader = len(train_loader)

search_optimizer = FixedLatencySearchOptimizer(
    search_space,
    optimizer,
    max_latency_value=target_latency,
)

# Create SearchResolutionWithFixedLatencyIterator object
# for fixed resolution range and fixed target latency
search_resolution_iter = ResolutionSearcherWithFixedLatencyIterator(
    search_optimizer=search_optimizer,
    dataloader=train_loader,
    latency_type='mmac.fvcore',
    min_resolution=100,
    max_resolution=300,
    resolution_tolerance=resolution_tolerance,
)

for r_step, (resolution, resolution_strategy) in enumerate(search_resolution_iter):
    print(f'RESOLUTION_SEARCH_STEP #{r_step}')
    print(f'CURRENT RESOLUTION: {resolution}')

    # We should not re-create search optimizer as it's state is controlled by the resolution searcher.

    # We should re-create scheduler as it is not updated by the resolution searcher.
    scheduler = CosineAnnealingLR(optimizer, T_max=len_train_loader * N_EPOCHS, eta_min=1e-8)
    scheduler = WarmupScheduler(scheduler, warmup_steps=len_train_loader * WARMUP_STEPS)

    validation_accuracy, latency = None, None
    for epoch in range(N_EPOCHS):
        print(f'EPOCH #{epoch}')

        search_space.train()
        train_metrics_acc = {
            'loss': 0.0,
            'accuracy': 0.0,
            'n': 0,
        }
        # Apply resolution strategy and iterate over samples.
        for inputs, labels in resolution_strategy(train_loader):
            search_optimizer.zero_grad()

            def closure():
                pred_labels = search_space(inputs)
                batch_loss = loss_function(pred_labels, labels)
                batch_loss = search_optimizer.modify_loss(batch_loss)

                batch_loss.backward()
                batch_metric = metric_function(pred_labels, labels)

                train_metrics_acc['loss'] += batch_loss.item()
                train_metrics_acc['accuracy'] += batch_metric.item()
                train_metrics_acc['n'] += 1

            search_optimizer.step(closure)
            scheduler.step()

        train_loss = train_metrics_acc['loss'] / train_metrics_acc['n']
        train_accuracy = train_metrics_acc['accuracy'] / train_metrics_acc['n']

    search_space.eval()
    # prepare_validation_model function samples an optimal architecture
    # in the search space and prepares it to the validation procedure.
    search_optimizer.prepare_validation_model()

    validation_loss = 0
    validation_accuracy = 0
    with torch.no_grad():
        # Apply resolution strategy and iterate over samples.
        for inputs, labels in resolution_strategy(validation_loader):
            pred_labels = search_space(inputs)
            batch_loss = loss_function(pred_labels, labels)
            batch_metric = metric_function(pred_labels, labels)

            validation_loss += batch_loss.item()
            validation_accuracy += batch_metric.item()

    n = len(validation_loader)
    validation_loss /= n
    validation_accuracy /= n

    # Getting latency of the current sampled architecure.
    # The current architecture is the best one (for this epoch)
    # as it is sampled by prepare_validation_model function/
    latency = current_latency(search_space)

    print('train accuracy:', train_accuracy)
    print('train loss:', train_loss)
    print('validation accuracy:', validation_accuracy)
    print('validation loss:', validation_loss)
    print('latency:', latency)
    print()

    # Apply the obtained target metric (validation accuracy) to the resolution search object.
    search_resolution_iter.set_resolution_target_metric(validation_accuracy)

In [None]:
# Now our search space is initialized with the best state
# (see ResolutionSearcherWithFixedLatencyIterator documentation).

best_resolution = search_resolution_iter.best_resolution
best_architecture = search_resolution_iter.best_architecture

print(f'Best resolution is {best_resolution}')
print(f'Best architecture is {best_architecture}')

## Tune model with the best architecture
Now we take our best architecture from search space, and create a regular model using it. Then we run finetune procedure (usual training loop).

In [None]:
# get regular model with the best architecture
best_model = search_space.get_network_by_indexes(best_architecture).cuda()

In [None]:
N_EPOCHS = 5

optimizer = SGD(best_model.parameters(), lr=2e-4, momentum=0.9, weight_decay=1e-4)
scheduler = None
metric_function = accuracy
loss_function = nn.CrossEntropyLoss().cuda()

train_loader = dataloaders['tune_train_dataloader']
validation_loader = dataloaders['tune_validation_dataloader']

# ConstantResolutionStrategy is a wrapper over your dataloader to produce images of fixed resolution.
# You can use it to reproduce data processing pipeline from resolution search procedure.
to_best_resolution = ConstantResolutionStrategy(resolution=best_resolution)

for epoch in range(N_EPOCHS):
    print(f'EPOCH #{epoch}')

    best_model.train()
    train_loss = 0.0
    train_accuracy = 0.0
    for inputs, labels in to_best_resolution(train_loader):
        optimizer.zero_grad()

        pred_labels = best_model(inputs)
        batch_loss = loss_function(pred_labels, labels)
        batch_loss.backward()
        batch_metric = metric_function(pred_labels, labels)

        train_loss += batch_loss.item()
        train_accuracy += batch_metric.item()

        optimizer.step()
        if scheduler is not None:
            scheduler.step()

    n = len(train_loader)
    train_loss /= n
    train_accuracy /= n

    print('train metrics:')
    print('  loss:', train_loss)
    print('  accuracy:', train_accuracy)

    best_model.eval()
    validation_loss = 0
    validation_accuracy = 0
    with torch.no_grad():
        for inputs, labels in to_best_resolution(validation_loader):
            pred_labels = best_model(inputs)
            batch_loss = loss_function(pred_labels, labels)
            batch_metric = metric_function(pred_labels, labels)

            validation_loss += batch_loss.item()
            validation_accuracy += batch_metric.item()

    n = len(validation_loader)
    validation_loss /= n
    validation_accuracy /= n

    print('validation metrics:')
    print('  loss:', validation_loss)
    print('  accuracy:', validation_accuracy)

    print()