## Additional starting points for OptimalPruningLabelSelector

This notebook demonstrates how to add additional starting points to `OptimalPruningLabelSelector`.

### Main chapters of this notebook:
1. Setup the environment
1. Prepare latency calculation function
1. Prepare starting points and starting points generator
1. Create label selector with additional starting points

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# You may need to uncomment and change this variable to match free GPU index
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# Common:
import torch
from fvcore.nn import FlopCountAnalysis

# Pruning:
from enot.pruning import GlobalPruningLabelSelectorByChannels
from enot.pruning import KnapsackPruningLabelSelector
from enot.pruning import OptimalPruningLabelSelector
from enot.pruning.label_selector.starting_points_generator import LabelSelectorsStartingPointsGenerator

### Prepare latency calculation function

In [None]:
TARGET_LATENCY = 100


def mmac_calculation_function(model):
    inputs, _ = next(iter(train_dataloader))

    fca = FlopCountAnalysis(
        model=model.eval(),
        inputs=inputs,
    )
    fca.unsupported_ops_warnings(False)
    fca.uncalled_modules_warnings(False)

    return fca.total() / 1e6

### Prepare starting points and starting points generator

In [None]:
# prepare label selectors for starting points
starting_point_label_selector_0 = GlobalPruningLabelSelectorByChannels(n_channels_or_ratio=0.5)
starting_point_label_selector_1 = GlobalPruningLabelSelectorByChannels(n_channels_or_ratio=0.3)
starting_point_label_selector_2 = KnapsackPruningLabelSelector(
    target_latency=TARGET_LATENCY,
    latency_calculation_function=mmac_calculation_function,
)

# add all starting points to generator
additional_starting_points_generator = LabelSelectorsStartingPointsGenerator(
    starting_point_label_selector_0,
    starting_point_label_selector_1,
    starting_point_label_selector_2,
)

### Create label selector with additional starting points

In [None]:
label_selector = OptimalPruningLabelSelector(
    target_latency=TARGET_LATENCY,
    latency_calculation_function=mmac_calculation_function,
    additional_starting_points_generator=additional_starting_points_generator,
    n_search_steps=200,
)