-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Multitask example #4898
Multitask example #4898
Changes from 80 commits
cc083d0
5333901
8808b7d
37cc69c
4668bb1
1844709
4d1cd4e
2351649
9c6e505
3fe3bd3
1ef85c5
1f1abe4
53ba25c
947735e
c561870
9cd8cb1
8432834
4db41b0
9cea943
1eeeed5
11fd67c
7aa7e0e
08a498c
a12769c
16f1d12
b363af5
bcdc63c
639d043
a55f141
184c45f
a1367cf
16dc3c4
e37dd4a
ac1a2d9
0a9ac4e
866d43d
6aa7e0e
ff35247
626c509
67f85d5
140bbd3
8b7668e
a633e67
67b942b
54c8a46
ddb0501
5e30634
42d0a65
73111ad
74c8637
6a79238
85c2329
493712f
6dac87d
0858d8d
1c064b7
ce20246
58b6745
0a4db8f
9ae0555
7d420bb
79baa38
4b52a56
a9ddd17
50baedc
da29019
25dad23
1258dd0
0da384f
6e7030b
0d011a3
7165edf
8ccb259
8a5d8cf
b926016
18a6772
931af0d
9c0e950
f11a0f4
440c7aa
e771934
bd1a2f1
7c145b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,16 @@ | ||
from typing import Any, Dict, Iterable, Iterator, List, Union, Optional | ||
from typing import Any, Dict, Iterable, Iterator, Union, Optional | ||
import itertools | ||
import math | ||
|
||
import torch | ||
from overrides import overrides | ||
|
||
from allennlp.common import util | ||
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, DatasetReaderInput | ||
from allennlp.data.batch import Batch | ||
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict | ||
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader | ||
from allennlp.data.data_loaders.multitask_scheduler import ( | ||
MultiTaskScheduler, | ||
HomogeneousRoundRobinScheduler, | ||
) | ||
from allennlp.data.data_loaders.multitask_scheduler import MultiTaskScheduler | ||
from allennlp.data.data_loaders.multitask_epoch_sampler import MultiTaskEpochSampler | ||
from allennlp.data.dataset_readers.multitask import MultiTaskDatasetReader | ||
from allennlp.data.instance import Instance | ||
|
@@ -54,10 +52,6 @@ class MultiTaskDataLoader(DataLoader): | |
data_path: `Dict[str, str]` | ||
One file per underlying dataset reader in the `MultiTaskDatasetReader`, which will be passed | ||
to those readers to construct one `DataLoader` per dataset. | ||
batch_size: `int` | ||
The number of instances (from any dataset) that should be combined together into a single | ||
batch. See also the `batch_size_multiplier` argument for additional control over exactly | ||
how batch size is computed. | ||
scheduler: `MultiTaskScheduler`, optional (default = `HomogeneousRoundRobinScheduler`) | ||
The `scheduler` determines how instances are ordered within an epoch. By default, we'll | ||
select one batch of instances from each dataset in turn, trying to ensure as uniform a mix | ||
|
@@ -71,25 +65,9 @@ class MultiTaskDataLoader(DataLoader): | |
for an epoch, this `sampler` will tell us with what proportion we should sample from each | ||
dataset. For instance, we might want to focus more on datasets that are underperforming in | ||
some way, by having those datasets contribute more instances this epoch than other datasets. | ||
batch_size_multiplier: `Dict[str, float]`, optional (default = `None`) | ||
If this is not `None`, it specifies how much of the batch an instance from each dataset | ||
takes up. That is, if this is 1 for every dataset (which is the default), then batch size | ||
is computed as normal. If dataset "A" has a value of 1.5 in this dictionary, than each | ||
instance from dataset "A" counts as 1.5 instances for the purposes of computing batch size. | ||
This option is available to you to account for the fact that some operations might be *much* | ||
less costly than others (e.g., if you are multitasking a coref model with a simple document | ||
classification model). If you use this, you're on your own as far as figuring out how it | ||
interacts with optimization behavior. | ||
instances_per_epoch: `int`, optional (default = `None`) | ||
If not `None`, we will use this many instances per epoch of training, drawing from the | ||
underlying datasets with proportions given by the `scheduler`. Note that this is | ||
_instances_, not _batches_, because if you're using batch size multipliers we don't know how | ||
many batches the instances specified by the `scheduler` will turn out to be. | ||
drop_last: `bool`, optional (default = `False`) | ||
If this is `True`, we will not return the last batch if it is smaller than `batch_size`. | ||
Note that this is kind of nonsensical to use if you're using `batch_size_multipliers`, as | ||
you are not guaranteed to get an optimal packing, so you will likely have batches that don't | ||
fill up the `batch_size` in that case, anyway. | ||
underlying datasets according to the `sampler`. | ||
num_workers: `Dict[str, int]`, optional (default = `None`) | ||
Used when creating one `MultiProcessDataLoader` per dataset. If you want non-default | ||
behavior for this parameter in the `DataLoader` for a particular dataset, pass the | ||
|
@@ -126,13 +104,10 @@ def __init__( | |
self, | ||
reader: MultiTaskDatasetReader, | ||
data_path: Dict[str, str], | ||
batch_size: int, | ||
scheduler: MultiTaskScheduler, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The biggest change here is that now the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also means that the scheduler is no longer optional. You have to specify one. Otherwise we don't even know a batch size. |
||
*, | ||
scheduler: MultiTaskScheduler = None, | ||
sampler: MultiTaskEpochSampler = None, | ||
instances_per_epoch: int = None, | ||
batch_size_multiplier: Dict[str, float] = None, | ||
drop_last: bool = False, | ||
num_workers: Dict[str, int] = None, | ||
max_instances_in_memory: Dict[str, int] = None, | ||
start_method: Dict[str, str] = None, | ||
|
@@ -143,7 +118,7 @@ def __init__( | |
) -> None: | ||
self.readers = reader.readers | ||
self.data_paths = data_path | ||
self.scheduler = scheduler or HomogeneousRoundRobinScheduler(batch_size=batch_size) | ||
self.scheduler = scheduler | ||
self.sampler = sampler | ||
self.cuda_device: Optional[torch.device] = None | ||
if cuda_device is not None: | ||
|
@@ -152,20 +127,12 @@ def __init__( | |
else: | ||
self.cuda_device = cuda_device | ||
|
||
self._batch_size = batch_size | ||
self._instances_per_epoch = instances_per_epoch | ||
self._batch_size_multiplier = batch_size_multiplier or {} | ||
for multiplier in self._batch_size_multiplier.values(): | ||
if multiplier > batch_size: | ||
raise ValueError( | ||
f"Multiplier value ({multiplier}) is larger than batch size ({batch_size})" | ||
) | ||
self._drop_last = drop_last | ||
self._shuffle = shuffle | ||
|
||
if instances_per_epoch is not None and sampler is None: | ||
raise ValueError( | ||
"You must provide an EpochSampler if you want to not use all instances every epoch" | ||
"You must provide an EpochSampler if you want to not use all instances every epoch." | ||
) | ||
|
||
self._num_workers = num_workers or {} | ||
|
@@ -176,8 +143,8 @@ def __init__( | |
|
||
if self.readers.keys() != self.data_paths.keys(): | ||
raise ValueError( | ||
f"Mismatch between readers ({self.readers.keys()}) and data paths" | ||
" ({self.data_paths.keys()})" | ||
f"Mismatch between readers ({self.readers.keys()}) and data paths " | ||
f"({self.data_paths.keys()})" | ||
) | ||
self._loaders = {key: self._make_data_loader(key) for key in self.readers} | ||
|
||
|
@@ -197,7 +164,7 @@ def __init__( | |
# for the loader variable, so a _different_ loader gets saved for every iterator. | ||
# Dictionary comprehensions don't create new scopes in python. If you don't have | ||
# this loader, you end up with `loader` always referring to the last loader in the | ||
# iteration... mypy also doesn't know what to do with this, for some reason I can't | ||
# iteration... mypy also doesn't know what to do with this, for some reason I can't | ||
# figure out. | ||
lambda l=loader: maybe_shuffle_instances(l, self._shuffle) # type: ignore | ||
) | ||
|
@@ -206,56 +173,27 @@ def __init__( | |
|
||
@overrides | ||
def __len__(self) -> int: | ||
if self._instances_per_epoch is not None: | ||
return self._instances_per_epoch | ||
|
||
# Here we try to estimate the actual length. If you are using varying batch size | ||
# multipliers per task, we may get batch packing orders that make this an underestimate, as | ||
# this assumes that all batches are full, which may not be true. | ||
total_instances = 0.0 | ||
for key, loader in self._loaders.items(): | ||
if self._instances_per_epoch is None: | ||
# This will raise a TypeError if any of the underlying loaders doesn't have a length, | ||
# which is actually what we want. If the loader has a length, we set batch_size = 1, so | ||
# this will give us the right number of instances. | ||
total_instances += self._batch_size_multiplier.get(key, 1.0) * len(loader) | ||
if self._drop_last or total_instances % self._batch_size == 0: | ||
return int(total_instances) // self._batch_size | ||
# which is actually what we want. | ||
return self.scheduler.count_batches( | ||
{dataset: len(loader) for dataset, loader in self._loaders.items()} | ||
) | ||
Comment on lines
+179
to
+181
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because schedulers now make batches, only schedulers know how many batches they are making. So the responsibility of counting is devolved to them. |
||
else: | ||
return int(1 + total_instances) // self._batch_size | ||
return self.scheduler.count_batches( | ||
{dataset: self._instances_per_epoch for dataset in self._loaders.keys()} | ||
) | ||
|
||
@overrides | ||
def __iter__(self) -> Iterator[TensorDict]: | ||
# Basic outline: first we _sample_ the instances that we're going to be using for this | ||
# epoch, which relies on the scheduler if `self._instances_per_epoch` is not None. This is | ||
# basically just saying how many instances we should use this epoch for each dataset, and we | ||
# grab bounded-length iterators over that many instances for each dataset. Second, we | ||
# _schedule_ the epoch's instances into a single list, again relying on the scheduler. | ||
# Finally, we take that combined list and yield `batch_size` batches from it. | ||
epoch_instances = self._get_instances_for_epoch() | ||
scheduled_instances = self.scheduler.order_epoch_instances(epoch_instances) | ||
batch_instances: List[Instance] = [] | ||
current_batch_size = 0.0 | ||
for dataset, instance in scheduled_instances: | ||
current_batch_size += self._batch_size_multiplier.get(dataset, 1.0) | ||
if current_batch_size > self._batch_size: | ||
batch = Batch(batch_instances) | ||
tensor_dict = batch.as_tensor_dict() | ||
if self.cuda_device is not None: | ||
tensor_dict == nn_util.move_to_device(tensor_dict, self.cuda_device) | ||
yield tensor_dict | ||
batch_instances = [instance] | ||
current_batch_size = self._batch_size_multiplier.get(dataset, 1.0) | ||
else: | ||
batch_instances.append(instance) | ||
|
||
# Based on how we yield batches above, we are guaranteed to always have leftover instances, | ||
# so we don't need a check for that here. | ||
if not self._drop_last or current_batch_size == self._batch_size: | ||
batch = Batch(batch_instances) | ||
tensor_dict = batch.as_tensor_dict() | ||
if self.cuda_device is not None: | ||
tensor_dict == nn_util.move_to_device(tensor_dict, self.cuda_device) | ||
yield tensor_dict | ||
return ( | ||
nn_util.move_to_device( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds correct. I will check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, never mind. This is definitely necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, wait. We read |
||
Batch(instances).as_tensor_dict(), | ||
-1 if self.cuda_device is None else self.cuda_device, | ||
) | ||
for instances in self.scheduler.batch_instances(epoch_instances) | ||
) | ||
|
||
@overrides | ||
def iter_instances(self) -> Iterator[Instance]: | ||
|
@@ -271,7 +209,7 @@ def iter_instances(self) -> Iterator[Instance]: | |
# complex, configurable scheduling. | ||
# | ||
# The underlying data loaders here could be using multiprocessing; we don't need to worry | ||
# about that in this class. Caching is also handled by the underlying data loaders. | ||
# about that in this class. Caching is also handled by the underlying data loaders. | ||
for loader in self._loaders.values(): | ||
yield from loader.iter_instances() | ||
|
||
|
@@ -292,9 +230,9 @@ def _get_instances_for_epoch(self) -> Dict[str, Iterable[Instance]]: | |
} | ||
if self.sampler is None: | ||
# We already checked for this in the constructor, so this should never happen unless you | ||
# modified the object after creation. But mypy is complaining, so here's another check. | ||
# modified the object after creation. But mypy is complaining, so here's another check. | ||
raise ValueError( | ||
"You must specify an EpochSampler if self._instances_per_epoch is not None" | ||
"You must specify an EpochSampler if self._instances_per_epoch is not None." | ||
) | ||
dataset_proportions = self.sampler.get_task_proportions(self._loaders) | ||
proportion_sum = sum(dataset_proportions.values()) | ||
|
@@ -309,9 +247,9 @@ def _get_instances_for_epoch(self) -> Dict[str, Iterable[Instance]]: | |
|
||
def _make_data_loader(self, key: str) -> MultiProcessDataLoader: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line 259-262 can be removed. Those params to the data loader don't exist anymore. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
kwargs: Dict[str, Any] = {} | ||
kwargs["reader"] = self.readers[key] | ||
kwargs["reader"] = _MultitaskDatasetReaderShim(self.readers[key], key) | ||
kwargs["data_path"] = self.data_paths[key] | ||
kwargs["batch_size"] = 1 | ||
kwargs["batch_size"] = 1 # So that the loader gives us one instance at a time. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only use it to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, but we have to set something. I clarified the comment. |
||
if key in self._num_workers: | ||
kwargs["num_workers"] = self._num_workers[key] | ||
if key in self._max_instances_in_memory: | ||
|
@@ -323,3 +261,32 @@ def _make_data_loader(self, key: str) -> MultiProcessDataLoader: | |
if key in self._instance_chunk_size: | ||
kwargs["instance_chunk_size"] = self._instance_chunk_size[key] | ||
return MultiProcessDataLoader(**kwargs) | ||
|
||
|
||
@DatasetReader.register("multitask_shim") | ||
class _MultitaskDatasetReaderShim(DatasetReader): | ||
"""This dataset reader wraps another dataset reader and adds the name of the "task" into | ||
each instance as a metadata field. This exists only to support `MultitaskDataLoader`. You | ||
should not have to use this yourself.""" | ||
|
||
def __init__(self, inner: DatasetReader, head: str, **kwargs): | ||
super().__init__(**kwargs) | ||
self.inner = inner | ||
self.head = head | ||
|
||
def read(self, file_path: DatasetReaderInput) -> Iterator[Instance]: | ||
from allennlp.data.fields import MetadataField | ||
|
||
for instance in self.inner.read(file_path): | ||
instance.add_field("task", MetadataField(self.head)) | ||
yield instance | ||
|
||
def text_to_instance(self, *inputs) -> Instance: | ||
from allennlp.data.fields import MetadataField | ||
|
||
instance = self.inner.text_to_instance(*inputs) | ||
instance.add_field("task", MetadataField(self.head)) | ||
return instance | ||
|
||
def apply_token_indexers(self, instance: Instance) -> None: | ||
self.inner.apply_token_indexers(instance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inputs for multitask are dicts. Some other readers do lists.