## Custom stats collector (phase callbacks)
This notebook describes how you can implement your own stats collector.

### Notebook consists of next main stages:
1. Setup the environment
1. Prepare dataset and create dataloaders
1. Create the model and move it to search space
1. Define class for custom stats collector
1. Run Pretrain, Search, Tune phases with custom stats collector 

## 1. Setup the environment
First, let's set up the environment and common imports.

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
# You should change to free GPU
os.environ['CUDA_VISIBLE_DEVICES']='0'

In [None]:
from pathlib import Path

import torch
import torch.nn as nn

from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_optimizer import RAdam

from enot.models import SearchSpaceModel
from enot.models.mobilenet import build_mobilenet
from enot.phases import pretrain
from enot.phases import search
from enot.phases import train
from enot.utils.data.csv_annotation_dataset import CsvAnnotationDataset
from enot.utils.data.dataloaders import create_data_loader
from enot.utils.data.dataloaders import create_data_loader_from_csv_annotation
from enot.utils.stats_collector import StatsCollector

from enot_utils.metric_utils import accuracy
from enot_utils.schedulers import WarmupScheduler

from tutorial_utils.checkpoints import download_getting_started_pretrain_checkpoint
from tutorial_utils.dataset import create_imagenette_annotation
from tutorial_utils.dataset import download_imagenette
from tutorial_utils.dataset import create_imagenette_train_transform
from tutorial_utils.dataset import create_imagenette_validation_transform

### In the next cell we setup all required dirs

* `ENOT_HOME_DIR` - is root dir for all other dirs
* `ENOT_DATASETS_DIR` - is root dir for datasets (imagenette2)
* `PROJECT_DIR` - is root dir for output data (checkpoints, logs...) of current tutorial

In [None]:
ENOT_HOME_DIR = Path.home() / '.enot'
ENOT_DATASETS_DIR = ENOT_HOME_DIR / 'datasets'
PROJECT_DIR = ENOT_HOME_DIR / 'stats_collector'

ENOT_HOME_DIR.mkdir(exist_ok=True)
ENOT_DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

## 2. Prepare dataset and create dataloaders
For this tutorial we add `limit` parameter for default `Dataset`. `limit` defines max number of output images of `Dataset`. Small datasets allow to create phases (pretrain, search, tune) with short and clear log. Pretrain, search and tune dataloaders contains 2 batch with 2 samples in each batch. 

In [None]:
class CsvAnnotationDatasetWithLimit(CsvAnnotationDataset):
    def __init__(self, csv_annotation_path, root_dir=None, transform=None, limit=None):
        self._limit = limit
        super().__init__(csv_annotation_path, root_dir=root_dir, transform=transform)
    
    def __len__(self):
        n = super().__len__()
        if self._limit is None:
            return n
        
        return min(self._limit, n)


input_size = (224, 224)
batch_size = 2
dataset_dir = download_imagenette(
    dataset_root_dir=ENOT_DATASETS_DIR, imagenette_kind='imagenette2-320',
)
annotations = create_imagenette_annotation(
    dataset_dir=dataset_dir, project_dir=PROJECT_DIR, random_seed=42,
)
train_transform = create_imagenette_train_transform(input_size)
validation_transform = create_imagenette_validation_transform(input_size)

pretrain_and_tune_train_dataloader = create_data_loader(
    dataset=CsvAnnotationDatasetWithLimit(
        annotations['train'], 
        transform=train_transform,
        limit=2*batch_size,
    ),
    batch_size=batch_size,
    shuffle=True,
)
pretrain_and_tune_validation_dataloader = create_data_loader(
    dataset=CsvAnnotationDatasetWithLimit(
        annotations['train'], 
        transform=validation_transform,
        limit=2*batch_size,
    ),
    batch_size=batch_size,
    shuffle=False,
)

search_train_dataloader = create_data_loader(
    dataset=CsvAnnotationDatasetWithLimit(
        annotations['search'],
        transform=train_transform,
        limit=2*batch_size,
    ),
    batch_size=batch_size,
    shuffle=True,
)
search_validation_dataloader = create_data_loader(
    dataset=CsvAnnotationDatasetWithLimit(
        annotations['search'],
        transform=validation_transform,
        limit=2*batch_size,
    ),
    batch_size=batch_size,
    shuffle=False,
)

## 3. Create the model and move it to search space

In [None]:
# The search space will have these ops to choose from in each layer
# Short format for operations is 'Name_param1=value1_param2=value2...'
# MIB is MNv2 inverted bottleneck. k is kernel size, t is expansion ratio
# See more in-depth info in "Tutorial - adding custom operations"
SEARCH_OPS = [
    'MIB_k=3_t=6',
    'MIB_k=5_t=6',
    'MIB_k=7_t=6',
]

