Skip to content

Commit

Permalink
fix(demo): --fp16 not working in demo.py (#648)
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Sep 9, 2021
1 parent 16ac853 commit 6819425
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
trt_file=None,
decoder=None,
device="cpu",
fp16=False,
legacy=False,
):
self.model = model
Expand All @@ -116,6 +117,7 @@ def __init__(
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
self.fp16 = fp16
self.preproc = ValTransform(legacy=legacy)
if trt_file is not None:
from torch2trt import TRTModule
Expand Down Expand Up @@ -145,8 +147,11 @@ def inference(self, img):

img, _ = self.preproc(img, None, self.test_size)
img = torch.from_numpy(img).unsqueeze(0)
img = img.float()
if self.device == "gpu":
img = img.cuda()
if self.fp16:
img = img.half() # to FP16

with torch.no_grad():
t0 = time.time()
Expand Down Expand Up @@ -261,6 +266,8 @@ def main(exp, args):

if args.device == "gpu":
model.cuda()
if args.fp16:
model.half() # to FP16
model.eval()

if not args.trt:
Expand Down Expand Up @@ -291,7 +298,7 @@ def main(exp, args):
trt_file = None
decoder = None

predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.legacy)
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.fp16, args.legacy)
current_time = time.localtime()
if args.demo == "image":
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
Expand Down

0 comments on commit 6819425

Please sign in to comment.