In [1]:
from copy import deepcopy
from collections import deque
from functools import singledispatch

from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Set
from typing import Union

import numpy as np
import torch
import torch.nn as nn

from enot.models.mobilenet import MobileNetBaseHead
from enot.models.operations import SearchableMobileInvertedBottleneck
from enot.models.operations import SearchVariantsContainer
from torchvision.models.mobilenet import InvertedResidual

In [2]:
TChannelsMapping = List[Tuple[int, int]]


def conv2d_like(
    src_conv: nn.Conv2d,
    dst_in_channels: int,
    dst_out_channels: int,
    init_w_const: Optional[float] = 0.0,
) -> nn.Conv2d:
    groups = src_conv.groups
    if src_conv.groups == src_conv.in_channels:
        groups = dst_in_channels

    bias = src_conv.bias is not None

    dst_conv = nn.Conv2d(
        in_channels=dst_in_channels, 
        out_channels=dst_out_channels, 
        kernel_size=src_conv.kernel_size, 
        stride=src_conv.stride, 
        padding=src_conv.padding, 
        dilation=src_conv.dilation, 
        groups=groups, 
        bias=bias, 
        padding_mode=src_conv.padding_mode,
    )
    if init_w_const is not None:
        nn.init.constant_(dst_conv.weight, init_w_const)
    
    return dst_conv


def bn_like(
    src_bn: nn.BatchNorm2d,
    dst_num_features: int,
) -> nn.BatchNorm2d:
    return nn.BatchNorm2d(
        num_features=dst_num_features, 
        eps=src_bn.eps, 
        momentum=src_bn.momentum, 
        affine=src_bn.affine, 
        track_running_stats=src_bn.track_running_stats,
    )


def conv2d_bn_like(
    src_conv: nn.Conv2d,
    src_bn: nn.Conv2d,
    dst_in_channels: int,
    dst_out_channels: int,
    init_w_const: Optional[float] = None,
) -> Tuple[nn.Conv2d, nn.BatchNorm2d]:
    conv2d = conv2d_like(
        src_conv,
        dst_in_channels=dst_in_channels,
        dst_out_channels=dst_out_channels,
        init_w_const=init_w_const,
    )
    bn = bn_like(
        src_bn=src_bn,
        dst_num_features=dst_out_channels,
    )

    return conv2d, bn 


def get_mapping_from_bn(bn: nn.BatchNorm2d, dst_channels: int) -> TChannelsMapping:
    if dst_channels <= 0:
        raise ValueError('Number of destination channels must be > 0')

    if not bn.track_running_stats:
        raise ValueError('track_running_stats must be True')

    src_channels = bn.num_features
    if src_channels <= dst_channels:
        return [(i, i) for i in range(src_channels)]
    
    score = bn.running_var.detach().cpu().numpy()
    filtered_channels = np.argsort(-score)
    filtered_channels = sorted(filtered_channels[:dst_channels])

    return [
        (src_ch, dst_ch)
        for dst_ch, src_ch in enumerate(filtered_channels)
    ]


