-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
57 lines (45 loc) · 1.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import warnings
warnings.simplefilter("ignore", UserWarning)
import logging
import torch
import numpy as np
from modules.tokenizers import Tokenizer
from modules.dataloaders import LADataLoader
from modules.metrics import compute_scores
from modules.optimizers import build_optimizer, build_lr_scheduler
from modules.trainer import Trainer
from modules.loss import compute_loss
import models
from config import opts
def main():
# parse arguments
# args = parse_agrs()
args = opts.parse_opt()
logging.info(str(args))
# fix random seeds
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
# create tokenizer
tokenizer = Tokenizer(args)
# create data loader
train_dataloader = LADataLoader(args, tokenizer, split='train', shuffle=True)
val_dataloader = LADataLoader(args, tokenizer, split='val', shuffle=False)
test_dataloader = LADataLoader(args, tokenizer, split='test', shuffle=False)
# build model architecture
model_name = f"LAMRGModel_v{args.version}"
logging.info(f"Model name: {model_name} \tModel Layers:{args.num_layers}")
model = getattr(models, model_name)(args, tokenizer)
# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores
# build optimizer, learning rate scheduler
optimizer = build_optimizer(args, model)
lr_scheduler = build_lr_scheduler(args, optimizer)
# build trainer and start to train
trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)
trainer.train()
logging.info(str(args))
if __name__ == '__main__':
main()