This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
data_collator.py
71 lines (56 loc) · 2.47 KB
/
data_collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import List
from transformers.data.data_collator import DataCollatorForLanguageModeling
from allennlp.common import Registrable
from allennlp.data.batch import Batch
from allennlp.data.data_loaders.data_loader import TensorDict
from allennlp.data.instance import Instance
def allennlp_collate(instances: List[Instance]) -> TensorDict:
"""
This is the default function used to turn a list of `Instance`s into a `TensorDict`
batch.
"""
batch = Batch(instances)
return batch.as_tensor_dict()
class DataCollator(Registrable):
"""
This class is similar with `DataCollator` in [Transformers]
(https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py)
Allow to do some dynamic operations for tensor in different batches
Cause this method run before each epoch to convert `List[Instance]` to `TensorDict`
"""
default_implementation = "allennlp"
def __call__(self, instances: List[Instance]) -> TensorDict:
raise NotImplementedError
@DataCollator.register("allennlp")
class DefaultDataCollator(DataCollator):
def __call__(self, instances: List[Instance]) -> TensorDict:
return allennlp_collate(instances)
@DataCollator.register("language_model")
class LanguageModelingDataCollator(DataCollator):
"""
Register as an `DataCollator` with name `LanguageModelingDataCollator`
Used for language modeling.
"""
def __init__(
self,
model_name: str,
mlm: bool = True,
mlm_probability: float = 0.15,
filed_name: str = "source",
namespace: str = "tokens",
):
self._field_name = filed_name
self._namespace = namespace
from allennlp.common import cached_transformers
tokenizer = cached_transformers.get_tokenizer(model_name)
self._collator = DataCollatorForLanguageModeling(tokenizer, mlm, mlm_probability)
def __call__(self, instances: List[Instance]) -> TensorDict:
tensor_dicts = allennlp_collate(instances)
tensor_dicts = self.process_tokens(tensor_dicts)
return tensor_dicts
def process_tokens(self, tensor_dicts: TensorDict) -> TensorDict:
inputs = tensor_dicts[self._field_name][self._namespace]["token_ids"]
inputs, labels = self._collator.mask_tokens(inputs)
tensor_dicts[self._field_name][self._namespace]["token_ids"] = inputs
tensor_dicts[self._field_name][self._namespace]["labels"] = labels
return tensor_dicts