/
train.py
executable file
·235 lines (194 loc) · 7.37 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python3
from pathlib import Path
import json
import subprocess
import shutil
import sys
import argostrain
from argostrain.dataset import *
from argostrain import data
import argostrain.opennmtutils
from argostrain import settings
import stanza
def train(
from_code,
to_code,
from_name,
to_name,
version,
package_version,
argos_version,
data_exists,
epochs_count,
):
settings.RUN_PATH.mkdir(exist_ok=True)
settings.CACHE_PATH.mkdir(exist_ok=True)
MAX_DATA_SIZE = 5 * (10 ** 7)
# Check for existing checkpoints
checkpoints = argostrain.opennmtutils.get_checkpoints()
if len(checkpoints) > 0:
input("Warning: Checkpoints exist (enter to continue)")
if not data_exists:
# Delete training data if it exists
settings.SOURCE_PATH.unlink(missing_ok=True)
settings.TARGET_PATH.unlink(missing_ok=True)
available_datasets = get_available_datasets()
from_and_to_codes = [from_code, to_code]
available_datasets = list(
filter(
lambda x: x.from_code in from_and_to_codes
and x.to_code in from_and_to_codes,
available_datasets,
)
)
# Limit max amount of data used
limited_datasets = list()
limited_datasets_size = 0
available_datasets.sort(key=lambda x: x.size)
for dataset in available_datasets:
if limited_datasets_size + dataset.size < MAX_DATA_SIZE:
limited_datasets.append(dataset)
limited_datasets_size += dataset.size
else:
print(f"Excluding data {str(dataset)}, over MAX_DATA_SIZE")
available_datasets = limited_datasets
datasets = list(
filter(
lambda x: x.from_code == from_code and x.to_code == to_code,
available_datasets,
)
)
# Try to use reverse data
reverse_datasets = list(
filter(
lambda x: x.to_code == from_code and x.from_code == to_code,
available_datasets,
)
)
for reverse_dataset in reverse_datasets:
reverse_dataset_data = reverse_dataset.data()
dataset = Dataset(reverse_dataset_data[1], reverse_dataset_data[0])
# Hack to preserve reference metadata
dataset.reference = reverse_dataset.reference
dataset.size = reverse_dataset.size
datasets.append(dataset)
if len(datasets) == 0:
print(
f"No data available for this language pair ({from_code}-{to_code}), check data-index.json"
)
sys.exit(1)
assert len(datasets) > 0
# Download and write data source and target
while len(datasets) > 0:
dataset = datasets.pop()
print(str(dataset))
source, target = dataset.data()
with open(settings.SOURCE_PATH, "a") as s:
s.writelines(source)
with open(settings.TARGET_PATH, "a") as t:
t.writelines(target)
del dataset
# Generate README.md
# This is broken somehow, the template is written but the credits are not added
# Maybe there's an issue with an end of file token in the template?
readme = f"# {from_name}-{to_name}"
with open(Path("MODEL_README.md")) as readme_template:
readme += "".join(readme_template.readlines())
for dataset in datasets:
readme += dataset.reference + "\n\n"
with open(settings.RUN_PATH / "README.md", "w") as readme_file:
readme_file.write(readme)
# Generate metadata.json
metadata = {
"package_version": package_version,
"argos_version": argos_version,
"from_code": from_code,
"from_name": from_name,
"to_code": to_code,
"to_name": to_name,
}
metadata_json = json.dumps(metadata, indent=4)
with open(settings.RUN_PATH / "metadata.json", "w") as metadata_file:
metadata_file.write(metadata_json)
argostrain.data.prepare_data(settings.SOURCE_PATH, settings.TARGET_PATH)
with open(Path("run/split_data/all.txt"), "w") as combined:
with open(Path("run/split_data/src-train.txt")) as src:
for line in src:
combined.write(line)
with open(Path("run/split_data/tgt-train.txt")) as tgt:
for line in tgt:
combined.write(line)
# TODO: Don't hardcode vocab_size and set user_defined_symbols
subprocess.run(
[
"spm_train",
"--input=run/split_data/all.txt",
"--model_prefix=run/sentencepiece",
"--vocab_size=50000",
"--character_coverage=0.9995",
"--input_sentence_size=1000000",
"--shuffle_input_sentence=true",
]
)
subprocess.run(["rm", "run/split_data/all.txt"])
subprocess.run(["onmt_build_vocab", "-config", "config.yml", "-n_sample", "-1"])
subprocess.run(["onmt_train", "-config", "config.yml"])
# Average checkpoints
opennmt_checkpoints = argostrain.opennmtutils.get_checkpoints()
opennmt_checkpoints.sort()
subprocess.run(
[
"./../OpenNMT-py/tools/average_models.py",
"-m",
str(opennmt_checkpoints[-2].f),
str(opennmt_checkpoints[-1].f),
"-o",
"run/averaged.pt",
]
)
subprocess.run(
[
"ct2-opennmt-py-converter",
"--model_path",
"run/averaged.pt",
"--output_dir",
"run/model",
"--quantization",
"int8",
]
)
package_version_code = package_version.replace(".", "_")
model_dir = f"translate-{from_code}_{to_code}-{package_version_code}"
model_path = Path("run") / model_dir
subprocess.run(["mkdir", model_path])
subprocess.run(["cp", "-r", "run/model", model_path])
subprocess.run(["cp", "run/sentencepiece.model", model_path])
# Include a Stanza sentence boundary detection model
stanza_model_located = False
stanza_lang_code = from_code
while not stanza_model_located:
try:
stanza.download(stanza_lang_code, dir="run/stanza", processors="tokenize")
stanza_model_located = True
except:
print(f"Could not locate stanza model for lang {stanza_lang_code}")
print(
"Enter the code of a different language to attempt to use its stanza model."
)
print(
"This will work best for with a similar language to the one you are attempting to translate."
)
print(
"This will require manually editing the Stanza package in the finished model to change its code"
)
stanza_lang_code = input("Stanza language code (ISO 639): ")
subprocess.run(["cp", "-r", "run/stanza", model_path])
subprocess.run(["cp", "run/metadata.json", model_path])
subprocess.run(["cp", "run/README.md", model_path])
package_path = (
Path("run")
/ f"translate-{from_code}_{to_code}-{package_version_code}.argosmodel"
)
shutil.make_archive(model_dir, "zip", root_dir="run", base_dir=model_dir)
subprocess.run(["mv", model_dir + ".zip", package_path])
print(f"Package saved to {str(package_path.resolve())}")