Skip to content

Commit

Permalink
Feature - ArrayDataset (#872)
Browse files Browse the repository at this point in the history
Co-authored-by: otaj <ota@lightning.ai>
  • Loading branch information
Ce11an and otaj committed Aug 26, 2022
1 parent 9619d5f commit bcbbf6a
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 87 deletions.
3 changes: 1 addition & 2 deletions pl_bolts/datamodules/__init__.py
Expand Up @@ -9,7 +9,7 @@
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset
from pl_bolts.datamodules.sr_datamodule import TVTDataModule
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
Expand All @@ -31,7 +31,6 @@
"MNISTDataModule",
"SklearnDataModule",
"SklearnDataset",
"TensorDataset",
"TVTDataModule",
"SSLImagenetDataModule",
"STL10DataModule",
Expand Down
83 changes: 0 additions & 83 deletions pl_bolts/datamodules/sklearn_datamodule.py
Expand Up @@ -2,9 +2,7 @@
from typing import Any, Tuple

import numpy as np
import torch
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils import _SKLEARN_AVAILABLE
Expand Down Expand Up @@ -65,50 +63,6 @@ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
return x, y


@under_review()
class TensorDataset(Dataset):
"""Prepare PyTorch tensor dataset for data loaders.
Example:
>>> from pl_bolts.datamodules import TensorDataset
...
>>> X = torch.rand(10, 3)
>>> y = torch.rand(10)
>>> dataset = TensorDataset(X, y)
>>> len(dataset)
10
"""

def __init__(self, X: Tensor, y: Tensor, X_transform: Any = None, y_transform: Any = None) -> None:
"""
Args:
X: PyTorch tensor
y: PyTorch tensor
X_transform: Any transform that works with PyTorch tensors
y_transform: Any transform that works with PyTorch tensors
"""
super().__init__()
self.X = X
self.Y = y
self.X_transform = X_transform
self.y_transform = y_transform

def __len__(self) -> int:
return len(self.X)

def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
x = self.X[idx].float()
y = self.Y[idx]

if self.X_transform:
x = self.X_transform(x)

if self.y_transform:
y = self.y_transform(y)

return x, y


@under_review()
class SklearnDataModule(LightningDataModule):
"""Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as
Expand Down Expand Up @@ -241,40 +195,3 @@ def test_dataloader(self) -> DataLoader:
pin_memory=self.pin_memory,
)
return loader


# TODO: this seems to be wrong, something missing here, another inherit class?
# class TensorDataModule(SklearnDataModule):
# """
# Automatically generates the train, validation and test splits for a PyTorch tensor dataset. They are set up as
# dataloaders for convenience. Optionally, you can pass in your own validation and test splits.
#
# Example:
#
# >>> from pl_bolts.datamodules import TensorDataModule
# >>> import torch
# ...
# >>> # create dataset
# >>> X = torch.rand(100, 3)
# >>> y = torch.rand(100)
# >>> loaders = TensorDataModule(X, y)
# ...
# >>> # train set
# >>> train_loader = loaders.train_dataloader(batch_size=10)
# >>> len(train_loader.dataset)
# 70
# >>> len(train_loader)
# 7
# >>> # validation set
# >>> val_loader = loaders.val_dataloader(batch_size=10)
# >>> len(val_loader.dataset)
# 20
# >>> len(val_loader)
# 2
# >>> # test set
# >>> test_loader = loaders.test_dataloader(batch_size=10)
# >>> len(test_loader.dataset)
# 10
# >>> len(test_loader)
# 1
# """
5 changes: 4 additions & 1 deletion pl_bolts/datasets/__init__.py
@@ -1,6 +1,7 @@
import urllib

from pl_bolts.datasets.base_dataset import LightDataset
from pl_bolts.datasets.array_dataset import ArrayDataset
from pl_bolts.datasets.base_dataset import DataModel, LightDataset
from pl_bolts.datasets.cifar10_dataset import CIFAR10, TrialCIFAR10
from pl_bolts.datasets.concat_dataset import ConcatDataset
from pl_bolts.datasets.dummy_dataset import (
Expand All @@ -17,6 +18,8 @@
from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin

__all__ = [
"ArrayDataset",
"DataModel",
"LightDataset",
"CIFAR10",
"TrialCIFAR10",
Expand Down
52 changes: 52 additions & 0 deletions pl_bolts/datasets/array_dataset.py
@@ -0,0 +1,52 @@
from typing import Tuple, Union

from pytorch_lightning.utilities import exceptions
from torch.utils.data import Dataset

from pl_bolts.datasets.base_dataset import DataModel, TArrays


class ArrayDataset(Dataset):
"""Dataset wrapping tensors, lists, numpy arrays.
Any number of ARRAYS can be inputted into the dataset. The ARRAYS are transformed on each `__getitem__`. When
transforming, please refrain from chaning the hape of ARRAYS in the first demension.
Attributes:
data_models: Sequence of data models.
Raises:
MisconfigurationException: if there is a shape mismatch between arrays in the first dimension.
Example:
>>> from pl_bolts.datasets import ArrayDataset, DataModel
>>> from pl_bolts.datasets.utils import to_tensor
>>> features = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3]], transform=to_tensor)
>>> target = DataModel(data=[1, 0, 0], transform=to_tensor)
>>> ds = ArrayDataset(features, target)
>>> len(ds)
3
"""

def __init__(self, *data_models: DataModel) -> None:
"""Initialises class and checks if arrays are the same shape in the first dimension."""
self.data_models = data_models

if not self._equal_size():
raise exceptions.MisconfigurationException("Shape mismatch between arrays in the first dimension")

def __len__(self) -> int:
return len(self.data_models[0].data)

def __getitem__(self, idx: int) -> Tuple[Union[TArrays, float], ...]:
return tuple(data_model.process(data_model.data[idx]) for data_model in self.data_models)

def _equal_size(self) -> bool:
"""Checks the size of the data_models are equal in the first dimension.
Returns:
bool: True if size of data_models are equal in the first dimension. False, if not.
"""
return len({len(data_model.data) for data_model in self.data_models}) == 1
32 changes: 31 additions & 1 deletion pl_bolts/datasets/base_dataset.py
Expand Up @@ -2,13 +2,15 @@
import os
import urllib.request
from abc import ABC
from typing import Sequence, Tuple
from dataclasses import dataclass
from typing import Callable, Optional, Sequence, Tuple, Union
from urllib.error import HTTPError

from torch import Tensor
from torch.utils.data import Dataset

from pl_bolts.utils.stability import under_review
from pl_bolts.utils.types import TArrays


@under_review()
Expand Down Expand Up @@ -58,3 +60,31 @@ def _download_from_url(self, base_url: str, data_folder: str, file_name: str):
urllib.request.urlretrieve(url, fpath)
except HTTPError as err:
raise RuntimeError(f"Failed download from {url}") from err


@dataclass
class DataModel:
"""Data model dataclass.
Ties together data and callable transforms.
Attributes:
data: Sequence of indexables.
transform: Callable to transform data. The transform is called on a subset of data.
"""

data: TArrays
transform: Optional[Callable[[TArrays], TArrays]] = None

def process(self, subset: Union[TArrays, float]) -> Union[TArrays, float]:
"""Transforms a subset of data.
Args:
subset: Sequence of indexables.
Returns:
data: Transformed data if transform is not None.
"""
if self.transform is not None:
subset = self.transform(subset)
return subset
16 changes: 16 additions & 0 deletions pl_bolts/datasets/utils.py
@@ -1,9 +1,11 @@
import torch
from torch.utils.data.dataset import random_split

from pl_bolts.datasets.sr_celeba_dataset import SRCelebA
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
from pl_bolts.datasets.sr_stl10_dataset import SRSTL10
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.types import TArrays


@under_review()
Expand Down Expand Up @@ -39,3 +41,17 @@ def prepare_sr_datasets(dataset: str, scale_factor: int, data_dir: str):
dataset_test = dataset_cls(scale_factor, root=data_dir, split="test", download=True)

return (dataset_train, dataset_val, dataset_test)


def to_tensor(arrays: TArrays) -> torch.Tensor:
"""Takes a sequence of type `TArrays` and returns a tensor.
This function serves as a use case for the ArrayDataset.
Args:
arrays: Sequence of type `TArrays`
Returns:
Tensor of the integers
"""
return torch.tensor(arrays)
6 changes: 6 additions & 0 deletions pl_bolts/utils/types.py
@@ -0,0 +1,6 @@
from typing import Sequence, Union

import numpy as np
import torch

TArrays = Union[torch.Tensor, np.ndarray, Sequence[float], Sequence["TArrays"]] # type: ignore
52 changes: 52 additions & 0 deletions tests/datasets/test_array_dataset.py
@@ -0,0 +1,52 @@
import numpy as np
import pytest
import torch
from pytorch_lightning.utilities import exceptions

from pl_bolts.datasets import ArrayDataset, DataModel
from pl_bolts.datasets.utils import to_tensor


class TestArrayDataset:
@pytest.fixture
def array_dataset(self):
features_1 = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]], transform=to_tensor)
target_1 = DataModel(data=[1, 0, 0, 1], transform=to_tensor)

