## Resolution search with ENOT

This notebook describes the basic steps you need to optimize an architecture with fixed latency and search best input resolution using ENOT framework.

### Main chapters of this notebook:
1. Setup the environment
1. Prepare dataset and create dataloaders
1. Create model and move it into search space
1. Pretrain constructed search space on different resolutions
1. Search best architecture and resolution
1. Tune model with the best architecture on the best resolution

## Setup the environment
First, let's set up the environment and make some common imports.

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# You may need to change this variable to match free GPU index
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn

from torch.optim import Adam
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_optimizer import RAdam

from enot.models import SearchSpaceModel
from enot.models.mobilenet import build_mobilenet
from enot.optimize import EnotPretrainOptimizer
from enot.optimize import EnotSearchOptimizer
from enot.latency import best_arch_latency

from enot.experimental.resolution_search import ConstantResolutionStrategy
from enot.experimental.resolution_search import PretrainResolutionStrategy
from enot.experimental.resolution_search import SearchResolutionWithFixedLatencyIterator

from tutorial_utils.train_utils import accuracy
from tutorial_utils.train_utils import WarmupScheduler

from tutorial_utils.checkpoints import download_resolution_search_pretrain_checkpoint
from tutorial_utils.dataset import create_imagenette_dataloaders

from tutorial_utils.phases import tutorial_pretrain_loop
from tutorial_utils.phases import tutorial_search_loop
from tutorial_utils.phases import tutorial_train_loop

### In the following cell we setup all necessary dirs

* `ENOT_HOME_DIR` - ENOT framework home directory
* `ENOT_DATASETS_DIR` - root directory for datasets (imagenette2, ...)
* `PROJECT_DIR` - project directory to save training logs, checkpoints, ...

In [None]:
ENOT_HOME_DIR = Path.home() / '.enot'
ENOT_DATASETS_DIR = ENOT_HOME_DIR / 'datasets'
PROJECT_DIR = ENOT_HOME_DIR / 'classification_resolution_search'

ENOT_HOME_DIR.mkdir(exist_ok=True)
ENOT_DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

## Prepare dataset and create dataloaders

In [None]:
dataloaders = create_imagenette_dataloaders(
    dataset_root_dir=ENOT_DATASETS_DIR, 
    project_dir=PROJECT_DIR,
    input_size=(224, 224),
    batch_size=32,
)

## Create model and move it into search space

In [None]:
# Search space will have these ops as choose options in each layer.
# Short format for operations is 'Name_param1=value1_param2=value2...'.
# MIB is a MNv2 inverted bottleneck, k is a kernel size for depthwise
# convolution, and t is an expansion ratio coefficient.
# See more in-depth info in "Tutorial - adding custom operations".
SEARCH_OPS = [
    'MIB_k=3_t=6',
    'MIB_k=5_t=6',
    'MIB_k=7_t=6',
    'conv1x1-skip',
]

# build model
model = build_mobilenet(
    search_ops=SEARCH_OPS,
    num_classes=10,
    blocks_out_channels=[24, 32, 64, 96, 160, 320],
    blocks_count=[2, 2, 2, 1, 2, 1],
    blocks_stride=[2, 2, 2, 1, 2, 1],
)

# move model to search space
search_space = SearchSpaceModel(model).cuda()

## Pretrain constructed search space
Pretrain phase in case of resolution search is simalar to regular pretrain procedure. Detailed description of regular
pretrain you can find in 'Tutorial - getting started'.

Extra steps for resolution search:
1. You must create `PretrainResolutionStrategy` object.
1. You must use dataloader iterator generated by PretrainResolutionStrategy object.

**IMPORTANT: `PretrainResolutionStrategy.__call__(...)` returns iterator, so you must apply strategy on every epoch using train_loader and validation_loader as parameters**

In [None]:
N_EPOCHS = 3
N_WARMUP_EPOCHS = 1

# create PretrainResolutionStrategy for resolution range [min_resolution, max_resolution]
pretrain_resolution_strategy = PretrainResolutionStrategy(
    min_resolution=100,
    max_resolution=300,
)

train_loader = dataloaders['pretrain_train_dataloader']
validation_loader = dataloaders['pretrain_validation_dataloader']

metric_function = accuracy
loss_function = nn.CrossEntropyLoss().cuda()

# using `search_space.model_parameters()` as optimizable variables
optimizer = SGD(params=search_space.model_parameters(), lr=0.06, momentum=0.9, weight_decay=1e-4)
enot_optimizer = EnotPretrainOptimizer(search_space=search_space, optimizer=optimizer)

