-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune_BERT_for_tagging.py
159 lines (131 loc) · 5.07 KB
/
finetune_BERT_for_tagging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import random
from torch import nn
from argparse import ArgumentParser
from dataset_loaders import load_LinCE
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForTokenClassification
TAGS = ["ambiguous", "lang1", "lang2", "mixed", "ne", "other"]
class LIDataset:
def __init__(
self,
tokenizer,
dataset,
max_seq_len,
):
# The dataset as a list of sentences
# [(tokens_1, tags_1), (tokens_2, tags_2), ....]
self.texts = [s[0] for s in dataset]
self.tags = [s[1] for s in dataset]
# Encode the BIL tags
self.label_list = sorted(
set([tag for tag_list in self.tags for tag in tag_list])
)
self.label_map = {label: i for i, label in enumerate(self.label_list)}
# A wordpiece tokenizer
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
# Use cross entropy ignore_index as padding label id so that only
# real label ids contribute to the loss later.
self.pad_token_label_id = nn.CrossEntropyLoss().ignore_index
def __len__(self):
return len(self.texts)
def __getitem__(self, index):
# TODO: This is a redundant computation
# Is it worth refactoring?
words = self.texts[index]
labels = self.tags[index]
# Tokens and Labels of current example
tokens, label_ids = [], []
for word, label in zip(words, labels):
word_tokens = self.tokenizer.tokenize(word)
if len(word_tokens) > 0:
# Append the list of subtokens
tokens.extend(word_tokens)
# Ignore the subwords
subwords_label = self.pad_token_label_id
label_ids.extend(
# Label for first subtoken
[self.label_map[label]]
+
# Padding for the rest of subtokens
[subwords_label] * (len(word_tokens) - 1)
)
# Truncate the sample while reserving two tokens for [CLS] and [SEP]
if len(tokens) > self.max_seq_len - 2:
# TODO: Add a debugging message
tokens = tokens[: (self.max_seq_len - 2)]
label_ids = label_ids[: (self.max_seq_len - 2)]
# Add the [CLS] and [SEP] tokens
tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
label_ids = [self.pad_token_label_id] + label_ids + [self.pad_token_label_id]
segment_ids = [0] * len(tokens)
# Encode the subwords to indecies
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens
# Only real tokens are attended to
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
padding_length = self.max_seq_len - len(input_ids)
input_ids += [self.tokenizer.pad_token_id] * padding_length
# Don't attend to the padding
input_mask += [0] * padding_length
segment_ids += [0] * padding_length
# Don't use padding on computing loss
label_ids += [self.pad_token_label_id] * padding_length
assert len(input_ids) == self.max_seq_len
assert len(input_mask) == self.max_seq_len
assert len(segment_ids) == self.max_seq_len
assert len(label_ids) == self.max_seq_len
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(input_mask, dtype=torch.long),
"token_type_ids": torch.tensor(segment_ids, dtype=torch.long),
"labels": torch.tensor(label_ids, dtype=torch.long),
}
def model_init(model_name):
model = AutoModelForTokenClassification.from_pretrained(
model_name, num_labels=len(TAGS)
)
return model
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-s", "--seed", type=int, default=42)
parser.add_argument(
"-model_name",
"-m",
default="UBC-NLP/MARBERT",
help="The model name.",
)
parser.add_argument(
"-o",
required=True,
help="The output directory.",
)
args = parser.parse_args()
random.seed(args.seed)
torch.manual_seed(args.seed)
model_name = args.model_name
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_dataset = LIDataset(
tokenizer=tokenizer, dataset=load_LinCE("train"), max_seq_len=512
)
eval_dataset = LIDataset(
tokenizer=tokenizer, dataset=load_LinCE("dev"), max_seq_len=512
)
NO_STEPS = 100
training_args = TrainingArguments(
output_dir=args.o,
save_strategy="epoch",
eval_steps=NO_STEPS,
evaluation_strategy="steps",
seed=args.seed,
)
# Make sure it is using the right optimization function
trainer = Trainer(
model_init=lambda: model_init(model_name),
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()