def _unpack_mapping(
    mapping: Optional[TChannelsMapping],
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
    src, dst = tuple(zip(*mapping))
    if len(dst) != len(set(dst)):
        raise RuntimeError('destination indexes must be unique')
        
    return list(src), list(dst)


def _apply_mapping(
    dst: torch.Tensor,
    src: torch.Tensor,
    mapping: Optional[TChannelsMapping],
) -> None:
    if mapping is None:
        dst.copy_(src)
        return
    
    src_indeces, dst_indeces = _unpack_mapping(mapping)
    dst[dst_indeces] = src[src_indeces]

    
def _apply_mapping_out_in(
    dst: torch.Tensor,
    src: torch.Tensor,
    out_mapping: Optional[TChannelsMapping],
    in_mapping: Optional[TChannelsMapping],
) -> None:
    if in_mapping is None and out_mapping is None:
        dst.copy_(src)
        return

    if out_mapping is not None:
        src_out_indeces, dst_out_indeces = _unpack_mapping(out_mapping)
    else:
        src_out_indeces, dst_out_indeces = slice(None, None), slice(None, None)

    if in_mapping is not None:
        src_in_indeces, dst_in_indeces = _unpack_mapping(in_mapping)
        if out_mapping is not None:
            src_out_indeces = np.asarray(src_out_indeces)[:, np.newaxis]
            dst_out_indeces = np.asarray(dst_out_indeces)[:, np.newaxis] 
    else:
        src_in_indeces, dst_in_indeces = Ellipsis, Ellipsis
    
    dst[dst_out_indeces, dst_in_indeces] = src[src_out_indeces, src_in_indeces]


def conv2d_apply_mapping(
    dst_conv: nn.Conv2d, 
    src_conv: nn.Conv2d, 
    out_mapping: Optional[TChannelsMapping],
    in_mapping: Optional[TChannelsMapping],
) -> None:
    if (src_conv.bias is None) != (dst_conv.bias is None):
        raise ValueError('(src_conv.bias is None) != (dst_conv.bias is None)')

    with torch.no_grad():
        _apply_mapping_out_in(dst_conv.weight, src_conv.weight, out_mapping, in_mapping)

        if dst_conv.bias is not None:
            _apply_mapping(dst_conv.bias, src_conv.bias, out_mapping)

            
def bn_apply_mapping(
    dst_bn: nn.BatchNorm2d, 
    src_bn: nn.BatchNorm2d, 
    mapping: Optional[TChannelsMapping],
) -> None:
    if src_bn.affine != dst_bn.affine:
        raise ValueError('src_bn.affine != dst_bn.affine')
    
    if src_bn.track_running_stats != dst_bn.track_running_stats:
        raise ValueError('src_bn.track_running_stats != dst_bn.track_running_stats')

    with torch.no_grad():
        if dst_bn.affine:
            _apply_mapping(dst_bn.weight, src_bn.weight, mapping)
            _apply_mapping(dst_bn.bias, src_bn.bias, mapping)

        if dst_bn.track_running_stats:                
            _apply_mapping(dst_bn.running_mean, src_bn.running_mean, mapping)
            _apply_mapping(dst_bn.running_var, src_bn.running_var, mapping)
            dst_bn.num_batches_tracked.copy_(src_bn.num_batches_tracked)
            
            
def conv2d_bn_apply_mapping(
    dst_conv: nn.Conv2d, 
    dst_bn: nn.BatchNorm2d, 
    src_conv: nn.Conv2d, 
    src_bn: nn.BatchNorm2d, 
    out_mapping: Optional[TChannelsMapping],
    in_mapping: Optional[TChannelsMapping],
) -> None:
    conv2d_apply_mapping(
        dst_conv=dst_conv, 
        src_conv=src_conv, 
        out_mapping=out_mapping,
        in_mapping=in_mapping,
    )
    bn_apply_mapping(
        dst_bn=dst_bn, 
        src_bn=src_bn, 
        mapping=out_mapping,
    )
    

def create_conv2d_and_apply_mapping(
    src_conv: nn.Conv2d,
    dst_in_channels: int,
    dst_out_channels: int,
    out_mapping: Optional[TChannelsMapping],
    in_mapping: Optional[TChannelsMapping],
) -> nn.Conv2d:
    dst_conv = conv2d_like(
        src_conv=src_conv,
        dst_in_channels=dst_in_channels,
        dst_out_channels=dst_out_channels,
        init_w_const=0.0,
    )
    conv2d_apply_mapping(
        dst_conv=dst_conv, 
        src_conv=src_conv, 
        out_mapping=out_mapping,
        in_mapping=in_mapping,
    )

    return dst_conv


def create_bn_and_apply_mapping(
    src_bn: nn.BatchNorm2d,
    dst_num_features: int,
    mapping: Optional[TChannelsMapping],
) -> nn.BatchNorm2d:
    dst_bn = bn_like(
        src_bn=src_bn,
        dst_num_features=dst_num_features,
    )
    bn_apply_mapping(
        dst_bn=dst_bn, 
        src_bn=src_bn, 
        mapping=mapping,
    )

    return dst_bn


def create_conv2d_bn_and_apply_mapping(
    src_conv: nn.Conv2d,
    src_bn: nn.BatchNorm2d,
    dst_in_channels: int,
    dst_out_channels: int,
    out_mapping: Optional[TChannelsMapping],
    in_mapping: Optional[TChannelsMapping],
) -> Tuple[nn.Conv2d, nn.BatchNorm2d]:
    dst_conv = create_conv2d_and_apply_mapping(
        src_conv=src_conv,
        dst_in_channels=dst_in_channels,
        dst_out_channels=dst_out_channels,
        out_mapping=out_mapping,
        in_mapping=in_mapping,
    )
    dst_bn = create_bn_and_apply_mapping(
        src_bn=src_bn,
        dst_num_features=dst_out_channels,
        mapping=out_mapping,
    )
    
    return dst_conv, dst_bn

In [3]:
def create_MNv2_block_and_apply_mapping(
    src_block: InvertedResidual, 
    dst_channels: int,
) -> InvertedResidual:
    if len(src_block.conv) != 4:
        return None
    
    if type(src_block) is not InvertedResidual:
        warnings.warn(
            'Got subclass of InvertedResidual. Original InvertedResidual class '
            'should be used to be ensure of correct generation of search variants. '
            'In most of cases using subclass is ok.',
            RuntimeWarning,
        )

    
    bn_after_dws = src_block.conv[1][1]
    mapping = get_mapping_from_bn(bn_after_dws, dst_channels)

    dst_block = SearchableMobileInvertedBottleneck(
        in_channels=src_block.conv[0][0].in_channels,
        out_channels=src_block.conv[2].out_channels,
        dw_channels=dst_channels,
        kernel_size=3,
        affine=True,
        track=True,
        activation='relu6',
        use_skip_connection=src_block.use_res_connect,
        
        stride =   src_block.stride,
        padding =  src_block.padding,
        dilation = src_block.dilation
    )
    
    # expand
    src_expand = src_block.conv[0][0]
    src_expand_bn = src_block.conv[0][1]
    
    dst_expand, dst_expand_bn = create_conv2d_bn_and_apply_mapping(
        src_conv=src_expand,
        src_bn=src_expand_bn,
        dst_in_channels=src_expand.in_channels,
        dst_out_channels=dst_channels,
        out_mapping=mapping,
        in_mapping=None,
    )
    
    dst_block.expand_op[0].conv = dst_expand
    dst_block.expand_op[0].bn = dst_expand_bn

    # dws
    src_dws = src_block.conv[1][0]
    src_dws_bn = src_block.conv[1][1]
    
    dst_dws, dst_dws_bn = create_conv2d_bn_and_apply_mapping(
        src_conv=src_dws,
        src_bn=src_dws_bn,
        dst_in_channels=dst_channels,
        dst_out_channels=dst_channels,
        out_mapping=mapping,
        in_mapping=None,
    )
        
    dst_block.depthwise_op[0].conv = dst_dws
    dst_block.depthwise_op[0].bn = dst_dws_bn
    
    # project conv
    src_project = src_block.conv[2]
    src_project_bn = src_block.conv[3]

    dst_project, dst_project_bn = create_conv2d_bn_and_apply_mapping(
        src_conv=src_project,
        src_bn=src_project_bn,
        dst_in_channels=dst_channels,
        dst_out_channels=src_project.out_channels,
        out_mapping=None,
        in_mapping=mapping,
    )

    dst_block.squeeze_op.conv = dst_project
    dst_block.squeeze_op.bn = dst_project_bn

    return dst_block

In [4]:
@singledispatch
def split_module_to_search_variants(src_module, **options) -> Optional[SearchVariantsContainer]:
    return None


@split_module_to_search_variants.register
def _(src_module: InvertedResidual, **options) -> Optional[SearchVariantsContainer]:
    if len(src_module.conv) != 4:
        return None
    
    width_fractions = options.pop('width_fractions', [0.25, 0.5, 0.75, 1.0])
    
    bn_after_dws = src_module.conv[1][1]
    src_channels = bn_after_dws.num_features

    search_variants = []
    for fraction in width_fractions:
        if fraction == 1.0:
            search_variants.append(deepcopy(src_module))
        else:
            dst_channels = max(int(src_channels*fraction), 1)
            dst_module = create_MNv2_block_and_apply_mapping(src_module, dst_channels=dst_channels)
            search_variants.append(dst_module)

    return SearchVariantsContainer(search_variants)


def build_model_with_search_variants(
    model: nn.Module, 
    split_options_by_types: Optional[Dict[type, Dict[str, Any]]] = None,
    split_options_by_modules: Optional[Dict[nn.Module, Dict[str, Any]]] = None,
    excluded_types: Optional[Set[type]] = None, 
    excluded_modules: Optional[Set[nn.Module]] = None, 
) -> nn.Module:
    if split_options_by_types is None:
        split_options_by_types = dict()

    if split_options_by_modules is None:
        split_options_by_modules = dict()

    if excluded_types is None:
        excluded_types = set()

    if excluded_modules is None:
        excluded_modules = set()

    def apply_split_module_to_search_variants(module):
        if type(module) in excluded_types:
            return None

        if type(module) in excluded_modules:
            return None
            
        options = split_options_by_modules.get(type(module), {})
        if not options: 
            options = split_options_by_types.get(type(module), {})
        
        return split_module_to_search_variants(module, **options)
    
    result_model = apply_split_module_to_search_variants(model)
    if result_model is not None:
        return result_model

    result_model = deepcopy(model)

    queue = deque(((n,), m) for n, m in result_model.named_children())
    while queue:
        current_attr_path, current_module = queue.pop()

        search_variants_container = apply_split_module_to_search_variants(current_module)
        if search_variants_container is None:
            for submodule_name, submodule in current_module.named_children():
                submodule_attr_path = current_attr_path + (submodule_name,)
                queue.append(
                    (submodule_attr_path, submodule)
                )
        else:
            target_module = result_model
            for attr in current_attr_path[:-1]:
                target_module = getattr(target_module, attr)
            
            attr = current_attr_path[-1]
            setattr(target_module, attr, search_variants_container)

    return result_model

In [5]:
model = torch.hub.load('pytorch/vision:v0.8.0', 'mobilenet_v2', pretrained=True)
model.eval()

model_with_search_variants = build_model_with_search_variants(model, split_options_by_types={InvertedResidual: {'width_fractions': [0.25, 0.5]}})
model_with_search_variants.classifier[1] = nn.Linear(in_features=1280, out_features=10, bias=True)

Using cache found in /home/igor/.cache/torch/hub/pytorch_vision_v0.8.0


In [1]:
import numpy as np

def params_amount(model):
    return np.sum([p.numel() for p in model.parameters()])

In [None]:
from collections import OrderedDict



In [None]:
mnv2 = MobileNetV2()
load_state_mnv2_pretrained(mnv2)

In [None]:
net = PoseEstimationWithMobileNet(backbone = mnv2.model, after_backbone_channels = 96)

In [None]:
model_with_search_variants = build_model_with_search_variants(net, 
                                                              split_options_by_types={
                                                                  InvertedResidual: {'width_fractions': [1.0, 0.5, 1/6]}
                                                              })

In [None]:
search_space = SearchSpaceModel(model_with_search_variants)
torch.save(search_space.state_dict(), "search_space_init.pth")

In [None]:
net = SearchableMobileNetV2()
net = PoseEstimationWithMobileNet(backbone = mnv2.model, after_backbone_channels = 96)
search_space = SearchSpaceModel(net)
search_space.load_state_dict(torch.load("search_space_init.pth"))

In [6]:
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# You may need to change this variable to match free GPU index
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [7]:
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn

from torch.optim import Adam
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_optimizer import RAdam

from enot.models import SearchSpaceModel
from enot.models.mobilenet import build_mobilenet
from enot.optimize import EnotPretrainOptimizer
from enot.optimize import EnotSearchOptimizer

from enot_utils.metric_utils import accuracy
from enot_utils.schedulers import WarmupScheduler

from tutorial_utils.checkpoints import download_getting_started_pretrain_checkpoint
from tutorial_utils.dataset import create_imagenette_dataloaders

In [8]:
ENOT_HOME_DIR = Path.home() / '.enot'
ENOT_DATASETS_DIR = ENOT_HOME_DIR / 'datasets'
PROJECT_DIR = ENOT_HOME_DIR / 'pruning_test'

ENOT_HOME_DIR.mkdir(exist_ok=True)
ENOT_DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

In [9]:
dataloaders = create_imagenette_dataloaders(
    dataset_root_dir=ENOT_DATASETS_DIR,
    project_dir=PROJECT_DIR,
    input_size=(224, 224),
    batch_size=32,
)

In [10]:
# move model to search space
search_space = SearchSpaceModel(model_with_search_variants).cuda()

In [11]:
N_EPOCHS = 100
N_WARMUP_EPOCHS = 10

train_loader = dataloaders['pretrain_train_dataloader']
validation_loader = dataloaders['pretrain_validation_dataloader']

# using `search_space.model_parameters()` as optimizable variables
optimizer = SGD(params=search_space.model_parameters(), lr=0.06, momentum=0.9, weight_decay=1e-4)
# using `EnotPretrainOptimizer` as a default optimizer
enot_optimizer = EnotPretrainOptimizer(search_space=search_space, optimizer=optimizer)

len_train_loader = len(train_loader)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train_loader*N_EPOCHS, eta_min=1e-8)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train_loader*N_WARMUP_EPOCHS)