len_train_loader = len(train_loader)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train_loader*N_EPOCHS, eta_min=1e-8)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train_loader*N_WARMUP_EPOCHS)

for epoch in range(N_EPOCHS):

    print(f'EPOCH #{epoch}')

    search_space.train()
    train_metrics_acc = {
        'loss': 0.0,
        'accuracy': 0.0,
        'n': 0,
    }
    # apply resolution strategy and start iteration 
    for inputs, labels in pretrain_resolution_strategy(train_loader):
        if not search_space.output_distribution_optimization_enabled:
            search_space.initialize_output_distribution_optimization(inputs)

        enot_optimizer.zero_grad()

        def closure():
            pred_labels = search_space(inputs)
            batch_loss = loss_function(pred_labels, labels)
            batch_loss.backward()
            batch_metric = metric_function(pred_labels, labels)

            train_metrics_acc['loss'] += batch_loss.item()
            train_metrics_acc['accuracy'] += batch_metric.item()
            train_metrics_acc['n'] += 1

        enot_optimizer.step(closure)
        if scheduler is not None:
            scheduler.step()

    train_loss = train_metrics_acc['loss'] / train_metrics_acc['n']
    train_accuracy = train_metrics_acc['accuracy'] / train_metrics_acc['n']

    print('train metrics:')
    print('  loss:', train_loss)
    print('  accuracy:', train_accuracy)

    search_space.eval()
    validation_loss = 0
    validation_accuracy = 0
    # apply resolution strategy and start iteration 
    for inputs, labels in pretrain_resolution_strategy(validation_loader):
        search_space.sample_random_arch()

        pred_labels = search_space(inputs)
        batch_loss = loss_function(pred_labels, labels)
        batch_metric = metric_function(pred_labels, labels)

        validation_loss += batch_loss.item()
        validation_accuracy += batch_metric.item()

    n = len(validation_loader)
    validation_loss /= n
    validation_accuracy /= n

    print('validation metrics:')
    print('  loss:', validation_loss)
    print('  accuracy:', validation_accuracy)

    print()

In [None]:
# We pretrained search space for 3 epochs in this example. In this cell, we are downloading
# search space checkpoint, pretrained for 100 epochs (for demonstration purposes).

checkpoint_path = PROJECT_DIR / 'resolution_search_pretrain_checkpoint.pth'
download_resolution_search_pretrain_checkpoint(checkpoint_path)

search_space.load_state_dict(
    torch.load(checkpoint_path)['model'],
)

## Search best architecture and resolution
Search phase in case of resolution search is simular to regular pretrain procedure. Detailed description of regular
search you can find in 'Tutorial - getting started'.

Extra steps for resolution search:
1. You must create `SearchResolutionWithFixedLatencyIterator` iterator object for fixed resolution range and fixed target latency.
1. On Every iteration `SearchResolutionWithFixedLatencyIterator`:
    1. Reset search space state and initialize latency for current resolution
    1. Reset optimizer state
    1. Returns `SearchResolutionStepData` object, which contains enot_optimizer, latency_loss_weight, resolution_strategy, metric_latency_cb for current search step
1. You must follow regular search phase steps using parameters returned by `SearchResolutionWithFixedLatencyIterator` object     
1. Before the next iteration you must send final values of validation accuracy and latency of the best architecture using `metric_latency_cb`

**IMPORTANT: `resolution_strategy.__call__(...)` returns iterator, so you must apply strategy on every epoch using train_loader and validation_loader as parameters**

In [None]:
N_EPOCHS = 5

metric_function = accuracy
loss_function = nn.CrossEntropyLoss().cuda()

train_loader = dataloaders['search_train_dataloader']
validation_loader = dataloaders['search_validation_dataloader']

optimizer = RAdam(search_space.architecture_parameters(), lr=0.01)

# create SearchResolutionWithFixedLatencyIterator iterator object 
# for fixed resolution range and fixed target latency
search_resolution_iter = SearchResolutionWithFixedLatencyIterator(
    search_space=search_space,
    optimizer=optimizer,
    data_loader=train_loader,
    target_latency=250.0,
    latency_tol=5.0,
    latency_type='mmac',
    min_resolution=100,
    max_resolution=300,
    resolution_tol=32,
    max_latency_loss_weight=0.01,
)

