Skip to content

Commit

Permalink
refactor(inferencer): inference and visualize image one by one (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
hujiahao1 committed Nov 8, 2021
1 parent b0ac29a commit 0138d0c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
6 changes: 3 additions & 3 deletions eod/apis/inference.py
Expand Up @@ -185,8 +185,8 @@ def predict(self):
batch = self.fetch([filename])
with torch.no_grad():
output = self.detector(batch)
output_list += self.map_back(output)
return output_list
output = self.map_back(output)
self.vis(output)

def vis(self, outputs):
for img_idx, output in enumerate(outputs):
Expand All @@ -198,7 +198,7 @@ def vis(self, outputs):
filename = os.path.basename(output['image_info'][-1])
if self.vis_type == 'plt':
filename = filename.rsplit('.', 1)[0]
logger.info('visualizing {}:{}'.format(img_idx, filename))
logger.info('visualizing {}'.format(filename))

img_h, img_w = img.shape[:2]
classes = boxes[:, -1].astype(np.int32)
Expand Down
3 changes: 1 addition & 2 deletions eod/commands/inference.py
Expand Up @@ -61,8 +61,7 @@ def main(args):
infer_cfg['kwargs'] = infer_cfg.get('kwargs', {})
cfg['runtime']['inferencer'] = infer_cfg
inferencer = INFERENCER_REGISTRY.get(infer_cfg['type'])(cfg, **infer_cfg['kwargs'])
output_list = inferencer.predict()
inferencer.vis(output_list)
inferencer.predict()


def _main(args):
Expand Down

0 comments on commit 0138d0c

Please sign in to comment.