Skip to content

Commit

Permalink
Optimize decode_outputs for OpenVINO (#1535)
Browse files Browse the repository at this point in the history
Avoid ScatterND ops that will cause error in openvino model optimizer.
  • Loading branch information
Sped0n committed Nov 29, 2022
1 parent d942239 commit e80063b
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions yolox/models/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ def decode_outputs(self, outputs, dtype):
grids = torch.cat(grids, dim=1).type(dtype)
strides = torch.cat(strides, dim=1).type(dtype)

outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
outputs = torch.cat([(outputs[..., 0:2] + grids) * strides, torch.exp(outputs[..., 2:4]) * strides, outputs[..., 4:]], dim=-1)
return outputs

def get_losses(
Expand Down

0 comments on commit e80063b

Please sign in to comment.