Skip to content

Commit

Permalink
修复eval和predict加载模型的错误
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jul 10, 2020
1 parent ec68c9e commit 678b2ae
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion predict.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
CUDA_VISIBLE_DEVICES=0 python tools/predict.py --model_path model_best.pth --input_folder ./input --output_folder ./output --thre 0.7 --polygon --show --save_resut
CUDA_VISIBLE_DEVICES=0 python tools/predict.py --model_path model_best.pth --input_folder ./input --output_folder ./output --thre 0.7 --polygon --show --save_result
2 changes: 1 addition & 1 deletion tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, model_path, gpu_id=0):

self.validate_loader = get_dataloader(config['dataset']['validate'], config['distributed'])

self.model = build_model(config['arch'].pop('type'), **config['arch'])
self.model = build_model(config['arch'])
self.model.load_state_dict(checkpoint['state_dict'])
self.model.to(self.device)

Expand Down
4 changes: 2 additions & 2 deletions tools/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, model_path, post_p_thre=0.7, gpu_id=None):

config = checkpoint['config']
config['arch']['backbone']['pretrained'] = False
self.model = build_model(config['arch'].pop('type'), **config['arch'])
self.model = build_model(config['arch'])
self.post_process = get_post_processing(config['post_processing'])
self.post_process.box_thresh = post_p_thre
self.img_mode = config['dataset']['train']['dataset']['args']['img_mode']
Expand Down Expand Up @@ -119,7 +119,7 @@ def init_args():
parser.add_argument('--model_path', default=r'model_best.pth', type=str)
parser.add_argument('--input_folder', default='./test/input', type=str, help='img path for predict')
parser.add_argument('--output_folder', default='./test/output', type=str, help='img path for output')
parser.add_argument('--thre', default=0.3, help='the thresh of post_processing')
parser.add_argument('--thre', default=0.3,type=float, help='the thresh of post_processing')
parser.add_argument('--polygon', action='store_true', help='output polygon or box')
parser.add_argument('--show', action='store_true', help='show result')
parser.add_argument('--save_resut', action='store_true', help='save box and score to txt file')
Expand Down

1 comment on commit 678b2ae

@summer-1010
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你好,我在训练的过程中遇到了加载数据集的错误,在base_dataset中的load_data()是要我们自己完善代码吗?因为发现是没有作处理的,感谢回答!

Please sign in to comment.