In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
import sys
import yaml
from copy import deepcopy
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from enot.logging import prepare_log
from enot.models import SearchSpaceModel
from enot.models import SearchVariantsContainer

In [None]:
print(torch.__version__, torch.version.cuda, np.__version__)

In [None]:
from models.experimental import attempt_load
from models.yolo import Model
from models import yolo
from utils.torch_utils import intersect_dicts

yolo.LOGGER = prepare_log(log_format='%(message)s')

In [None]:
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
!nvidia-smi

In [None]:
device = 'cuda:0'

if 'cuda' in device:
    torch.cuda.set_device(device)

In [None]:
# model_config = 'models/hub/yolov5l6.yaml'
# model_config = 'models/hub/yolov5l6ss_v1.yaml'
# model_config = 'models/yolov5lss_v2.yaml'
model_config = 'models/yolov5sss_v1.yaml'

root_directory = Path('.').absolute()
model_config_file = root_directory / model_config

In [None]:
weights = 'weights/best.pt'

In [None]:
architecture_indices = [0, 1, 2, 3, 4, 5, 6, 7]
# architecture_indices = [0] * 8

In [None]:
model = Model(model_config_file).to(device)

search_space = None
if any(isinstance(layer, SearchVariantsContainer) for layer in model.modules()):
    search_space = SearchSpaceModel(model)

main_model = search_space if search_space is not None else model

In [None]:
ckpt = torch.load(weights, map_location=device)  # load checkpoint
ckpt_model = ckpt['ema']
state_dict = ckpt_model.float().state_dict()  # checkpoint state_dict as FP32
state_dict = intersect_dicts(state_dict, main_model.state_dict())
main_model.load_state_dict(state_dict, strict=True)
yolo.LOGGER.info(f'Transferred {len(state_dict)} / {len(main_model.state_dict())} items from {weights}')

if search_space is not None:
    model = search_space.get_network_by_indexes(architecture_indices)
    search_space = None

In [None]:
from enot.utils.batch_norm import tune_bn_stats
from utils.general import check_dataset
from utils.datasets import create_dataloader
import yaml

def preprocess_data(x):
    return (x[0].to(device).float() / 255.0, ), {}

hyp = 'data/hyps/hyp.scratch.yaml'
with open(hyp) as f:
    hyp = yaml.safe_load(f)  # load hyps dict

yolo.LOGGER.info('hyperparameters: ' + ', '.join(f'{k}={v}' for k, v in hyp.items()))
gs = 32

data = 'data/enot_coco.yaml'
data_dict = check_dataset(data)
train_path, val_path = data_dict['pretrain'], data_dict['val']


train_loader, dataset = create_dataloader(
    train_path,
    640,
    10,
    gs,
    hyp=hyp,
    augment=True,
    rect=False,
    pad=0.0,
)

In [None]:
tune_bn_stats(
    model,
    train_loader,
    reset_bns=True,
    set_momentums_none=True,
    n_steps=256,
    sample_to_model_inputs=preprocess_data,
)

In [None]:
ckpt = torch.load(weights, map_location=device)
ckpt['model'] = ckpt['model'].get_network_by_indexes(architecture_indices).eval().cpu()
ckpt['ema'] = ckpt['ema'].get_network_by_indexes(architecture_indices).eval().cpu()
torch.save(ckpt, 'weights/extract.pt')

In [None]:
def tune_bn_and_save(input_path, output_path):

    model = Model(model_config_file).to(device)
    if any(isinstance(layer, SearchVariantsContainer) for layer in model.modules()):
        model = SearchSpaceModel(model)

    ckpt = torch.load(input_path, map_location=device)

    state_dict = ckpt['model'].float().state_dict()
    state_dict = intersect_dicts(state_dict, model.state_dict())
    model.load_state_dict(state_dict, strict=True)
    yolo.LOGGER.info(f'Transferred {len(state_dict)} / {len(model.state_dict())} items from {input_path}')

    model.sample([[x] for x in architecture_indices])

    tune_bn_stats(
        model,
        train_loader,
        reset_bns=True,
        set_momentums_none=True,
        n_steps=25,
        sample_to_model_inputs=preprocess_data,
    )

    ckpt['model'] = deepcopy(model).cpu().half()

    if 'ema' in ckpt:
        state_dict = ckpt['ema'].float().state_dict()
        state_dict = intersect_dicts(state_dict, model.state_dict())
        model.load_state_dict(state_dict, strict=True)
        yolo.LOGGER.info(f'Transferred {len(state_dict)} / {len(model.state_dict())} items from {input_path}')

        model.sample([[x] for x in architecture_indices])

        tune_bn_stats(
            model,
            train_loader,
            reset_bns=True,
            set_momentums_none=True,
            n_steps=25,
            sample_to_model_inputs=preprocess_data,
        )

        ckpt['ema'] = deepcopy(model).cpu().half()

    torch.save(ckpt, output_path)

tune_bn_and_save(
    'weights/best.pt',
    'weights/new.pt',
)