for r_step, search_step_data in enumerate(search_resolution_iter):
    print(f'RESOLUTION_SEARCH_STEP #{r_step}')
    validation_accuracy, latency = None, None
    for epoch in range(N_EPOCHS):
        print(f'EPOCH #{epoch}')

        search_space.train()
        train_metrics_acc = {
            'loss': 0.0,
            'accuracy': 0.0,
            'n': 0,
        }
        # apply resolution strategy and start iteration 
        for inputs, labels in search_step_data.resolution_strategy(train_loader):
            search_step_data.enot_optimizer.zero_grad()

            def closure():
                pred_labels = search_space(inputs)
                batch_loss = loss_function(pred_labels, labels)
                batch_loss += search_space.loss_latency_expectation * search_step_data.latency_loss_weight

                batch_loss.backward()
                batch_metric = metric_function(pred_labels, labels)

                train_metrics_acc['loss'] += batch_loss.item()
                train_metrics_acc['accuracy'] += batch_metric.item()
                train_metrics_acc['n'] += 1

            search_step_data.enot_optimizer.step(closure)

        train_loss = train_metrics_acc['loss'] / train_metrics_acc['n']
        train_accuracy = train_metrics_acc['accuracy'] / train_metrics_acc['n']

        search_space.eval()
        # selecting best architecture for validation
        search_space.sample_best_arch()

        validation_loss = 0
        validation_accuracy = 0
        # apply resolution strategy and start iteration 
        for inputs, labels in search_step_data.resolution_strategy(validation_loader):
            pred_labels = search_space(inputs)
            batch_loss = loss_function(pred_labels, labels)
            batch_metric = metric_function(pred_labels, labels)

            validation_loss += batch_loss.item()
            validation_accuracy += batch_metric.item()

        n = len(validation_loader)
        validation_loss /= n
        validation_accuracy /= n

        # getting latency of the best architecture
        latency = best_arch_latency(search_space)

    print('train accuracy:', train_accuracy)
    print('train loss:', train_loss)
    print('validation accuracy:', validation_accuracy)
    print('validation loss:', validation_loss)
    print('latency:', latency)
    print()
    
    # send final values of validation accuracy and latency of the best architecture
    search_step_data.metric_latency_cb(
        metric=validation_accuracy, 
        latency=latency,
    )

In [None]:
search_resolution_iter.load_best_arch_parameters(search_space)
best_resolution = search_resolution_iter.best_resolution

print(f'Best resolution is {best_resolution}')

## Tune model with the best architecture
Now we take our best architecture from search space, and create a regular model using it. Then we run finetune procedure (usual training loop).

In [None]:
# get regular model with the best architecture
best_model = search_space.get_network_with_best_arch().cuda()

In [None]:
N_EPOCHS = 1

optimizer = SGD(best_model.parameters(), lr=2e-4)
scheduler = None
metric_function = accuracy
loss_function = nn.CrossEntropyLoss().cuda()

train_loader = dataloaders['tune_train_dataloader']
validation_loader = dataloaders['tune_validation_dataloader']

# You can use ConstantResolutionStrategy for reproducing data processing pipeline from resolution search.
to_best_resolution = ConstantResolutionStrategy(
    resolution=best_resolution,
)

for epoch in range(N_EPOCHS):

    print(f'EPOCH #{epoch}')

    best_model.train()
    train_loss = 0.0
    train_accuracy = 0.0
    for inputs, labels in to_best_resolution(train_loader):
        optimizer.zero_grad()
            
        pred_labels = best_model(inputs)
        batch_loss = loss_function(pred_labels, labels)
        batch_loss.backward()
        batch_metric = metric_function(pred_labels, labels)
            
        train_loss += batch_loss.item()
        train_accuracy += batch_metric.item()

        optimizer.step()
        if scheduler is not None:
            scheduler.step()

    n = len(train_loader)
    train_loss /= n
    train_accuracy /= n

    print('train metrics:')
    print('  loss:', train_loss)
    print('  accuracy:', train_accuracy)
    
    best_model.eval()    
    validation_loss = 0
    validation_accuracy = 0
    for inputs, labels in to_best_resolution(validation_loader):
        pred_labels = best_model(inputs)
        batch_loss = loss_function(pred_labels, labels)
        batch_metric = metric_function(pred_labels, labels)

        validation_loss += batch_loss.item()
        validation_accuracy += batch_metric.item()
    
    n = len(validation_loader)
    validation_loss /= n
    validation_accuracy /= n
    
    print('validation metrics:')
    print('  loss:', validation_loss)
    print('  accuracy:', validation_accuracy)
    
    print()