-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
47 lines (37 loc) · 1.29 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import numpy as np
from mindspore import Model
from mindspore import context
from mindspore.common import set_seed
from src.args import args
from src.tools.cell import cast_amp
from src.tools.get_misc import get_dataset, set_device, get_model, pretrained
set_seed(args.seed)
def main():
mode = {
0: context.GRAPH_MODE,
1: context.PYNATIVE_MODE
}
context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
context.set_context(enable_graph_kernel=False)
if args.device_target == "Ascend":
context.set_context(enable_auto_mixed_precision=True)
set_device(args)
# get model
net = get_model(args)
cast_amp(net)
if args.pretrain_url:
pretrained(args, net)
data = get_dataset(args, training=False)
model = Model(net)
print("begin predict")
ckpt_save_dir = os.path.join(args.output_path, 'result.txt')
result = ''
for x in data.test_dataset.create_dict_iterator():
outcome = model.predict(x["image"]).asnumpy()
result += ''.join([f"{i}\n" for i in np.argmax(outcome, axis=1)])
with open(ckpt_save_dir, 'w', encoding='utf-16le') as f:
f.write(result)
print("File saved")
if __name__ == '__main__':
main()