-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
executable file
·83 lines (66 loc) · 2.66 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# -*- coding: utf-8 -*-
# @Time : 2018/8/23 22:20
# @Author : zhoujun
import os
def main(config):
from mxnet import nd
from mxnet.gluon.loss import CTCLoss
from models import get_model
from data_loader import get_dataloader
from trainer import Trainer
from utils import get_ctx, load
if os.path.isfile(config['dataset']['alphabet']):
config['dataset']['alphabet'] = ''.join(load(config['dataset']['alphabet']))
prediction_type = config['arch']['args']['prediction']['type']
num_class = len(config['dataset']['alphabet'])
# loss 设置
if prediction_type == 'CTC':
criterion = CTCLoss()
else:
raise NotImplementedError
ctx = get_ctx(config['trainer']['gpus'])
model = get_model(num_class, ctx, config['arch']['args'])
model.hybridize()
model.initialize(ctx=ctx)
img_h, img_w = 32, 100
for process in config['dataset']['train']['dataset']['args']['pre_processes']:
if process['type'] == "Resize":
img_h = process['args']['img_h']
img_w = process['args']['img_w']
break
img_channel = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1
sample_input = nd.zeros((2, img_channel, img_h, img_w), ctx[0])
num_label = model.get_batch_max_length(sample_input)
train_loader = get_dataloader(config['dataset']['train'], num_label, config['dataset']['alphabet'])
assert train_loader is not None
if 'validate' in config['dataset']:
validate_loader = get_dataloader(config['dataset']['validate'], num_label, config['dataset']['alphabet'])
else:
validate_loader = None
config['lr_scheduler']['args']['step'] *= len(train_loader)
trainer = Trainer(config=config,
model=model,
criterion=criterion,
train_loader=train_loader,
validate_loader=validate_loader,
sample_input=sample_input,
ctx=ctx)
trainer.train()
def init_args():
import argparse
parser = argparse.ArgumentParser(description='DBNet.pytorch')
parser.add_argument('--config_file', default='config/icdar2015.yaml', type=str)
args = parser.parse_args()
return args
if __name__ == '__main__':
import sys
import anyconfig
project = 'crnn.gluon' # 工作项目根目录
sys.path.append(os.getcwd().split(project)[0] + project)
from utils import parse_config
args = init_args()
assert os.path.exists(args.config_file)
config = anyconfig.load(open(args.config_file, 'rb'))
if 'base' in config:
config = parse_config(config)
main(config)