-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
67 lines (54 loc) · 2.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import models
import torch
from transformers import logging
from config.parse_args import *
from data.data_reader import *
import trainer.Trainer as Trainer
logging.set_verbosity_info()
logging.enable_explicit_format()
import logging as local_logging
logger = logging.get_logger(__name__)
logger.setLevel('INFO')
local_logging.basicConfig(format="[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s",level=logging.INFO)
from data.tokenizer_utils import prepare_tokenizer
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED, as_completed, ALL_COMPLETED
# import time
def main():
base_args,train_args,model_args,task_args = parse_args()
auto_tokenizer = prepare_tokenizer(model_args._name_or_path, train_args.cache_dir,
special_tokens=train_args.special_tokens)
train_input, eval_input, predict_input = input_builder(model_args._name_or_path, train_args, task_args,
auto_tokenizer)
auto_model = task_args.auto_model if hasattr(task_args,'auto_model') else getattr(models,train_args.task)[identifier(model_args)]
kwargs = {}
if hasattr(auto_model,'set_cfg'):
kwargs["customize_cfg"] = task_args
kwargs["train_cfg"] = train_args
model = auto_model.from_pretrained(
model_args._name_or_path,
from_tf = train_args.from_tf,
config=model_args,
cache_dir=train_args.cache_dir,
**kwargs
)
trainer = getattr(Trainer, train_args.trainer)(
model=model,
args = train_args,
model_args = model_args,
train_dataset = train_input,
eval_dataset = eval_input if not train_args.do_predict else predict_input,
task_args = task_args,
auto_tokenizer=auto_tokenizer
)
if train_args.do_train:
trainer.train()
if train_args.do_predict:
trainer.predict()
if __name__ == "__main__":
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
logger.info("checking GPU")
if not torch.cuda.is_available():
logger.warning("torch.cuda.is_available() Fail")
else:
logger.info("torch.cuda.is_available() Succeed")
main()