## Using Arcface head for search best architecture on metric learning task

### Notebook consists of next main stages:
1. Setup the environment
1. Define ArcfaceHead class and define MetricLearningModel class
1. Prepare dataloaders
1. Create the model and move it to search space
1. Pretrain search_space
1. Search best architecture
1. Tune model with best architecture

## 1. Setup the environment

In [None]:
# Copy your license file to $HOME/.enot/enot.lic or set full path to licence file 
# through environment variable ENOT_LIC_FILE
#
# Important note: no quotes
# %env ENOT_LIC_FILE=/FULL/PATH/TO/your_company.lic

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
# You should change to free GPU
os.environ['CUDA_VISIBLE_DEVICES']='0'

In [None]:
from pathlib import Path

import torch
import torch.nn as nn

from torch.nn import functional as F

from torch.optim import SGD
from torch_optimizer import RAdam

from torchvision.models.mobilenet import ConvBNReLU

from enot.models import build_simple_block_model
from enot.models import SearchSpaceModel
from enot.phases import pretrain
from enot.phases import search
from enot.phases import train
from enot.utils.data.csv_annotation_dataset import CsvAnnotationDataset
from enot.utils.data.dataloaders import create_data_loader
from enot.utils.data.dataloaders import create_data_loader_from_csv_annotation

from tutorial_utils.checkpoints import download_metric_learning_arcface_pretrain_checkpoint
from tutorial_utils.checkpoints import download_metric_learning_regular_pretrain_checkpoint
from tutorial_utils.dataset import create_imagenette_annotation
from tutorial_utils.dataset import download_imagenette
from tutorial_utils.dataset import create_imagenette_train_transform
from tutorial_utils.dataset import create_imagenette_validation_transform

from enot_utils.metric_utils import accuracy

## 2. Define ArcfaceHead class and define MetricLearningModel class

In [None]:
class ArcfaceHead(nn.Linear):
    def __init__(
            self,
            in_channels: int,
            *,
            radius=64.0,
            angle_margin=0.5,
            embedding_channels=512,
            num_classes=10,
    ):
        super().__init__(embedding_channels, num_classes, bias=False)

        self.radius = radius
        self.angle_margin = angle_margin
        self.num_classes = num_classes

        self.vectorizer = nn.Sequential(
            ConvBNReLU(
                in_channels, 
                embedding_channels,
                1, 1, 1, norm_layer=None,
            ),
            nn.AdaptiveAvgPool2d([1, 1]),
            nn.Flatten(),
        )

    def forward(self, inputs, labels=None):
        features = F.normalize(self.vectorizer(inputs), 2, -1)

        weights = F.normalize(self.weight, 2, -1)
        cosine = F.linear(features, weights)

        if labels is not None:
            angle = torch.acos(cosine)
            modified_cos = torch.cos(angle + self.angle_margin)

            one_hot = torch.zeros(cosine.size(), device=cosine.device)
            one_hot.scatter_(1, labels.view(-1, 1).long(), 1)
            cosine = cosine * (1 - one_hot) + modified_cos * one_hot

        return cosine * self.radius

In [None]:
class MetricLearningModel(nn.Module):
    def __init__(
        self,
        search_ops,
        embedding_channels=512,
        num_classes=10,
        radius=64.0,
        angle_margin=0.0,
    ):
        super().__init__()

        blocks_in_channels = 32
        blocks_out_channels = 320

        self.stem = ConvBNReLU(3, blocks_in_channels, stride=2)
        self.body = build_simple_block_model(
            in_channels=blocks_in_channels,
            search_ops=search_ops,
            blocks_out_channels=[16, 24, 32, 64, 96, 160, blocks_out_channels],
            blocks_count=[1, 2, 2, 2, 1, 2, 1],
            blocks_stride=[1, 2, 2, 2, 1, 2, 1],
            width_multiplier=1.0,
            min_channels=8,
            init_weights_kn=True,
        )
        self.head = ArcfaceHead(
            blocks_out_channels,
            radius=radius,
            angle_margin=angle_margin,
            embedding_channels=embedding_channels,
            num_classes=num_classes,
        )

    def update_angle_margin(self, value):
        self.head.angle_margin = value

    def forward(self, inputs):
        labels = None
        if not isinstance(inputs, torch.Tensor):
            inputs, labels = inputs

        x = self.stem(inputs)
        x = self.body(x)
        x = self.head(x, labels)

        return x

## 3. Prepare dataloaders
For arcface training we need to pass labels to head part of the model, so arcface train `Dataset` must return data in the next format `(image, label), label`.  

In [None]:
ENOT_HOME_DIR = Path.home() / '.enot'
ENOT_DATASETS_DIR = ENOT_HOME_DIR / 'datasets'
PROJECT_DIR = ENOT_HOME_DIR / 'metric_learning'