# build model
model = build_mobilenet(
    search_ops=SEARCH_OPS,
    num_classes=10,
    blocks_out_channels=[24, 32, 64, 96, 160, 320],
    blocks_count=[2, 2, 2, 1, 2, 1],
    blocks_stride=[2, 2, 2, 1, 2, 1],
)
# move model to search space
search_space = SearchSpaceModel(model).cuda()

## 4. Define class for custom stats collector
All custom stats collectors subclass the `StatsCollector`. We define new custom stats collector, which print name and arguments of each callback on every call. To turn on `on_train_batch_result` and `on_validation_batch_result` callbacks we set `need_train_batch_result` and `need_validation_batch_result` to `True`.

In [None]:
class MyStatsCollector(StatsCollector):
    @property
    def need_train_batch_result(self):
        """
        If this property is True then on_train_batch_result will be called in phase loop, 
        otherwise on_train_batch_result will be ignored.
        """
        return True

    @property
    def need_validation_batch_result(self):
        """
        If this property is True then on_train_batch_result will be called in phase loop, 
        otherwise on_train_batch_result will be ignored.
        """
        return True

    def on_phase_start(self, phase_name):
        print(f'on_phase_start: {phase_name}')

    def on_phase_end(self, phase_name):
        print(f'on_phase_end: {phase_name}')

    def on_epoch_start(self, epoch):
        print(f'on_epoch_start: #{epoch}')

    def on_epoch_end(self, epoch, stats):
        print(f'on_epoch_end: #{epoch}\n{stats}')

    def on_train_start(self):
        print('on_train_start')

    def on_train_end(self, stats):
        print(f'on_train_end:\n{stats}')

    def on_validation_start(self):
        print('on_validation_start')

    def on_validation_end(self, stats):
        print(f'on_validation_end:\n{stats}')

    def on_train_batch_start(self, batch):
        print(f'on_train_batch_start: #{batch}')

    def on_train_batch_result(self, batch, predicted, original, process_index, sample_index):
        print(f'on_train_batch_result: #{batch} (process_index = {process_index}, sample_index = {sample_index})')

    def on_train_batch_end(self, batch, stats):
        print(f'on_train_batch_end: #{batch}\n{stats}')

    def on_validation_batch_start(self, batch):
        print(f'on_validation_batch_start: #{batch}')

    def on_validation_batch_result(self, batch, predicted, original, process_index, sample_index):
        print(f'on_validation_batch_result: #{batch} (process_index = {process_index}, sample_index = {sample_index})')

    def on_validation_batch_end(self, batch, stats):
        print(f'on_validation_batch_end: #{batch}\n{stats}')

## 5. Run Pretrain, Search, Tune phases with custom stats collector
We run all phases, with our custom stats collector.

**IMPORTANT:**<br>
Default logging and tensorboard logging stats collectors will be replaced by user defined stats collector. If you need functional of default stats collectors, then you should to add them manually.Â 

In [None]:
# define directory for text logs and tensorboard logs
pretrain_dir = PROJECT_DIR / 'pretrain'
pretrain_dir.mkdir(exist_ok=True)

N_EPOCHS = 1

optimizer = SGD(params=search_space.model_parameters(), lr=0.06, momentum=0.9, weight_decay=1e-4)
loss_function = nn.CrossEntropyLoss().cuda()

pretrain(
    search_space=search_space,
    exp_dir=pretrain_dir,
    train_loader=pretrain_and_tune_train_dataloader,
    valid_loader=pretrain_and_tune_validation_dataloader,
    optimizer=optimizer,
    metric_function=accuracy,
    loss_function=loss_function,
    epochs=N_EPOCHS,
    stats_collectors=[MyStatsCollector()],
)

In [None]:
# define directory for text logs and tensorboard logs
search_dir = PROJECT_DIR / 'search'
search_dir.mkdir(exist_ok=True)

optimizer = RAdam(search_space.architecture_parameters(), lr=0.01)

search(
    search_space=search_space,
    exp_dir=search_dir,
    search_loader=search_train_dataloader,
    valid_loader=search_validation_dataloader,
    optimizer=optimizer,
    loss_function=loss_function,
    metric_function=accuracy,
    latency_loss_weight=2.0e-3,
    epochs=5,
    stats_collectors=[MyStatsCollector()],
)

In [None]:
# get regular model with best architecture
best_model = search_space.get_network_with_best_arch().cuda()

In [None]:
# define directory for text logs and tensorboard logs
tune_dir = PROJECT_DIR / 'tune'
tune_dir.mkdir(exist_ok=True)

optimizer = RAdam(best_model.parameters(), lr=5e-3, weight_decay=4e-5)

train(
    model=best_model,
    exp_dir=tune_dir,
    train_loader=pretrain_and_tune_train_dataloader,
    valid_loader=pretrain_and_tune_validation_dataloader,
    optimizer=optimizer,
    loss_function=loss_function,
    metric_function=accuracy,
    epochs=5,
    stats_collectors=[MyStatsCollector()],
)