features_2 = DataModel(data=np.array([[2, 1, -5, 1], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]]))
target_2 = DataModel(data=[1, 0, 1, 1])
return ArrayDataset(features_1, target_1, features_2, target_2)

def test_len(self, array_dataset):
assert len(array_dataset) == 4

def test_getitem_with_transforms(self, array_dataset):
assert len(array_dataset[0]) == 4
assert len(array_dataset[1]) == 4
assert len(array_dataset[2]) == 4
assert len(array_dataset[3]) == 4
torch.testing.assert_close(array_dataset[0][0], torch.tensor([1, 0, -1, 2]))
torch.testing.assert_close(array_dataset[0][1], torch.tensor(1))
np.testing.assert_array_equal(array_dataset[0][2], np.array([2, 1, -5, 1]))
assert array_dataset[0][3] == 1
torch.testing.assert_close(array_dataset[1][0], torch.tensor([1, 0, -2, -1]))
torch.testing.assert_close(array_dataset[1][1], torch.tensor(0))
np.testing.assert_array_equal(array_dataset[1][2], np.array([1, 0, -2, -1]))
assert array_dataset[1][3] == 0
torch.testing.assert_close(array_dataset[2][0], torch.tensor([2, 5, 0, 3]))
torch.testing.assert_close(array_dataset[2][1], torch.tensor(0))
np.testing.assert_array_equal(array_dataset[2][2], np.array([2, 5, 0, 3]))
assert array_dataset[2][3] == 1
torch.testing.assert_close(array_dataset[3][0], torch.tensor([-7, 1, 2, 2]))
torch.testing.assert_close(array_dataset[3][1], torch.tensor(1))
np.testing.assert_array_equal(array_dataset[3][2], np.array([-7, 1, 2, 2]))
assert array_dataset[3][3] == 1

