From f09c422bed163da4eb3d5056d234a8cae477f81e Mon Sep 17 00:00:00 2001 From: Unbinilium <15633984+Unbinilium@users.noreply.github.com> Date: Fri, 27 May 2022 19:23:13 +0800 Subject: [PATCH] Add NMS to CoreML model output https://github.com/ultralytics/yolov5/pull/7263 --- export.py | 138 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 126 insertions(+), 12 deletions(-) diff --git a/export.py b/export.py index c9ad158c5f41..0ba9464a2d1f 100644 --- a/export.py +++ b/export.py @@ -189,7 +189,26 @@ def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): LOGGER.info(f'\n{prefix} export failure: {e}') -def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): + +class CoreMLExportModel(torch.nn.Module): + + def __init__(self, base_model, img_size): + super().__init__() + self.base_model = base_model + self.img_size = img_size + + def forward(self, x): + x = self.base_model(x)[0] + x = x.squeeze(0) + # Convert box coords to normalized coordinates [0 ... 1] + w = self.img_size[0] + h = self.img_size[1] + objectness = x[:, 4:5] + class_probs = x[:, 5:] * objectness + boxes = x[:, :4] * torch.tensor([1. / w, 1. / h, 1. / w, 1. / h]) + return class_probs, boxes + +def export_coreml(model, im, file, num_boxes, num_classes, labels, conf_thres, iou_thres, prefix=colorstr('CoreML:')): # YOLOv5 CoreML export try: check_requirements(('coremltools',)) @@ -198,16 +217,110 @@ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') f = file.with_suffix('.mlmodel') - ts = torch.jit.trace(model, im, strict=False) # TorchScript model - ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) - bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None) - if bits < 32: - if platform.system() == 'Darwin': # quantization only supported on macOS - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning - ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) - else: - print(f'{prefix} quantization only supported on macOS, skipping...') + export_model = CoreMLExportModel(model, img_size=opt.imgsz) + + ts = torch.jit.trace(export_model, im, strict=False) # TorchScript model + orig_model = ct.convert( + ts, + inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])] + ) + + spec = orig_model.get_spec() + old_box_output_name = spec.description.output[1].name + old_scores_output_name = spec.description.output[0].name + ct.utils.rename_feature(spec, old_scores_output_name, "raw_confidence") + ct.utils.rename_feature(spec, old_box_output_name, "raw_coordinates") + spec.description.output[0].type.multiArrayType.shape.extend([num_boxes, num_classes]) + spec.description.output[1].type.multiArrayType.shape.extend([num_boxes, 4]) + spec.description.output[0].type.multiArrayType.dataType = ct.proto.FeatureTypes_pb2.ArrayFeatureType.DOUBLE + spec.description.output[1].type.multiArrayType.dataType = ct.proto.FeatureTypes_pb2.ArrayFeatureType.DOUBLE + + yolo_model = ct.models.MLModel(spec) + + # Build Non Maximum Suppression model + nms_spec = ct.proto.Model_pb2.Model() + nms_spec.specificationVersion = 3 + + for i in range(2): + decoder_output = spec.description.output[i].SerializeToString() + + nms_spec.description.input.add() + nms_spec.description.input[i].ParseFromString(decoder_output) + + nms_spec.description.output.add() + nms_spec.description.output[i].ParseFromString(decoder_output) + + nms_spec.description.output[0].name = "confidence" + nms_spec.description.output[1].name = "coordinates" + + output_sizes = [num_classes, 4] + for i in range(2): + ma_type = nms_spec.description.output[i].type.multiArrayType + ma_type.shapeRange.sizeRanges.add() + ma_type.shapeRange.sizeRanges[0].lowerBound = 0 + ma_type.shapeRange.sizeRanges[0].upperBound = -1 + ma_type.shapeRange.sizeRanges.add() + ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i] + ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i] + del ma_type.shape[:] + + nms = nms_spec.nonMaximumSuppression + nms.confidenceInputFeatureName = "raw_confidence" + nms.coordinatesInputFeatureName = "raw_coordinates" + nms.confidenceOutputFeatureName = "confidence" + nms.coordinatesOutputFeatureName = "coordinates" + nms.iouThresholdInputFeatureName = "iouThreshold" + nms.confidenceThresholdInputFeatureName = "confidenceThreshold" + + nms.iouThreshold = iou_thres + nms.confidenceThreshold = conf_thres + nms.pickTop.perClass = True + nms.stringClassLabels.vector.extend(labels) + + nms_model = ct.models.MLModel(nms_spec) + + # Assembling a pipeline model from the two models + input_features = [("image", ct.models.datatypes.Array(3, 640, 640)), + ("iouThreshold", ct.models.datatypes.Double()), + ("confidenceThreshold", ct.models.datatypes.Double())] + + output_features = ["confidence", "coordinates"] + + pipeline = ct.models.pipeline.Pipeline(input_features, output_features) + + pipeline.add_model(yolo_model) + pipeline.add_model(nms_model) + + # The "image" input should really be an image, not a multi-array + pipeline.spec.description.input[0].ParseFromString(spec.description.input[0].SerializeToString()) + + # Copy the declarations of the "confidence" and "coordinates" outputs + # The Pipeline makes these strings by default + pipeline.spec.description.output[0].ParseFromString(nms_spec.description.output[0].SerializeToString()) + pipeline.spec.description.output[1].ParseFromString(nms_spec.description.output[1].SerializeToString()) + + # Add descriptions to the inputs and outputs + pipeline.spec.description.input[1].shortDescription = "(optional) IOU Threshold override" + pipeline.spec.description.input[2].shortDescription = "(optional) Confidence Threshold override" + pipeline.spec.description.output[0].shortDescription = "Boxes Class confidence" + pipeline.spec.description.output[1].shortDescription = "Boxes [x, y, width, height] (normalized to [0...1])" + + # Add metadata to the model + pipeline.spec.description.metadata.shortDescription = "YOLOv5 object detector" + pipeline.spec.description.metadata.author = "Ultralytics" + + # Add the default threshold values and list of class labels + user_defined_metadata = { + "iou_threshold": str(iou_thres), + "confidence_threshold": str(conf_thres), + "classes": ", ".join(labels)} + pipeline.spec.description.metadata.userDefined.update(user_defined_metadata) + + # Don't forget this or Core ML might attempt to run the model on an unsupported operating system version! + pipeline.spec.specificationVersion = 3 + + ct_model = ct.models.MLModel(pipeline.spec) + ct_model.save(f) LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') @@ -525,7 +638,8 @@ def run( if xml: # OpenVINO f[3] = export_openvino(model, file, half) if coreml: - _, f[4] = export_coreml(model, im, file, int8, half) + nb = shape[1] + _, f[4] = export_coreml(model, im, file, nb, nc, names, conf_thres, iou_thres) # TensorFlow Exports if any((saved_model, pb, tflite, edgetpu, tfjs)):