metric_function = accuracy
loss_function = nn.CrossEntropyLoss().cuda()

for epoch in range(N_EPOCHS):

    print(f'EPOCH #{epoch}')

    search_space.train()
    train_metrics_acc = {
        'loss': 0.0,
        'accuracy': 0.0,
        'n': 0,
    }
    for inputs, labels in train_loader:
        # By default, `EnotPretrainOptimizer` requires one batch of train data to initialize optimizations, 
        # so you should run `search_space.initialize_output_distribution_optimization(...)` before the first 
        # model step. You can disable optimization checking in `EnotPretrainOptimizer` constructor 
        # (`check_recommended_optimizations` parameter), but this is not recommended.
        if not search_space.output_distribution_optimization_enabled:
            search_space.initialize_output_distribution_optimization(inputs)

        enot_optimizer.zero_grad()
        # Wrapping model step and backward with closure.
        # Alternatively, here is `enot_optimizer.model_step(...)` example usage for gradient accumulation:
        #
        # enot_optimizer.zero_grad()
        # for inputs, labels in train_loader:
        #
        #     def closure():
        #         ...
        #
        #     enot_optimizer.model_step(closure)
        #     if (n + 1) % 10 == 0:
        #         enot_optimizer.step()
        #         enot_optimizer.zero_grad()
        def closure():
            pred_labels = search_space(inputs)
            batch_loss = loss_function(pred_labels, labels)
            batch_loss.backward()
            batch_metric = metric_function(pred_labels, labels)

            train_metrics_acc['loss'] += batch_loss.item()
            train_metrics_acc['accuracy'] += batch_metric.item()
            train_metrics_acc['n'] += 1

        enot_optimizer.step(closure)
        if scheduler is not None:
            scheduler.step()

    train_loss = train_metrics_acc['loss'] / train_metrics_acc['n']
    train_accuracy = train_metrics_acc['accuracy'] / train_metrics_acc['n']

    print('train metrics:')
    print('  loss:', train_loss)
    print('  accuracy:', train_accuracy)

    search_space.eval()
    validation_loss = 0
    validation_accuracy = 0
    for inputs, labels in validation_loader:

        # Sample random architecture from the search space to estimate
        # search space expected metrics.
        search_space.sample_random_arch()

        pred_labels = search_space(inputs)
        batch_loss = loss_function(pred_labels, labels)
        batch_metric = metric_function(pred_labels, labels)

        validation_loss += batch_loss.item()
        validation_accuracy += batch_metric.item()

    n = len(validation_loader)
    validation_loss /= n
    validation_accuracy /= n

    print('validation metrics:')
    print('  loss:', validation_loss)
    print('  accuracy:', validation_accuracy)

    print()

