-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_eval.py
355 lines (272 loc) · 9.18 KB
/
train_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
# training script and commandline interface for Joint NLU.
# last edited: 10.2.2021
# SP
import argparse
import os
import sys
import warnings
import logging
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import numpy as np
from sklearn_crfsuite import metrics
from transformers import logging as trans_log
from transformers import Trainer, TrainingArguments, DataCollatorForTokenClassification
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
ProgressCallback,
)
from transformers.integrations import TensorBoardCallback, WandbCallback
from transformers.trainer_callback import PrinterCallback
from joint_nlu_models import *
from sklearn.metrics import classification_report
from preprocessing.conll_loader import ConLLLoader, intent_labels_list, slot_labels_list
from joint_metrics import running_metrics, joint_classification_report, exact_match
# configure loggers
trans_log.set_verbosity_warning()
logging.basicConfig( level=logging.INFO)
# disable huggingface warning and wandb
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"
#suppress annoying warnings
if not sys.warnoptions:
warnings.simplefilter("ignore")
# disable wandb for memory efficiency.
# this does not seem to work.
# will make issue on huggingface github
# detect and enable CUDA backend
if torch.cuda.is_available():
device = torch.device("cuda")
logging.info("using CUDA backend")
else:
logging.info("using CPU backend")
device = torch.device("cpu")
# set seeds for reproducibility
torch.manual_seed(136)
np.random.seed(136)
parser = argparse.ArgumentParser(description="run cross lingual NLU experiment")
# train-test arguments
parser.add_argument(
"--train_file",
metavar="train",
type=str,
help="CONLLU training file",
required=True,
)
parser.add_argument(
"--cross_train_file",
type=str,
help="CONLLU cross training file. Will either be concataned to sequentially trained depending on --sequential_cross_train flag",
)
parser.add_argument(
"--test_file",
type=str,
metavar="test",
required=True,
help="CONNLU test file to run after training",
)
parser.add_argument(
"--cross_test_file",
type=str,
metavar="xtest",
help="CONNLU test file to run after training",
)
parser.add_argument(
"--eval_file",
metavar="eval",
help="CONLLU evaluation file for evaluation between epochs",
)
parser.add_argument(
"--num_epochs",
type=int,
default=10,
help="number of training epochs. default is 10",
)
parser.add_argument(
"--training_bs", type=int, default=32, help="training mini batch size. "
)
parser.add_argument(
"--lr", type=float, default=5e-5, help="initial learning rate for adam optimizer."
)
parser.add_argument(
"--save_dir",
type=str,
default="./models",
help="save directory for the final model",
)
parser.add_argument("--chkpt_dir", type=str)
parser.add_argument(
"--no_fp16",
action = 'store_true',
help="enable mixed-precision training. Use False if nvidia apex is unavailable",
)
parser.add_argument(
"--run_name", type=str, default="test_model", help="name of the experiment"
)
parser.add_argument(
"--sequential_train",
action = 'store_true',
help="""sequential train performs cross lingual training sequentially.
Otherwise the second dataset will be concatenated and shuffled""",
)
args = parser.parse_args()
arg_dict = vars(args)
# refactor to use args.xxxx instead
train_path, test_path, eval_path, save_dir = (
arg_dict["train_file"],
arg_dict["test_file"],
arg_dict["eval_file"],
arg_dict["save_dir"],
)
# Check if paths are valid
# todo check for all provided paths
if not os.path.isfile(train_path):
print("The training file does not exist")
sys.exit()
if not os.path.isfile(test_path):
print("The test file specified does not exist")
sys.exit()
if eval_path != None:
if not os.path.isfile(eval_path):
print("The eval file does not exist")
sys.exit()
if save_dir != None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# train
path2train_en = train_path
# test
path2test_en = test_path
#cross train
path2xtrain_en = args.cross_train_file
#cross test
path2xtest_en = args.cross_test_file
#eval...
path2eval_en = eval_path
# name of transformer model to used. Only tested on XLM-Roberta base but theoretically extensible to other models as long as they share similar dimensions.
# TODO add as argument. but make sure it works with joint model.
pretrained_name = "xlm-roberta-base"
# BUG: Sometimes Autotokenizer cannot find pre-trained models in cache and re-downloads.
# There is no problem with this except that
# Issue number #4197 on Huggingface github
# Load default tokenizer for XLM-R
tokenizer = AutoTokenizer.from_pretrained(pretrained_name)
# default data collator for sequential data
data_collator = DataCollatorForTokenClassification(tokenizer)
# parse, tokenize data. CONLLLoader is a sub-class of torch.Dataset and shares most of the functionalities.
train_set = ConLLLoader(path2train_en, tokenizer, intent_labels_list, slot_labels_list)
test_set = ConLLLoader(path2test_en, tokenizer, intent_labels_list, slot_labels_list)
#uncomment here to perform sequential training
# Temp fix for bug
#args.sequential_train = True
# check if cross train set provided else none
if path2xtrain_en != None:
# detect if second training set is provided
logging.info("using cross-training")
extrain_set = ConLLLoader(
path2xtrain_en, tokenizer, intent_labels_list, slot_labels_list
)
# if not sequential mix datasets
if args.sequential_train != True:
logging.info("using combined data strategy")
train_set = ConcatDataset([train_set, extrain_set])
else:
extrain_set = None
# check if cross train provided
# else return none
if path2xtest_en != None:
# detect if second training set is provided
logging.info("using cross-testing")
extest_set = ConLLLoader(
path2xtest_en, tokenizer, intent_labels_list, slot_labels_list
)
else:
#print("no x test provided")
extest_set = None
if path2eval_en != None:
# detect if eval dataset is provided
eval_set = ConLLLoader(
path2eval_en, tokenizer, intent_labels_list, slot_labels_list
)
else:
eval_set = None
#print(args.no_fp16)
if args.no_fp16 == True:
fp16 = False
logging.info("mixed precision training disabled")
else:
fp16 = True
training_args = TrainingArguments(
output_dir="./results", # output directory for check points
num_train_epochs=args.num_epochs, # total # of training epochs
per_device_train_batch_size=args.training_bs, # batch size per device during training
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir="./logs",
learning_rate=args.lr,
save_steps=5000,
fp16=fp16,
label_names=["intent_labels", "slot_labels"],
evaluation_strategy="epoch",
run_name=args.run_name,
save_total_limit = 3, # prevent save files from crashing computer
#load_best_model_at_end = True,
)
# default config for XLM-R
conf = config_init(pretrained_name)
# Initialize joint classifier model
model = JointClassifier(conf, num_intents=12, num_slots=31)
# test freeze bert
# make arg in file
#for param in model.roberta.parameters():
#param.requires_grad = False
# setup trainer for 1st train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_set,
eval_dataset=eval_set,
data_collator=data_collator,
compute_metrics=running_metrics,
#callbacks = [PrinterCallback]
)
# attempt to remove overly verbose output from Trainer API
trainer.remove_callback(PrinterCallback)
trainer.remove_callback(WandbCallback)
trainer.remove_callback(TensorBoardCallback)
logging.info("starting training session")
# train on first file
trainer.train()
# train model using transformers.Trainer API
# train on cross lingual training data if sequential training strategy is used
# detect if sequential training.
if extrain_set != None and (args.sequential_train == True):
logging.info("performing sequential training")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=extrain_set,
eval_dataset=eval_set,
data_collator=data_collator,
compute_metrics=running_metrics,
#callbacks = [PrinterCallback]
)
# sequential train
trainer.train()
# save model
trainer.save_model(args.save_dir)
# standard p,r, f1 classification report for both intents and slots
print("trained on: ", train_path)
# predict on 1st test file
print("*"*15, "test results for ",args.test_file, "*"*15)
preds = trainer.predict(test_set)
result_dict = joint_classification_report(preds, intent_labels_list, slot_labels_list)
print("exact matches", exact_match(preds))
# predict on 2nd test file
if extest_set != None:
print("*"*15, "cross test results for ",args.cross_test_file, "*"*15)
preds = trainer.predict(extest_set)
result_dict = joint_classification_report(preds, intent_labels_list, slot_labels_list)
print("exact matches", exact_match(preds))