Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
ShownX committed Jan 25, 2018
1 parent 1f927b5 commit 387b8e5
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions E2FAR.py
Expand Up @@ -97,7 +97,7 @@ def initialize_inference(inference, pretrained, start_epoch):
print('Done')
elif start_epoch > 0:
print('Loading the weights from [%d] epoch' % start_epoch)
inference.load_params(os.path.join(args.ckpt_dir, args.prefix, '%s-%d.params' % (args.prefix, start_epoch)))
inference.load_params(os.path.join(args.ckpt_dir, args.prefix, '%s-%d.params' % (args.prefix, start_epoch)), ctx)
else:
inference.collect_params().initialize(ctx=ctx)
return inference
Expand Down Expand Up @@ -197,7 +197,18 @@ def train():


def test():
print('Test')

testset = SupervisedDataset(args.test_list)

test_loader = gluon.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True)

inference = E2FAR(freeze=args.freeze)
# initialize
initialize_inference(inference, args.pretrained, args.start_epoch)

for i, batch in enumerate(test_loader):
data = batch[0].as_in_context(ctx)
preds_shape, preds_exp = inference(data)


if __name__ == '__main__':
Expand Down Expand Up @@ -227,7 +238,7 @@ def test():
parser.add_argument('--val_list', default='', type=str, help='validation record')
# test
parser.add_argument('--testing', dest='training', action='store_false', help='testing flag')
parser.add_argument('--test_dir', default='', type=str, help='test record')
parser.add_argument('--test_list', default='', type=str, help='test record')

parser.set_defaults(training=True)

Expand Down

0 comments on commit 387b8e5

Please sign in to comment.