EPOCH #0
train metrics:
  loss: 1.698934056023334
  accuracy: 40.584367605980404
validation metrics:
  loss: 1.6638118374732234
  accuracy: 51.391129032258064

EPOCH #1
train metrics:
  loss: 1.347713968094359
  accuracy: 58.15307328244473
validation metrics:
  loss: 1.489925630390644
  accuracy: 56.95564516129032

EPOCH #2
train metrics:
  loss: 1.1685503467600395
  accuracy: 63.125738760765564
validation metrics:
  loss: 1.4963290958154587
  accuracy: 55.01008064516129

EPOCH #3
train metrics:
  loss: 1.0874943785210873
  accuracy: 65.07535459640178
validation metrics:
  loss: 1.3672184644327048
  accuracy: 57.57056451612903

EPOCH #4
train metrics:
  loss: 1.0874536850350969
  accuracy: 64.60919028748857
validation metrics:
  loss: 1.4683305130850883
  accuracy: 58.346774193548384

EPOCH #5
train metrics:
  loss: 1.0630275870891328
  accuracy: 65.69666075199208
validation metrics:
  loss: 1.1600293024413046
  accuracy: 65.45362903225806

EPOCH #6
train metrics:
  loss: 1.07063086641

In [13]:
# EPOCH #99
# train metrics:
#   loss: 0.29863091677506554
#   accuracy: 90.61391841807264
# validation metrics:
#   loss: 0.32048667211746495
#   accuracy: 89.96975806451613

