-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
enc_dec_nmt.py
131 lines (107 loc) · 5.4 KB
/
enc_dec_nmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from pytorch_lightning import Trainer
from nemo.collections.nlp.data.machine_translation.preproc_mt_data import MTDataPreproc
from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import AAYNBaseConfig
from nemo.collections.nlp.models.machine_translation.mt_enc_dec_model import MTEncDecModel
from nemo.core.config import hydra_runner
from nemo.core.config.modelPT import NemoConfig
from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import logging
from nemo.utils.config_utils import update_model_config
from nemo.utils.exp_manager import ExpManagerConfig, exp_manager
"""
Usage:
1. If you need to start docker and install NeMo, otherwise skip this step:
a. ```docker run --gpus all -it --rm -v /home/okuchaiev/repos/NeMo/:/NeMo -p 6006:6006 -v /mnt:/mnt --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:20.11-py3```
b. ```cd /NeMo```
c. ```./reinstall.sh```
2. Train a new tokenizer (or use pre-trained one):
```yttm bpe --data /mnt/D1/Data/NMT/wmt16_de_en/train.clean.en-de.shuffled.common --model tokenizer.BPE.8192.model --vocab_size 8192```
(To use WANDB, optionally, do login first)
``wandb login [YOUR WANDB login]``
3. Start training:
(This example for "base" model on 2 GPUs for 150000 steps with batch size of 12500 tokens per GPU)
python enc_dec_nmt.py \
--config-path=conf \
--config-name=aayn_base \
trainer.gpus=[0,1] \
~trainer.max_epochs \
+trainer.max_steps=150000 \
model.beam_size=4 \
model.max_generation_delta=5 \
model.label_smoothing=0.1 \
model.encoder_tokenizer.tokenizer_model=tokenizer.BPE.8192.model \
model.decoder_tokenizer.tokenizer_model=tokenizer.BPE.8192.model \
model.encoder.num_layers=6 \
model.encoder.hidden_size=512 \
model.encoder.inner_size=2048 \
model.encoder.num_attention_heads=8 \
model.encoder.ffn_dropout=0.1 \
model.decoder.num_layers=6 \
model.decoder.hidden_size=512 \
model.decoder.inner_size=2048 \
model.decoder.num_attention_heads=8 \
model.decoder.ffn_dropout=0.1 \
model.train_ds.src_file_name=/mnt/D1/Data/NMT/wmt16_de_en/train.clean.de.shuffled \
model.train_ds.tgt_file_name=/mnt/D1/Data/NMT/wmt16_de_en/train.clean.en.shuffled \
model.train_ds.tokens_in_batch=12500 \
model.validation_ds.src_file_name=/mnt/D1/Data/NMT/wmt16_de_en/wmt14-en-de.ref \
model.validation_ds.tgt_file_name=/mnt/D1/Data/NMT/wmt16_de_en/wmt14-en-de.src \
model.validation_ds.tokens_in_batch=8192 \
model.test_ds.src_file_name=/mnt/D1/Data/NMT/wmt16_de_en/wmt14-en-de.ref \
model.test_ds.tgt_file_name=/mnt/D1/Data/NMT/wmt16_de_en/wmt14-en-de.src \
model.optim.lr=0.001 \
model.optim.sched.warmup_ratio=0.05 \
+exp_manager.create_wandb_logger=True \
+exp_manager.wandb_logger_kwargs.name=TEST-nmt-base \
+exp_manager.wandb_logger_kwargs.project=nmt-de-en \
+exp_manager.create_checkpoint_callback=True \
+exp_manager.checkpoint_callback_params.monitor=val_sacreBLEU \
+exp_manager.exp_dir=nmt_base \
+exp_manager.checkpoint_callback_params.mode=max
"""
@dataclass
class MTEncDecConfig(NemoConfig):
name: Optional[str] = 'MTEncDec'
do_training: bool = True
model: AAYNBaseConfig = AAYNBaseConfig()
trainer: Optional[TrainerConfig] = TrainerConfig()
exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(name='MTEncDec', files_to_copy=[])
@hydra_runner(config_path="conf", config_name="aayn_base")
def main(cfg: MTEncDecConfig) -> None:
# merge default config with user specified config
default_cfg = MTEncDecConfig()
cfg = update_model_config(default_cfg, cfg)
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'Config: {cfg.pretty()}')
# training is managed by PyTorch Lightning
trainer = Trainer(**cfg.trainer)
# tokenizers will be trained and and tarred training data will be created if needed
# model config is then updated
MTDataPreproc(cfg=cfg.model, trainer=trainer)
if cfg.do_training:
# experiment logs, checkpoints, and auto-resume are managed by exp_manager and PyTorch Lightning
exp_manager(trainer, cfg.exp_manager)
# everything needed to train translation models is encapsulated in the NeMo MTEncdDecModel
mt_model = MTEncDecModel(cfg.model, trainer=trainer)
logging.info("\n\n************** Model parameters and their sizes ***********")
for name, param in mt_model.named_parameters():
print(name, param.size())
logging.info("***********************************************************\n\n")
trainer.fit(mt_model)
if __name__ == '__main__':
main()