def test__equal_size_true(self, array_dataset):
assert array_dataset._equal_size() is True

def test__equal_size_false(self):
features = DataModel(data=[[1, 0, 1]])
target = DataModel([1, 0, 1])
with pytest.raises(exceptions.MisconfigurationException):
ArrayDataset(features, target)
22 changes: 22 additions & 0 deletions tests/datasets/test_base_dataset.py
@@ -0,0 +1,22 @@
import numpy as np
import pytest
import torch

from pl_bolts.datasets.base_dataset import DataModel
from pl_bolts.datasets.utils import to_tensor


class TestDataModel:
@pytest.fixture
def data(self):
return np.array([[1, 0, 0, 1], [0, 1, 1, 0]])

def test_process_transform_is_none(self, data):
dm = DataModel(data=data)
np.testing.assert_array_equal(dm.process(data[0]), data[0])
np.testing.assert_array_equal(dm.process(data[1]), data[1])

def test_process_transform_is_not_none(self, data):
dm = DataModel(data=data, transform=to_tensor)
torch.testing.assert_close(dm.process(data[0]), torch.tensor([1, 0, 0, 1]))
torch.testing.assert_close(dm.process(data[1]), torch.tensor([0, 1, 1, 0]))
18 changes: 18 additions & 0 deletions tests/datasets/test_utils.py
@@ -0,0 +1,18 @@
import numpy as np
import torch.testing

from pl_bolts.datasets.utils import to_tensor


class TestToTensor:
def test_to_tensor_list(self):
_list = [1, 2, 3]
torch.testing.assert_close(to_tensor(_list), torch.tensor(_list))

def test_to_tensor_array(self):
_array = np.array([1, 2, 3])
torch.testing.assert_close(to_tensor(_array), torch.tensor(_array))

def test_to_tensor_sequence_(self):
_sequence = [[1.0, 2.0, 3.0]]
torch.testing.assert_close(to_tensor(_sequence), torch.tensor(_sequence))

0 comments on commit bcbbf6a

Please sign in to comment.