# torch.save({'model': search_space.state_dict()}, PROJECT_DIR / 'prune_init_result.pth')

In [19]:
for m in search_space.modules():
    if hasattr(m, 'reset_parameters'):
        m.reset_parameters()

In [20]:
N_EPOCHS = 100
N_WARMUP_EPOCHS = 10

train_loader = dataloaders['pretrain_train_dataloader']
validation_loader = dataloaders['pretrain_validation_dataloader']

# using `search_space.model_parameters()` as optimizable variables
optimizer = SGD(params=search_space.model_parameters(), lr=0.06, momentum=0.9, weight_decay=1e-4)
# using `EnotPretrainOptimizer` as a default optimizer
enot_optimizer = EnotPretrainOptimizer(search_space=search_space, optimizer=optimizer)

len_train_loader = len(train_loader)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train_loader*N_EPOCHS, eta_min=1e-8)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train_loader*N_WARMUP_EPOCHS)

metric_function = accuracy
loss_function = nn.CrossEntropyLoss().cuda()

for epoch in range(N_EPOCHS):

    print(f'EPOCH #{epoch}')

    search_space.train()
    train_metrics_acc = {
        'loss': 0.0,
        'accuracy': 0.0,
        'n': 0,
    }
    for inputs, labels in train_loader:
        # By default, `EnotPretrainOptimizer` requires one batch of train data to initialize optimizations, 
        # so you should run `search_space.initialize_output_distribution_optimization(...)` before the first 
        # model step. You can disable optimization checking in `EnotPretrainOptimizer` constructor 
        # (`check_recommended_optimizations` parameter), but this is not recommended.
        if not search_space.output_distribution_optimization_enabled:
            search_space.initialize_output_distribution_optimization(inputs)

        enot_optimizer.zero_grad()
        # Wrapping model step and backward with closure.
        # Alternatively, here is `enot_optimizer.model_step(...)` example usage for gradient accumulation:
        #
        # enot_optimizer.zero_grad()
        # for inputs, labels in train_loader:
        #
        #     def closure():
        #         ...
        #
        #     enot_optimizer.model_step(closure)
        #     if (n + 1) % 10 == 0:
        #         enot_optimizer.step()
        #         enot_optimizer.zero_grad()
        def closure():
            pred_labels = search_space(inputs)
            batch_loss = loss_function(pred_labels, labels)
            batch_loss.backward()
            batch_metric = metric_function(pred_labels, labels)

            train_metrics_acc['loss'] += batch_loss.item()
            train_metrics_acc['accuracy'] += batch_metric.item()
            train_metrics_acc['n'] += 1

        enot_optimizer.step(closure)
        if scheduler is not None:
            scheduler.step()

    train_loss = train_metrics_acc['loss'] / train_metrics_acc['n']
    train_accuracy = train_metrics_acc['accuracy'] / train_metrics_acc['n']

    print('train metrics:')
    print('  loss:', train_loss)
    print('  accuracy:', train_accuracy)

    search_space.eval()
    validation_loss = 0
    validation_accuracy = 0
    for inputs, labels in validation_loader:

        # Sample random architecture from the search space to estimate
        # search space expected metrics.
        search_space.sample_random_arch()

        pred_labels = search_space(inputs)
        batch_loss = loss_function(pred_labels, labels)
        batch_metric = metric_function(pred_labels, labels)

        validation_loss += batch_loss.item()
        validation_accuracy += batch_metric.item()

    n = len(validation_loader)
    validation_loss /= n
    validation_accuracy /= n

    print('validation metrics:')
    print('  loss:', validation_loss)
    print('  accuracy:', validation_accuracy)

    print()

