-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
29 lines (20 loc) · 901 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import NamedTuple, Iterator
from tensor import Tensor
import numpy as np
Batch = NamedTuple("Batch", [("inputs", Tensor), ("targets", Tensor)])
class DataIterator:
def __call__(self, inputs: Tensor, targets: Tensor) -> Iterator[Batch]:
raise NotImplementedError
class BatchIterator(DataIterator):
def __init__(self, batch_size: int = 32, shuffle: bool = True) -> None:
self.batch_size = batch_size
self.shuffle = shuffle
def __call__(self, inputs: Tensor, targets: Tensor) -> Iterator[Batch]:
starts = np.arange(0, len(inputs), self.batch_size)
if self.shuffle:
np.random.shuffle(starts)
for start in starts:
end = start + self.batch_size
batch_inputs = inputs[start:end]
batch_targets = targets[start:end]
yield Batch(batch_inputs, batch_targets)