diff --git a/E2FAR.py b/E2FAR.py index 5513ca2..cc379f0 100644 --- a/E2FAR.py +++ b/E2FAR.py @@ -97,7 +97,7 @@ def initialize_inference(inference, pretrained, start_epoch): print('Done') elif start_epoch > 0: print('Loading the weights from [%d] epoch' % start_epoch) - inference.load_params(os.path.join(args.ckpt_dir, args.prefix, '%s-%d.params' % (args.prefix, start_epoch))) + inference.load_params(os.path.join(args.ckpt_dir, args.prefix, '%s-%d.params' % (args.prefix, start_epoch)), ctx) else: inference.collect_params().initialize(ctx=ctx) return inference @@ -197,7 +197,18 @@ def train(): def test(): - print('Test') + + testset = SupervisedDataset(args.test_list) + + test_loader = gluon.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True) + + inference = E2FAR(freeze=args.freeze) + # initialize + initialize_inference(inference, args.pretrained, args.start_epoch) + + for i, batch in enumerate(test_loader): + data = batch[0].as_in_context(ctx) + preds_shape, preds_exp = inference(data) if __name__ == '__main__': @@ -227,7 +238,7 @@ def test(): parser.add_argument('--val_list', default='', type=str, help='validation record') # test parser.add_argument('--testing', dest='training', action='store_false', help='testing flag') - parser.add_argument('--test_dir', default='', type=str, help='test record') + parser.add_argument('--test_list', default='', type=str, help='test record') parser.set_defaults(training=True)