In [None]:
import json
import os.path as osp

import torch
from torch import Tensor
from models.ofa.networks import CompositeSubNet

from ofa.data_providers import DataProvidersRegistry
from ofa.utils.common_tools import build_config_from_file
from utils import convert_with_all_tensors
from ofa.training.strategies import get_strategy_class

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False


In [None]:
ROOT = "/workspace/proj/output/detection/mbnet/15.08.23_14.43.35.153107"
config = build_config_from_file(osp.join(ROOT, "config.yaml"))

num_samples = 1
config.common.dataset.test_batch_size = num_samples
config.common.dataset.train_batch_size = num_samples

ProviderCLS = DataProvidersRegistry.get_provider_by_name(config.common.dataset.type)
provider = ProviderCLS(config.common.dataset)

CLS = get_strategy_class(config.common.strategy)
strategy = CLS(config.common)
print("strategy inited!")

provider.n_worker = 1
provider.test_batch_size = num_samples
loader = provider.test_loader_builder()
print("loader inited!")


In [None]:
# роллим хороший таргет
itr = iter(loader)

In [None]:
# Этим выводом проверям, что на картинке есть объекты и желательно два класса,
# после идём дальше
data = next(itr)

data['target']

In [None]:
debug = False  # 123 tensor use
# Эти два флага нужны были для RoboDeus, теперь теоретически от них можно отказаться
fp16 = False
cuda = False or fp16

pat = f"{ROOT}/convert"

model_config_path = osp.join(ROOT, "result_model_config.json")
with open(model_config_path) as fin:
    model_config = json.load(fin)
model = CompositeSubNet.build_from_config(model_config)
state_path = osp.join(ROOT, "result_model.pt")
state = torch.load(state_path, map_location="cpu")
model.load_state_dict(state)
model.eval()

print("model inited")

image = data["image"]

i = 0
if debug:
    i = 123
    image = torch.ones_like(image)
    image[:, 1] *= 2
    image[:, 2] *= 3

if cuda:
    strategy.device = torch.device("cuda")
    model.cuda()
    image = image.cuda()
    pat += "_cuda"
else:
    strategy.device = torch.device("cpu")

if fp16:
    assert cuda
    image = image.half()
    pat += "_fp16"
    model.half()


In [None]:
from contextlib import contextmanager
from models.ofa.heads.detection.yolo_v4 import PostprocessMode, YoloV4DetectionHead
from models.ofa.heads.detection.yolo_v4.postprocess import yolo_postprocessing_last

@contextmanager
def specific(model:CompositeSubNet):
    if isinstance(model.head, YoloV4DetectionHead):
        model.head.postprocess = PostprocessMode.PLATFORM
    try:
        yield
    finally:
        if isinstance(model.head, YoloV4DetectionHead):
            model.head.postprocess = PostprocessMode.NMS


In [None]:
with specific(model):
    convert_with_all_tensors(model, image, pat, cuda=cuda, fp16=fp16, check=False,preservation_of_intermediate_tensor=False)

 # Detection stuff

 Руками проверяется сходимость с референсом.
 TODO: REFACTOR/DELETE

In [None]:
# Подготовка аргуметнов
head_config = config.common.supernet_config.head
conf = head_config.conf_thresholds
n_classes = head_config.n_classes
classes = provider.class_names
anchors = head_config.anchors  #  исходные анкеры из конфига
anchors_num = [len(x) for x in head_config.anchors]
imagesize = head_config.image_size
nms_iou_threshold = head_config.nms_iou_threshold
nms_top_k = head_config.nms_top_k

plat_inputs = [
    "",
]
plat_output = "yolo_postprocessing"

conf_thresholds = [conf for c in range(n_classes)]
classmap = {k: v for k, v in enumerate(classes)}

# Получение выходов моделей
head: YoloV4DetectionHead = model.head
plat_anchors = head.yolo_layer.prepare_anchors()
with torch.no_grad():
    # Исходный форвард
    head.postprocess = PostprocessMode.NMS
    out_with_nms: Tensor = model(image)[1]

    # Выход до NMS
    head.postprocess = PostprocessMode.DECODE
    out_without_nms = model(image)[1]

    head.postprocess = PostprocessMode.NONE
    out_to_plat = model(image)
    qwe = []
    for i in range(head.levels):
        sigmoid = head.sigmoids[i]
        x = out_to_plat[i]
        x = sigmoid(x)
        qwe.append(x)
        

    last_out = yolo_postprocessing_last(
        qwe, plat_anchors, 
        conf_thresholds,
        imagesize[1], imagesize[0]
    )

torch.allclose(last_out, out_without_nms[0])