# Adding custom operations to ENOT

In this notebook we describe how you can implement your own operations to use with neural architecture search.

### Notebook consists of next main stages:
1. Setup the environment
1. Add a custom operation to use with search space
1. Build model with custom operation
1. Check pretrain search and tune phases

## 1. Setup the environment
First, let's set up the environment and common imports.

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
# You should change to free GPU
os.environ['CUDA_VISIBLE_DEVICES']='0'

In [None]:
from pathlib import Path

import torch
import torch.nn as nn

from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_optimizer import RAdam

from enot.models import BaseSearchableOperation
from enot.models import register_searchable_op
from enot.models import SearchSpaceModel
from enot.models.mobilenet import build_mobilenet
from enot.phases import pretrain
from enot.phases import search
from enot.phases import train

from enot_utils.metric_utils import accuracy
from enot_utils.schedulers import WarmupScheduler

from tutorial_utils.dataset import create_imagenette_dataloaders

## 2. Add a custom operation to use with search space

We have provided a bunch of pre-defined operations, but you can also implement your own.

You must follow next steps to create new searchable operation:
1. Create new operation with base class `BaseSearchableOperation`
    * The last layer of your operation should be Batch Normalization
1. Initialize base class in you `__init__`. `BaseSearchableOperation` operation takes 4 required arguments:
    * in_channels - number of channels in the input data
    * out_channels - number of channels produced by the operation
    * use_skip_connection - if you wanna apply skip connection set `True`, and `False` otherwise
1. Implement abstract method `get_last_batch_norm`, which returns last Batch Normalization layer
1. Implement abstract method `replace_last_batch_norm`, which replace last Batch Normalization layer by new one
1. Implement abstract method `operation_forward`, which defines the computation performed at every call like common "forward", but ignore skip connection logic (skip connection logic implemented in "forward" method of base class) 

If you wanna use custom operation with our model builders (`build_mobilenet`):
1. They must be registered in our framework using `@register_searchable_op(name)`
1. Operation must accept 4 required argument: in_channels, out_channels, stride, use_skip_connection. All other arguments must have default value or must be added in short config (see next paragraph).
1. To use the short config format you must provide parameter parsing rules E. g. if you want to write `MyOp_k=3_t=6` instead of `{"op_type": "MyOp", "kernel_size": 3, "expand_ratio": 6.0}`, you need the following rules:
```
    {
      'k': ('kernel_size', int),
      't': ('expand_ratio', float)
    }
```

In [None]:
activations = {
    'relu': nn.ReLU,
    'sigmoid': nn.Sigmoid
}
# Define short parameter parsing rules
# Format: {short_param_name: (original_param_name, parser)}
short_args = {
    'k': ('kernel_size', int),
    'a': ('activation', lambda x: activations[x])
}


@register_searchable_op('MyOp', short_args)
class MyOperation(BaseSearchableOperation):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        kernel_size=3,
        activation=nn.ReLU,
        padding=None,   
        use_skip_connection=True
    ):
        super().__init__(
            in_channels,
            out_channels,
            use_skip_connection,
        )
            
        if padding is None:
            padding = (kernel_size - 1) // 2
        
        self.operation = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels, 
                out_channels=out_channels, 
                kernel_size=kernel_size, 
                stride=stride, 
                padding=padding,
            ),
            activation(),
            nn.BatchNorm2d(out_channels)
        )

    def get_last_batch_norm(self) -> nn.BatchNorm2d:
        return self.operation[-1]

    def replace_last_batch_norm(self, new_last_batch_norm: nn.BatchNorm2d) -> None:
        self.operation[-1] = new_last_batch_norm
        
    def operation_forward(self, x):
        return self.operation(x)

## 3. Build model with custom operation

In [None]:
SEARCH_OPS = [
    'MIB_k=3_t=6',
    'MIB_k=5_t=6',
    'MIB_k=7_t=6',
    'MyOp_k=3',  # Notice that you can omit parameters with default values
    'MyOp_k=3_a=sigmoid',
]

# build model
model = build_mobilenet(
    search_ops=SEARCH_OPS,
    num_classes=10,
    blocks_out_channels=[24, 32, 64, 96, 160, 320],
    blocks_count=[2, 2, 2, 1, 2, 1],
    blocks_stride=[2, 2, 2, 1, 2, 1],
)
# move model to search space
search_space = SearchSpaceModel(model).cuda()

## 4. Check pretrain search and tune phases

Let's check that everything works.

In [None]:
ENOT_HOME_DIR = Path.home() / '.enot'
ENOT_DATASETS_DIR = ENOT_HOME_DIR / 'datasets'
PROJECT_DIR = ENOT_HOME_DIR / 'custom_operation'

ENOT_HOME_DIR.mkdir(exist_ok=True)
ENOT_DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

In [None]:
dataloaders = create_imagenette_dataloaders(
    dataset_root_dir=ENOT_DATASETS_DIR, 
    project_dir=PROJECT_DIR,
    input_size=(224, 224),
    batch_size=32,
    imagenette_kind='imagenette2-320',
)

In [None]:
# define directory for text logs and tensorboard logs
pretrain_dir = PROJECT_DIR / 'pretrain'
pretrain_dir.mkdir(exist_ok=True)

N_EPOCHS = 3
N_WARMUP_EPOCHS = 1
len_train = len(dataloaders['pretrain_train_dataloader'])

optimizer = SGD(params=search_space.model_parameters(), lr=0.06, momentum=0.9, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train*N_EPOCHS, eta_min=1e-8)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train*N_WARMUP_EPOCHS)
loss_function = nn.CrossEntropyLoss().cuda()

pretrain(
    search_space=search_space,
    exp_dir=pretrain_dir,
    train_loader=dataloaders['pretrain_train_dataloader'],
    valid_loader=dataloaders['pretrain_validation_dataloader'],
    optimizer=optimizer,
    scheduler=scheduler,
    metric_function=accuracy,
    loss_function=loss_function,
    epochs=N_EPOCHS,
)

In [None]:
# define directory for text logs and tensorboard logs
search_dir = PROJECT_DIR / 'search'
search_dir.mkdir(exist_ok=True)

optimizer = RAdam(search_space.architecture_parameters(), lr=0.01)

search(
    search_space=search_space,
    exp_dir=search_dir,
    search_loader=dataloaders['search_train_dataloader'],
    valid_loader=dataloaders['search_validation_dataloader'],
    optimizer=optimizer,
    loss_function=loss_function,
    metric_function=accuracy,
    epochs=3,
)

In [None]:
# get regular model with best architecture
best_model = search_space.get_network_with_best_arch().cuda()

In [None]:
# define directory for text logs and tensorboard logs
tune_dir = PROJECT_DIR / 'tune'
tune_dir.mkdir(exist_ok=True)

optimizer = RAdam(best_model.parameters(), lr=5e-3, weight_decay=4e-5)

train(
    model=best_model,
    exp_dir=tune_dir,
    train_loader=dataloaders['tune_train_dataloader'],
    valid_loader=dataloaders['tune_validation_dataloader'],
    optimizer=optimizer,
    loss_function=loss_function,
    metric_function=accuracy,
    epochs=3,
)