forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.pyi
28 lines (20 loc) · 886 Bytes
/
dataset.pyi
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import TypeVar, Generic, Iterable, Sequence, List, Tuple
from ... import Tensor
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
class Dataset(Generic[T_co]):
def __getitem__(self, index: int) -> T_co: ...
def __len__(self) -> int: ...
def __add__(self, other: T_co) -> 'ConcatDataset[T_co]': ...
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
tensors: List[Tensor]
def __init__(self, *tensors: Tensor) -> None: ...
class ConcatDataset(Dataset[T_co]):
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
def __init__(self, datasets: Iterable[Dataset]) -> None: ...
class Subset(Dataset[T_co]):
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: ...
def random_split(dataset: Dataset[T], lengths: Sequence[int]) -> List[Subset[T]]: ...