This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
multitask.py
91 lines (74 loc) · 3.73 KB
/
multitask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import collections
from typing import Type, List, Dict
from allennlp.common import JsonDict
from allennlp.data import Instance
from allennlp.models.multitask import MultiTaskModel
from allennlp.predictors.predictor import Predictor
from allennlp.common.util import sanitize
from allennlp.data.fields import MetadataField
from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset_readers import MultiTaskDatasetReader
@Predictor.register("multitask")
class MultiTaskPredictor(Predictor):
"""
Predictor for multitask models.
Registered as a `Predictor` with name "multitask".
This predictor is tightly coupled to `MultiTaskDatasetReader` and `MultiTaskModel`, and will not work if
used with other readers or models.
"""
_WRONG_READER_ERROR = (
"MultitaskPredictor is designed to work with MultiTaskDatasetReader. "
+ "If you have a different DatasetReader, you have to write your own "
+ "Predictor, but you can use MultiTaskPredictor as a starting point."
)
_WRONG_FIELD_ERROR = (
"MultiTaskPredictor expects instances that have a MetadataField "
+ "with the name 'task', containing the name of the task the instance is for."
)
def __init__(self, model: MultiTaskModel, dataset_reader: MultiTaskDatasetReader) -> None:
if not isinstance(dataset_reader, MultiTaskDatasetReader):
raise ConfigurationError(self._WRONG_READER_ERROR)
if not isinstance(model, MultiTaskModel):
raise ConfigurationError(
"MultiTaskPredictor is designed to work only with MultiTaskModel."
)
super().__init__(model, dataset_reader)
self.predictors = {}
for name, head in model._heads.items():
predictor_name = head.default_predictor
predictor_class: Type[Predictor] = (
Predictor.by_name(predictor_name) if predictor_name is not None else Predictor # type: ignore
)
self.predictors[name] = predictor_class(model, dataset_reader.readers[name].inner)
def predict_instance(self, instance: Instance) -> JsonDict:
task_field = instance["task"]
if not isinstance(task_field, MetadataField):
raise ValueError(self._WRONG_FIELD_ERROR)
task: str = task_field.metadata
if not isinstance(self._dataset_reader, MultiTaskDatasetReader):
raise ConfigurationError(self._WRONG_READER_ERROR)
self._dataset_reader.readers[task].apply_token_indexers(instance)
outputs = self._model.forward_on_instance(instance)
return sanitize(outputs)
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
task = json_dict["task"]
del json_dict["task"]
predictor = self.predictors[task]
instance = predictor._json_to_instance(json_dict)
instance.add_field("task", MetadataField(task))
return instance
def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]:
task_to_instances: Dict[str, List[Instance]] = collections.defaultdict(lambda: [])
for instance in instances:
task_field = instance["task"]
if not isinstance(task_field, MetadataField):
raise ValueError(self._WRONG_FIELD_ERROR)
task: str = task_field.metadata
if not isinstance(self._dataset_reader, MultiTaskDatasetReader):
raise ConfigurationError(self._WRONG_READER_ERROR)
self._dataset_reader.readers[task].apply_token_indexers(instance)
task_to_instances[task].append(instance)
outputs = []
for task, instances in task_to_instances.items():
outputs.extend(super().predict_batch_instance(instances))
return outputs