EPOCH #0
train metrics:
  loss: 2.307060359893961
  accuracy: 12.71128841562474
validation metrics:
  loss: 3.0902600749846427
  accuracy: 8.921370967741936

EPOCH #1
train metrics:
  loss: 2.30292133270426
  accuracy: 17.18011229088966
validation metrics:
  loss: 2.825388700731339
  accuracy: 19.758064516129032

EPOCH #2
train metrics:
  loss: 2.298670430639957
  accuracy: 20.09086879162078
validation metrics:
  loss: 2.7842245462440673
  accuracy: 19.294354838709676

EPOCH #3
train metrics:
  loss: 2.1483039579492935
  accuracy: 25.425531910835428
validation metrics:
  loss: 2.1391234859343498
  accuracy: 25.866935483870968

EPOCH #4
train metrics:
  loss: 1.9964911884449899
  accuracy: 29.701536633105988
validation metrics:
  loss: 1.9094789489623039
  accuracy: 34.243951612903224

EPOCH #5
train metrics:
  loss: 1.8526982287143139
  accuracy: 36.02171984733419
validation metrics:
  loss: 1.8766882438813486
  accuracy: 33.538306451612904

EPOCH #6
train metrics:
  loss: 1.7686041849

In [21]:
# EPOCH #99
# train metrics:
#   loss: 0.514968086113321
#   accuracy: 83.17966902712558
# validation metrics:
#   loss: 0.4812319797553843
#   accuracy: 84.52620967741936
    
torch.save({'model': search_space.state_dict()}, PROJECT_DIR / 'common_init_result.pth')