-
Notifications
You must be signed in to change notification settings - Fork 125
/
iterator.py
60 lines (49 loc) · 1.94 KB
/
iterator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from __future__ import division
import logging
from torchtext import data
from ..batch import MatchingBatch
logger = logging.getLogger(__name__)
class MatchingIterator(data.BucketIterator):
def __init__(self,
dataset,
train_info,
train,
batch_size,
sort_in_buckets=None,
**kwargs):
if sort_in_buckets is None:
sort_in_buckets = train
self.sort_in_buckets = sort_in_buckets
self.train_info = train_info
super(MatchingIterator, self).__init__(
dataset, batch_size, train=train, repeat=False, sort=False, **kwargs)
@classmethod
def splits(cls, datasets, batch_sizes=None, **kwargs):
"""Create Iterator objects for multiple splits of a dataset.
Args:
datasets: Tuple of Dataset objects corresponding to the splits. The
first such object should be the train set.
batch_sizes: Tuple of batch sizes to use for the different splits,
or None to use the same batch_size for all splits.
Remaining keyword arguments: Passed to the constructor of the
iterator class being used.
"""
if batch_sizes is None:
batch_sizes = [kwargs.pop('batch_size')] * len(datasets)
ret = []
for i in range(len(datasets)):
ret.append(
cls(datasets[i],
train_info=datasets[0],
train=i==0,
batch_size=batch_sizes[i],
**kwargs))
return tuple(ret)
def __iter__(self):
for batch in super(MatchingIterator, self).__iter__():
yield MatchingBatch(batch, self.train_info)
def create_batches(self):
if self.sort_in_buckets:
return data.BucketIterator.create_batches(self)
else:
return data.Iterator.create_batches(self)