ENOT_HOME_DIR.mkdir(exist_ok=True)
ENOT_DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

In [None]:
class ArcfaceTrainCsvAnnotationDataset(CsvAnnotationDataset):
    def __init__(self, csv_annotation_path, root_dir=None, transform=None):
        super().__init__(csv_annotation_path, root_dir=root_dir, transform=transform)

    def __getitem__(self, idx):
        image, label = super().__getitem__(idx)
        return (image, label), label


num_workers = 4
input_size = (224, 224)
batch_size = 32
dataset_dir = download_imagenette(
    dataset_root_dir=ENOT_DATASETS_DIR, imagenette_kind='imagenette2-320',
)
annotations = create_imagenette_annotation(
    dataset_dir=dataset_dir, project_dir=PROJECT_DIR, random_seed=42,
)
train_transform = create_imagenette_train_transform(input_size)
validation_transform = create_imagenette_validation_transform(input_size)

pretrain_and_tune_train_dataloader = create_data_loader_from_csv_annotation(
    csv_annotation_path=annotations['train'],
    dataset_transform=train_transform,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True,
)
pretrain_and_tune_arcface_train_dataloader = create_data_loader(
    dataset=ArcfaceTrainCsvAnnotationDataset(annotations['train'], transform=train_transform),
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True,
    collate_fn=None,
)
pretrain_and_tune_validation_dataloader = create_data_loader_from_csv_annotation(
    csv_annotation_path=annotations['validation'],
    dataset_transform=validation_transform,
    batch_size=batch_size,
    shuffle=False,
)

search_train_dataloader = create_data_loader_from_csv_annotation(
    csv_annotation_path=annotations['search'],
    dataset_transform=train_transform,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True,
)
search_validation_dataloader = create_data_loader_from_csv_annotation(
    csv_annotation_path=annotations['search'],
    dataset_transform=validation_transform,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
)

## 4. Create the model and move it to search space

In [None]:
SEARCH_OPS = [
    'MIB_k=3_t=3',
    'MIB_k=3_t=6',
]

# build model
model = MetricLearningModel(search_ops=SEARCH_OPS, angle_margin=0.1)
# move model to search space
search_space = SearchSpaceModel(model, train_loader=pretrain_and_tune_train_dataloader).cuda()

## 5. Pretrain search_space

Pretrain procedure stages:
1. We pretrain regular search_space (use regular train dataloader `pretrain_and_tune_train_dataloader`)
1. Calculate avg cosine distance to centroid for all classes
1. We pretrain arcface search_space with angle_margin=0.1 (use arcface train dataloader `pretrain_and_tune_arcface_train_dataloader`, which returns labels)
1. Calculate avg cosine distance to centroid for all classes

**IMPORTANT:**<br>
We cannot start training of arcface search_space from the randomly initialized search_space, because it is not efficient. Arcface pretrain starts from pretrained regular model and just tune it. For fair comparison we use next train procedure:<br>
```
1-2-3-4-...-N_EPOCHS-...-N_EPOCHS+N_TUNE_EPOCHS <- Regular model
               |
         load checkpoint
               |
               1-...N_TUNE_EPOCHS <- Arcface model
```

In [None]:
# helper function, which calculate avg cosine distance to centroid for all classes
def calc_avg_cosine_dist_to_centroids(search_space, dataloader, num_classes=10, forward_indices=None):
    avg_distances = [(0.0, 0) for i in range(num_classes)]

    # We should fixate architecture of search space. It can be any architecture, but 
    # must be the same for arcface and regular models
    if forward_indices is None:
        n = len(search_space.search_variants_containers)
        forward_indices = [[0] for _ in range(n)]

    search_space.eval()
    search_space.sample(forward_indices=forward_indices)

    radius = search_space.original_model.head.radius
    for inputs, labels in dataloader:
        labels = labels.detach().cpu().numpy()
        current_result = search_space(inputs)
        current_result = current_result.detach().cpu().numpy()

        for one_result, one_label in zip(current_result, labels):
            # get 
            dist, n = avg_distances[one_label]
            # update
            current_dist = 1 - one_result[one_label]/radius
            dist += current_dist
            n += 1
            # set updated
            avg_distances[one_label] = dist, n

    for i in range(num_classes):
        dist, n = avg_distances[i]
        avg_distances[i] = dist/n if n != 0 else 0.0

    return avg_distances

**IMPORTANT:**<br>
(N_EPOCHS + N_TUNE_EPOCHS) should be ~(100 + 30), if you wanna get good pretrain. In this tutorial we set (N_EPOCHS + N_TUNE_EPOCHS) = (2 + 1) and download checkpoints of pretreined models from Google Drive.

