-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodel_training.py
120 lines (103 loc) · 3.65 KB
/
model_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# Copyright (c) 2020-2023 Antmicro <www.antmicro.com>
#
# SPDX-License-Identifier: Apache-2.0
"""
The script for training models given in ModelWrapper object with dataset given
in Dataset object.
"""
import argparse
import sys
from pathlib import Path
from typing import List, Optional, Tuple
from kenning.cli.command_template import (
GROUP_SCHEMA,
TEST,
TRAIN,
ArgumentsGroups,
CommandTemplate,
ParserHelpException,
)
from kenning.cli.completers import DATASETS, MODEL_WRAPPERS, ClassPathCompleter
from kenning.utils.class_loader import get_command, load_class
class TrainModel(CommandTemplate):
"""
Command template for training models with ModelWrapper.
"""
parse_all = False
description = __doc__[:-1]
@staticmethod
def configure_parser(
parser: Optional[argparse.ArgumentParser] = None,
command: Optional[str] = None,
types: List[str] = [],
groups: Optional[ArgumentsGroups] = None,
) -> Tuple[argparse.ArgumentParser, ArgumentsGroups]:
parser, groups = super(TrainModel, TrainModel).configure_parser(
parser, command, types, groups, TEST in types
)
# other_group = groups[DEFAULT_GROUP]
train_group = parser.add_argument_group(GROUP_SCHEMA.format(TRAIN))
train_group.add_argument(
"--modelwrapper-cls",
help="ModelWrapper-based class with inference implementation to import", # noqa: E501
required=True,
).completer = ClassPathCompleter(MODEL_WRAPPERS)
train_group.add_argument(
"--dataset-cls",
help="Dataset-based class with dataset to import",
required=True,
).completer = ClassPathCompleter(DATASETS)
train_group.add_argument(
"--batch-size",
help="The batch size for training",
type=int,
required=True,
)
train_group.add_argument(
"--learning-rate",
help="The learning rate for training",
type=float,
required=True,
)
train_group.add_argument(
"--num-epochs",
help="Number of epochs to train for",
type=int,
required=True,
)
train_group.add_argument(
"--logdir",
help="Path to the training logs directory",
type=Path,
required=True,
)
return parser, groups
@staticmethod
def run(args: argparse.Namespace, not_parsed: List[str] = [], **kwargs):
modelwrappercls = (
load_class(args.modelwrapper_cls)
if args.modelwrapper_cls
else None
)
datasetcls = load_class(args.dataset_cls) if args.dataset_cls else None
parser = argparse.ArgumentParser(
" ".join(map(lambda x: x.strip(), get_command(with_slash=False))),
parents=[]
+ ([modelwrappercls.form_argparse()[0]] if modelwrappercls else [])
+ ([datasetcls.form_argparse()[0]] if datasetcls else []),
add_help=False,
)
if args.help:
raise ParserHelpException(parser)
args = parser.parse_args(not_parsed, namespace=args)
dataset = datasetcls.from_argparse(args)
model = modelwrappercls.from_argparse(dataset, args, from_file=False)
args.logdir.mkdir(parents=True, exist_ok=True)
model.prepare_model()
model.train_model(
args.batch_size, args.learning_rate, args.num_epochs, args.logdir
)
model.save_model(model.get_path())
if __name__ == "__main__":
sys.exit(TrainModel.scenario_run(sys.argv))