In [None]:
# define directory for text logs and tensorboard logs
pretrain_regular_dir = PROJECT_DIR / 'pretrain_regular'
pretrain_regular_dir.mkdir(exist_ok=True)
pretrain_arcface_dir = PROJECT_DIR / 'pretrain_arcface'
pretrain_arcface_dir.mkdir(exist_ok=True)

N_EPOCHS = 10
N_TUNE_EPOCHS = 1
USE_CHECKPOINTS_FROM_GOOGLE_DRIVE = True

optimizer = SGD(
    params=search_space.model_parameters(), 
    lr=0.01, 
    momentum=0.9, 
    weight_decay=1e-4
)
loss_function = nn.CrossEntropyLoss().cuda()

# pretrain regular search_space
pretrain(
    search_space=search_space,
    exp_dir=pretrain_regular_dir,
    train_loader=pretrain_and_tune_train_dataloader,
    valid_loader=pretrain_and_tune_validation_dataloader,
    optimizer=optimizer,
    metric_function=accuracy,
    loss_function=loss_function,
    epochs=N_EPOCHS + N_TUNE_EPOCHS,
)
if USE_CHECKPOINTS_FROM_GOOGLE_DRIVE:
    # download regular pretrain checkpoint from google drive
    checkpoint_path = pretrain_regular_dir / 'regular_checkpoint.pth'
    download_metric_learning_regular_pretrain_checkpoint(checkpoint_path)
    search_space.load_state_dict(
        torch.load(checkpoint_path)['model'],
    )
# calculate distance between classes for regular model
avg_cosine_dist_to_centroids_reg = calc_avg_cosine_dist_to_centroids(
    search_space=search_space, 
    dataloader=pretrain_and_tune_validation_dataloader, 
)

# load checkpoint for tune
checkpoint_path = pretrain_regular_dir/f'checkpoint-{N_EPOCHS}.pth'
search_space.load_state_dict(
    torch.load(checkpoint_path, map_location='cuda')['model']
)
# pretrain arcface search_space
pretrain(
    search_space=search_space,
    exp_dir=pretrain_arcface_dir,
    train_loader=pretrain_and_tune_arcface_train_dataloader,
    valid_loader=pretrain_and_tune_validation_dataloader,
    optimizer=optimizer,
    metric_function=accuracy,
    loss_function=loss_function,
    epochs=N_TUNE_EPOCHS,
)
if USE_CHECKPOINTS_FROM_GOOGLE_DRIVE:
    # download arcface pretrain checkpoint from google drive
    checkpoint_path = pretrain_arcface_dir / 'arcface_checkpoint.pth'
    download_metric_learning_arcface_pretrain_checkpoint(checkpoint_path)
    search_space.load_state_dict(
        torch.load(checkpoint_path)['model'],
    )
# calculate distance between classes for arcface model
avg_cosine_dist_to_centroids_af = calc_avg_cosine_dist_to_centroids(
    search_space=search_space, 
    dataloader=pretrain_and_tune_validation_dataloader, 
)

## 6. Check arcface profit
Average cosine distance to centroids for regular model should be greater than for arcface model.

In [None]:
print('#  REGULAR  ARCFACE')
for i, (reg, af) in enumerate(zip(avg_cosine_dist_to_centroids_reg, avg_cosine_dist_to_centroids_af)):
    print(f'{i}  {reg:.4f}  {af:.4f}')

## 7. Search best architecture
We will use final network without any angle_margin, so we should search it is as regular model (dataloader without labels)

In [None]:
# define directory for text logs and tensorboard logs
search_dir = PROJECT_DIR / 'search'
search_dir.mkdir(exist_ok=True)

optimizer = RAdam(search_space.architecture_parameters(), lr=0.01)

search(
    search_space=search_space,
    exp_dir=search_dir,
    search_loader=search_train_dataloader,
    valid_loader=search_validation_dataloader,
    optimizer=optimizer,
    loss_function=loss_function,
    metric_function=accuracy,
    epochs=5,
)

## 8. Tune model with best architecture
Tune final model with angle_margin=0.1

In [None]:
# get regular model with best architecture
best_model = search_space.get_network_with_best_arch().cuda()

In [None]:
# define directory for text logs and tensorboard logs
tune_dir = PROJECT_DIR / 'tune'
tune_dir.mkdir(exist_ok=True)

optimizer = RAdam(best_model.parameters(), lr=5e-3, weight_decay=4e-5)

train(
    model=best_model,
    exp_dir=tune_dir,
    train_loader=pretrain_and_tune_arcface_train_dataloader,
    valid_loader=pretrain_and_tune_validation_dataloader,
    optimizer=optimizer,
    loss_function=loss_function,
    metric_function=accuracy,
    epochs=5,
)