From 46f224fc91ee3e474376884132869c006108a45e Mon Sep 17 00:00:00 2001 From: BaruchG Date: Mon, 8 Aug 2022 17:02:38 -0400 Subject: [PATCH 01/17] insert torchvision dependency and write tests for cifar10 --- pl_bolts/datasets/cifar10_dataset.py | 157 ++------------------------- tests/datasets/test_datasets.py | 21 ++++ 2 files changed, 31 insertions(+), 147 deletions(-) diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 9e2c68ac6f..3e40c3920d 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -1,160 +1,23 @@ -import os -import pickle -import tarfile +from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 +from pl_bolts.utils.stability import under_review +from pl_bolts.utils.warnings import warn_missing_pkg from typing import Callable, Optional, Sequence, Tuple - +import os import torch from torch import Tensor -from pl_bolts.datasets import LightDataset -from pl_bolts.utils import _PIL_AVAILABLE -from pl_bolts.utils.stability import under_review -from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets import CIFAR10 +else: # pragma: no cover + warn_missing_pkg("torchvision") + CIFAR10 = object if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg("PIL", pypi_name="Pillow") - -@under_review() -class CIFAR10(LightDataset): - """Customized `CIFAR10 `_ dataset for testing Pytorch Lightning - without the torchvision dependency. - - Part of the code was copied from - https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/ - - Args: - data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt`` - and ``CIFAR10/processed/test.pt`` exist. - train: If ``True``, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - download: If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - - Examples: - - >>> from torchvision import transforms - >>> from pl_bolts.transforms.dataset_normalizations import cifar10_normalization - >>> cf10_transforms = transforms.Compose([transforms.ToTensor(), cifar10_normalization()]) - >>> dataset = CIFAR10(download=True, transform=cf10_transforms, data_dir="datasets") - >>> len(dataset) - 50000 - >>> torch.bincount(dataset.targets) - tensor([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]) - >>> data, label = dataset[0] - >>> data.shape - torch.Size([3, 32, 32]) - >>> label - 6 - - Labels:: - - airplane: 0 - automobile: 1 - bird: 2 - cat: 3 - deer: 4 - dog: 5 - frog: 6 - horse: 7 - ship: 8 - truck: 9 - """ - - BASE_URL = "https://www.cs.toronto.edu/~kriz/" - FILE_NAME = "cifar-10-python.tar.gz" - cache_folder_name = "complete" - TRAIN_FILE_NAME = "training.pt" - TEST_FILE_NAME = "test.pt" - DATASET_NAME = "CIFAR10" - labels = set(range(10)) - relabel = False - - def __init__( - self, data_dir: str = ".", train: bool = True, transform: Optional[Callable] = None, download: bool = True - ): - super().__init__() - self.dir_path = data_dir - self.train = train # training set or test set - self.transform = transform - - if not _PIL_AVAILABLE: - raise ImportError("You want to use PIL.Image for loading but it is not installed yet.") - - os.makedirs(self.cached_folder_path, exist_ok=True) - self.prepare_data(download) - - if not self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)): - raise RuntimeError("Dataset not found.") - - data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME - self.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file)) - - def __getitem__(self, idx: int) -> Tuple[Tensor, int]: - img = self.data[idx].reshape(3, 32, 32) - target = int(self.targets[idx]) - - if self.transform is not None: - img = img.numpy().transpose((1, 2, 0)) # convert to HWC - img = self.transform(Image.fromarray(img)) - if self.relabel: - target = list(self.labels).index(target) - return img, target - - @classmethod - def _check_exists(cls, data_folder: str, file_names: Sequence[str]) -> bool: - if isinstance(file_names, str): - file_names = [file_names] - return all(os.path.isfile(os.path.join(data_folder, fname)) for fname in file_names) - - def _unpickle(self, path_folder: str, file_name: str) -> Tuple[Tensor, Tensor]: - with open(os.path.join(path_folder, file_name), "rb") as fo: - pkl = pickle.load(fo, encoding="bytes") - return torch.tensor(pkl[b"data"]), torch.tensor(pkl[b"labels"]) - - def _extract_archive_save_torch(self, download_path): - # extract achieve - with tarfile.open(os.path.join(download_path, self.FILE_NAME), "r:gz") as tar: - tar.extractall(path=download_path) - # this is internal path in the archive - path_content = os.path.join(download_path, "cifar-10-batches-py") - - # load Test and save as PT - torch.save( - self._unpickle(path_content, "test_batch"), os.path.join(self.cached_folder_path, self.TEST_FILE_NAME) - ) - # load Train and save as PT - data, labels = [], [] - for i in range(5): - fname = f"data_batch_{i + 1}" - _data, _labels = self._unpickle(path_content, fname) - data.append(_data) - labels.append(_labels) - # stash all to one - data = torch.cat(data, dim=0) - labels = torch.cat(labels, dim=0) - # and save as PT - torch.save((data, labels), os.path.join(self.cached_folder_path, self.TRAIN_FILE_NAME)) - - def prepare_data(self, download: bool): - if self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)): - return - - base_path = os.path.join(self.dir_path, self.DATASET_NAME) - if download: - self.download(base_path) - self._extract_archive_save_torch(base_path) - - def download(self, data_folder: str) -> None: - """Download the data if it doesn't exist in cached_folder_path already.""" - if self._check_exists(data_folder, self.FILE_NAME): - return - self._download_from_url(self.BASE_URL, data_folder, self.FILE_NAME) - - @under_review() class TrialCIFAR10(CIFAR10): """ diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 6d010fe15b..97d2ae0da2 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,9 +1,12 @@ import pytest import torch from torch.utils.data import DataLoader +import torchvision.transforms as transforms + from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset from pl_bolts.datasets.sr_mnist_dataset import SRMNIST +from pl_bolts.datasets.cifar10_dataset import CIFAR10 def test_dummy_ds(): @@ -52,3 +55,21 @@ def test_sr_datasets(datadir, scale_factor): assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol) assert torch.allclose(lr_image.min(), torch.tensor(0.0), atol=atol) assert torch.allclose(lr_image.max(), torch.tensor(1.0), atol=atol) + +def test_cifar10_datasets(datadir): + transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + dl = DataLoader(CIFAR10(root=datadir, download=True, transform=transform)) + hr_image, lr_image = next(iter(dl)) + print("==============================", lr_image.size()) + + hr_image_size = 32 + assert hr_image.size() == torch.Size([1, 3, hr_image_size, hr_image_size]) + assert lr_image.size() == torch.Size([1]) + + atol = 0.3 + assert torch.allclose(hr_image.min(), torch.tensor(-1.0), atol=atol) + assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol) + assert torch.greater_equal(lr_image.min(), torch.tensor(0)) + assert torch.less_equal(lr_image.max(), torch.tensor(9)) \ No newline at end of file From 01d6f1ec5169e89b37660d1788ae3826afd85703 Mon Sep 17 00:00:00 2001 From: BaruchG Date: Mon, 8 Aug 2022 17:17:48 -0400 Subject: [PATCH 02/17] removed print --- tests/datasets/test_datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 97d2ae0da2..eca45dfb2b 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -62,7 +62,6 @@ def test_cifar10_datasets(datadir): transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) dl = DataLoader(CIFAR10(root=datadir, download=True, transform=transform)) hr_image, lr_image = next(iter(dl)) - print("==============================", lr_image.size()) hr_image_size = 32 assert hr_image.size() == torch.Size([1, 3, hr_image_size, hr_image_size]) From 45c9a9d8525510640c8cc9630e6b8ab801aabb9e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Aug 2022 21:20:53 +0000 Subject: [PATCH 03/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/datasets/cifar10_dataset.py | 10 ++++++---- tests/datasets/test_datasets.py | 12 +++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 3e40c3920d..9f362e8065 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -1,11 +1,12 @@ -from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 -from pl_bolts.utils.stability import under_review -from pl_bolts.utils.warnings import warn_missing_pkg -from typing import Callable, Optional, Sequence, Tuple import os +from typing import Callable, Optional, Sequence, Tuple + import torch from torch import Tensor +from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 +from pl_bolts.utils.stability import under_review +from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: from torchvision.datasets import CIFAR10 @@ -18,6 +19,7 @@ else: # pragma: no cover warn_missing_pkg("PIL", pypi_name="Pillow") + @under_review() class TrialCIFAR10(CIFAR10): """ diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index eca45dfb2b..d98158b95d 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,12 +1,11 @@ import pytest import torch -from torch.utils.data import DataLoader import torchvision.transforms as transforms - +from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset -from pl_bolts.datasets.sr_mnist_dataset import SRMNIST from pl_bolts.datasets.cifar10_dataset import CIFAR10 +from pl_bolts.datasets.sr_mnist_dataset import SRMNIST def test_dummy_ds(): @@ -56,10 +55,9 @@ def test_sr_datasets(datadir, scale_factor): assert torch.allclose(lr_image.min(), torch.tensor(0.0), atol=atol) assert torch.allclose(lr_image.max(), torch.tensor(1.0), atol=atol) + def test_cifar10_datasets(datadir): - transform = transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) dl = DataLoader(CIFAR10(root=datadir, download=True, transform=transform)) hr_image, lr_image = next(iter(dl)) @@ -71,4 +69,4 @@ def test_cifar10_datasets(datadir): assert torch.allclose(hr_image.min(), torch.tensor(-1.0), atol=atol) assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol) assert torch.greater_equal(lr_image.min(), torch.tensor(0)) - assert torch.less_equal(lr_image.max(), torch.tensor(9)) \ No newline at end of file + assert torch.less_equal(lr_image.max(), torch.tensor(9)) From 86482ce665733e980f9aa461f57602ac76381174 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Aug 2022 22:06:54 +0000 Subject: [PATCH 04/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/datasets/test_datasets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 04e9dd2eb8..948b086ab8 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,16 +1,19 @@ import pytest import torch + <<<<<<< HEAD import torchvision.transforms as transforms from torch.utils.data import DataLoader from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset from pl_bolts.datasets.cifar10_dataset import CIFAR10 + ======= from torch.utils.data import DataLoader, Dataset from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset + >>>>>>> upstream/master from pl_bolts.datasets.sr_mnist_dataset import SRMNIST From adc57b90c35070ea33e853f05c86dab66ccbe3a5 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 23 Aug 2022 16:20:59 +0200 Subject: [PATCH 05/17] cleanup failed merge --- tests/datasets/test_datasets.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 948b086ab8..b47acbb501 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,20 +1,11 @@ import pytest import torch - -<<<<<<< HEAD import torchvision.transforms as transforms -from torch.utils.data import DataLoader - -from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset -from pl_bolts.datasets.cifar10_dataset import CIFAR10 - -======= from torch.utils.data import DataLoader, Dataset from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset +from pl_bolts.datasets.cifar10_dataset import CIFAR10 from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset - ->>>>>>> upstream/master from pl_bolts.datasets.sr_mnist_dataset import SRMNIST From b1995a7f3435f6d137e1e536aa1532f873605912 Mon Sep 17 00:00:00 2001 From: BaruchG Date: Fri, 16 Sep 2022 15:47:03 -0400 Subject: [PATCH 06/17] Revert "insert torchvision dependency and write tests for cifar10" This reverts commit 46f224fc91ee3e474376884132869c006108a45e. --- pl_bolts/datasets/cifar10_dataset.py | 146 +++++++++++++++++++++++++-- tests/datasets/test_datasets.py | 10 ++ 2 files changed, 147 insertions(+), 9 deletions(-) diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 9f362e8065..e5612e04aa 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -1,32 +1,160 @@ import os +import pickle +import tarfile from typing import Callable, Optional, Sequence, Tuple import torch from torch import Tensor -from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 +from pl_bolts.datasets import LightDataset +from pl_bolts.utils import _PIL_AVAILABLE from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg -if _TORCHVISION_AVAILABLE: - from torchvision.datasets import CIFAR10 -else: # pragma: no cover - warn_missing_pkg("torchvision") - CIFAR10 = object - if _PIL_AVAILABLE: from PIL import Image else: # pragma: no cover warn_missing_pkg("PIL", pypi_name="Pillow") +@under_review() +class CIFAR10(LightDataset): + """Customized `CIFAR10 `_ dataset for testing Pytorch Lightning + without the torchvision dependency. + Part of the code was copied from + https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/ + Args: + data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt`` + and ``CIFAR10/processed/test.pt`` exist. + train: If ``True``, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download: If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + Examples: + >>> from torchvision import transforms + >>> from pl_bolts.transforms.dataset_normalizations import cifar10_normalization + >>> cf10_transforms = transforms.Compose([transforms.ToTensor(), cifar10_normalization()]) + >>> dataset = CIFAR10(download=True, transform=cf10_transforms, data_dir="datasets") + >>> len(dataset) + 50000 + >>> torch.bincount(dataset.targets) + tensor([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]) + >>> data, label = dataset[0] + >>> data.shape + torch.Size([3, 32, 32]) + >>> label + 6 + Labels:: + airplane: 0 + automobile: 1 + bird: 2 + cat: 3 + deer: 4 + dog: 5 + frog: 6 + horse: 7 + ship: 8 + truck: 9 + """ + + BASE_URL = "https://www.cs.toronto.edu/~kriz/" + FILE_NAME = "cifar-10-python.tar.gz" + cache_folder_name = "complete" + TRAIN_FILE_NAME = "training.pt" + TEST_FILE_NAME = "test.pt" + DATASET_NAME = "CIFAR10" + labels = set(range(10)) + relabel = False + + def __init__( + self, data_dir: str = ".", train: bool = True, transform: Optional[Callable] = None, download: bool = True + ): + super().__init__() + self.dir_path = data_dir + self.train = train # training set or test set + self.transform = transform + + if not _PIL_AVAILABLE: + raise ImportError("You want to use PIL.Image for loading but it is not installed yet.") + + os.makedirs(self.cached_folder_path, exist_ok=True) + self.prepare_data(download) + + if not self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)): + raise RuntimeError("Dataset not found.") + + data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME + self.data, self.targets = torch.load(os.path.join(self.cached_folder_path, data_file)) + + def __getitem__(self, idx: int) -> Tuple[Tensor, int]: + img = self.data[idx].reshape(3, 32, 32) + target = int(self.targets[idx]) + + if self.transform is not None: + img = img.numpy().transpose((1, 2, 0)) # convert to HWC + img = self.transform(Image.fromarray(img)) + if self.relabel: + target = list(self.labels).index(target) + return img, target + + @classmethod + def _check_exists(cls, data_folder: str, file_names: Sequence[str]) -> bool: + if isinstance(file_names, str): + file_names = [file_names] + return all(os.path.isfile(os.path.join(data_folder, fname)) for fname in file_names) + + def _unpickle(self, path_folder: str, file_name: str) -> Tuple[Tensor, Tensor]: + with open(os.path.join(path_folder, file_name), "rb") as fo: + pkl = pickle.load(fo, encoding="bytes") + return torch.tensor(pkl[b"data"]), torch.tensor(pkl[b"labels"]) + + def _extract_archive_save_torch(self, download_path): + # extract achieve + with tarfile.open(os.path.join(download_path, self.FILE_NAME), "r:gz") as tar: + tar.extractall(path=download_path) + # this is internal path in the archive + path_content = os.path.join(download_path, "cifar-10-batches-py") + + # load Test and save as PT + torch.save( + self._unpickle(path_content, "test_batch"), os.path.join(self.cached_folder_path, self.TEST_FILE_NAME) + ) + # load Train and save as PT + data, labels = [], [] + for i in range(5): + fname = f"data_batch_{i + 1}" + _data, _labels = self._unpickle(path_content, fname) + data.append(_data) + labels.append(_labels) + # stash all to one + data = torch.cat(data, dim=0) + labels = torch.cat(labels, dim=0) + # and save as PT + torch.save((data, labels), os.path.join(self.cached_folder_path, self.TRAIN_FILE_NAME)) + + def prepare_data(self, download: bool): + if self._check_exists(self.cached_folder_path, (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME)): + return + + base_path = os.path.join(self.dir_path, self.DATASET_NAME) + if download: + self.download(base_path) + self._extract_archive_save_torch(base_path) + + def download(self, data_folder: str) -> None: + """Download the data if it doesn't exist in cached_folder_path already.""" + if self._check_exists(data_folder, self.FILE_NAME): + return + self._download_from_url(self.BASE_URL, data_folder, self.FILE_NAME) + + @under_review() class TrialCIFAR10(CIFAR10): """ Customized `CIFAR10 `_ dataset for testing Pytorch Lightning without the torchvision dependency. Examples: - >>> dataset = TrialCIFAR10(download=True, num_samples=150, labels=(1, 5, 8), data_dir="datasets") >>> len(dataset) 450 @@ -80,4 +208,4 @@ def prepare_data(self, download: bool) -> None: data, targets = torch.load(path_fname) if self.num_samples or len(self.labels) < 10: data, targets = self._prepare_subset(data, targets, self.num_samples, self.labels) - torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) + torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) \ No newline at end of file diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index b48a8945d6..10f18c4c63 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,5 +1,6 @@ import pytest import torch +<<<<<<< HEAD import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset from torchvision import transforms as transform_lib @@ -54,6 +55,12 @@ def test_rand_dict_ds(catch_warnings, batch_size, size, num_samples): x = next(iter(ds)) assert x["a"].shape == torch.Size([size]) assert x["b"].shape == torch.Size([size]) +======= +from torch.utils.data import DataLoader + +from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset +from pl_bolts.datasets.sr_mnist_dataset import SRMNIST +>>>>>>> parent of 46f224f... insert torchvision dependency and write tests for cifar10 batch = next(iter(dl)) assert len(batch["a"]), len(batch["a"][0]) == (batch_size, size) @@ -139,6 +146,7 @@ def test_sr_datasets(datadir, scale_factor): assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol) assert torch.allclose(lr_image.min(), torch.tensor(0.0), atol=atol) assert torch.allclose(lr_image.max(), torch.tensor(1.0), atol=atol) +<<<<<<< HEAD def test_cifar10_datasets(datadir): @@ -182,3 +190,5 @@ def test_binary_emnist_dataset(datadir, split): assert torch.allclose(img.min(), torch.tensor(0.0)) assert torch.allclose(img.max(), torch.tensor(1.0)) assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0])) +======= +>>>>>>> parent of 46f224f... insert torchvision dependency and write tests for cifar10 From a9a5ab8e189a04a4398c34bebab397fc24840b53 Mon Sep 17 00:00:00 2001 From: BaruchG Date: Tue, 20 Sep 2022 11:56:31 -0400 Subject: [PATCH 07/17] ensured tests are present for object detection and removed under review --- pl_bolts/metrics/object_detection.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index d2a93e35af..c6814fb647 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -1,10 +1,6 @@ import torch from torch import Tensor -from pl_bolts.utils.stability import under_review - - -@under_review() def iou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union. @@ -36,8 +32,6 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: iou = torch.true_divide(intersection, union) return iou - -@under_review() def giou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union. From 6ae60a82794b4cb2b70f4dd3a72a877a5b445840 Mon Sep 17 00:00:00 2001 From: BaruchG Date: Tue, 20 Sep 2022 12:10:11 -0400 Subject: [PATCH 08/17] cifar10 revert --- pl_bolts/datasets/cifar10_dataset.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index e5612e04aa..9e2c68ac6f 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -21,8 +21,10 @@ class CIFAR10(LightDataset): """Customized `CIFAR10 `_ dataset for testing Pytorch Lightning without the torchvision dependency. + Part of the code was copied from https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/ + Args: data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt`` and ``CIFAR10/processed/test.pt`` exist. @@ -31,7 +33,9 @@ class CIFAR10(LightDataset): download: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. + Examples: + >>> from torchvision import transforms >>> from pl_bolts.transforms.dataset_normalizations import cifar10_normalization >>> cf10_transforms = transforms.Compose([transforms.ToTensor(), cifar10_normalization()]) @@ -45,7 +49,9 @@ class CIFAR10(LightDataset): torch.Size([3, 32, 32]) >>> label 6 + Labels:: + airplane: 0 automobile: 1 bird: 2 @@ -155,6 +161,7 @@ class TrialCIFAR10(CIFAR10): Customized `CIFAR10 `_ dataset for testing Pytorch Lightning without the torchvision dependency. Examples: + >>> dataset = TrialCIFAR10(download=True, num_samples=150, labels=(1, 5, 8), data_dir="datasets") >>> len(dataset) 450 @@ -208,4 +215,4 @@ def prepare_data(self, download: bool) -> None: data, targets = torch.load(path_fname) if self.num_samples or len(self.labels) < 10: data, targets = self._prepare_subset(data, targets, self.num_samples, self.labels) - torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) \ No newline at end of file + torch.save((data, targets), os.path.join(self.cached_folder_path, fname)) From bfdae0113ea84d78c0f60229ef5376f1339b7e8a Mon Sep 17 00:00:00 2001 From: BaruchG Date: Tue, 20 Sep 2022 12:12:25 -0400 Subject: [PATCH 09/17] revert cifar10 --- tests/datasets/test_datasets.py | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 10f18c4c63..d2a26bd64b 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,7 +1,5 @@ import pytest import torch -<<<<<<< HEAD -import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset from torchvision import transforms as transform_lib @@ -13,7 +11,6 @@ RandomDictDataset, RandomDictStringDataset, ) -from pl_bolts.datasets.cifar10_dataset import CIFAR10 from pl_bolts.datasets.dummy_dataset import DummyDetectionDataset from pl_bolts.datasets.sr_mnist_dataset import SRMNIST @@ -55,12 +52,6 @@ def test_rand_dict_ds(catch_warnings, batch_size, size, num_samples): x = next(iter(ds)) assert x["a"].shape == torch.Size([size]) assert x["b"].shape == torch.Size([size]) -======= -from torch.utils.data import DataLoader - -from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset -from pl_bolts.datasets.sr_mnist_dataset import SRMNIST ->>>>>>> parent of 46f224f... insert torchvision dependency and write tests for cifar10 batch = next(iter(dl)) assert len(batch["a"]), len(batch["a"][0]) == (batch_size, size) @@ -146,23 +137,6 @@ def test_sr_datasets(datadir, scale_factor): assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol) assert torch.allclose(lr_image.min(), torch.tensor(0.0), atol=atol) assert torch.allclose(lr_image.max(), torch.tensor(1.0), atol=atol) -<<<<<<< HEAD - - -def test_cifar10_datasets(datadir): - transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - dl = DataLoader(CIFAR10(root=datadir, download=True, transform=transform)) - hr_image, lr_image = next(iter(dl)) - - hr_image_size = 32 - assert hr_image.size() == torch.Size([1, 3, hr_image_size, hr_image_size]) - assert lr_image.size() == torch.Size([1]) - - atol = 0.3 - assert torch.allclose(hr_image.min(), torch.tensor(-1.0), atol=atol) - assert torch.allclose(hr_image.max(), torch.tensor(1.0), atol=atol) - assert torch.greater_equal(lr_image.min(), torch.tensor(0)) - assert torch.less_equal(lr_image.max(), torch.tensor(9)) def test_binary_mnist_dataset(datadir): @@ -189,6 +163,4 @@ def test_binary_emnist_dataset(datadir, split): assert torch.allclose(img.min(), torch.tensor(0.0)) assert torch.allclose(img.max(), torch.tensor(1.0)) - assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0])) -======= ->>>>>>> parent of 46f224f... insert torchvision dependency and write tests for cifar10 + assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0])) \ No newline at end of file From c6b47527082266f3d0f14da2dc93030b4cbef737 Mon Sep 17 00:00:00 2001 From: BaruchG Date: Tue, 20 Sep 2022 12:13:25 -0400 Subject: [PATCH 10/17] revert cifar10 --- tests/datasets/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index d2a26bd64b..b45e987589 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -163,4 +163,4 @@ def test_binary_emnist_dataset(datadir, split): assert torch.allclose(img.min(), torch.tensor(0.0)) assert torch.allclose(img.max(), torch.tensor(1.0)) - assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0])) \ No newline at end of file + assert torch.equal(torch.unique(img), torch.tensor([0.0, 1.0])) From 92d3297c4c8d0c4568ed717c26d6857944db2b4b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Sep 2022 16:17:11 +0000 Subject: [PATCH 11/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/metrics/object_detection.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index c6814fb647..106de2ec0d 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -1,6 +1,7 @@ import torch from torch import Tensor + def iou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union. @@ -32,6 +33,7 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: iou = torch.true_divide(intersection, union) return iou + def giou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union. From f9954f422e32660d083ae315f08ac5bdebd1f43d Mon Sep 17 00:00:00 2001 From: BaruchG Date: Wed, 21 Sep 2022 16:50:56 -0400 Subject: [PATCH 12/17] renamed variables to conform to specs --- pl_bolts/metrics/object_detection.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index c6814fb647..c4f9fff7c9 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -1,22 +1,19 @@ import torch from torch import Tensor + def iou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union. - Args: preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` - Example: - >>> import torch >>> from pl_bolts.metrics.object_detection import iou >>> preds = torch.tensor([[100, 100, 200, 200]]) >>> target = torch.tensor([[150, 150, 250, 250]]) >>> iou(preds, target) tensor([[0.1429]]) - Returns: IoU tensor: an NxM tensor containing the pairwise IoU values for every element in preds and target, where N is the number of prediction bounding boxes and M is the number of target bounding boxes @@ -29,28 +26,24 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) union = pred_area[:, None] + target_area - intersection - iou = torch.true_divide(intersection, union) - return iou + iou_value = torch.true_divide(intersection, union) + return iou_value + def giou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union. - It has been proposed in `Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression `_. - Args: preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` - Example: - >>> import torch >>> from pl_bolts.metrics.object_detection import giou >>> preds = torch.tensor([[100, 100, 200, 200]]) >>> target = torch.tensor([[150, 150, 250, 250]]) >>> giou(preds, target) tensor([[-0.0794]]) - Returns: GIoU in an NxM tensor containing the pairwise GIoU values for every element in preds and target, where N is the number of prediction bounding boxes and M is the number of target bounding boxes @@ -68,6 +61,6 @@ def giou(preds: Tensor, target: Tensor) -> Tensor: C_x_max = torch.max(preds[:, None, 2], target[:, 2]) C_y_max = torch.max(preds[:, None, 3], target[:, 3]) C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0) - iou = torch.true_divide(intersection, union) - giou = iou - torch.true_divide((C_area - union), C_area) - return giou + iou_value = torch.true_divide(intersection, union) + giou_value = iou_value - torch.true_divide((C_area - union), C_area) + return giou_value \ No newline at end of file From d38a965f7c4c650f8698363f2c2aca599867f247 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Sep 2022 20:52:55 +0000 Subject: [PATCH 13/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/metrics/object_detection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index 02561cea08..2e6f1a92eb 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -30,7 +30,6 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: return iou_value - def giou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union. It has been proposed in `Generalized Intersection over Union: A Metric and A @@ -64,4 +63,4 @@ def giou(preds: Tensor, target: Tensor) -> Tensor: C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0) iou_value = torch.true_divide(intersection, union) giou_value = iou_value - torch.true_divide((C_area - union), C_area) - return giou_value \ No newline at end of file + return giou_value From 6a0eacfa301d7559ab9ba7a133879160bbec45bb Mon Sep 17 00:00:00 2001 From: BaruchG Date: Wed, 21 Sep 2022 16:54:08 -0400 Subject: [PATCH 14/17] added newline --- pl_bolts/metrics/object_detection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index 02561cea08..4c9c156fc6 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -64,4 +64,5 @@ def giou(preds: Tensor, target: Tensor) -> Tensor: C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0) iou_value = torch.true_divide(intersection, union) giou_value = iou_value - torch.true_divide((C_area - union), C_area) - return giou_value \ No newline at end of file + return giou_value + \ No newline at end of file From 0d4cc232c5fa715250a3d2ad20495be3996baeb3 Mon Sep 17 00:00:00 2001 From: BaruchG Date: Wed, 21 Sep 2022 17:29:27 -0400 Subject: [PATCH 15/17] upgraded to assert_close and modified tolerance --- tests/metrics/test_object_detection.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/metrics/test_object_detection.py b/tests/metrics/test_object_detection.py index 011f230858..be20bdd8d8 100644 --- a/tests/metrics/test_object_detection.py +++ b/tests/metrics/test_object_detection.py @@ -11,7 +11,7 @@ [(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([[1.0]]))], ) def test_iou_complete_overlap(preds, target, expected_iou): - torch.testing.assert_allclose(iou(preds, target), expected_iou) + torch.testing.assert_close(iou(preds, target), expected_iou) @pytest.mark.parametrize( @@ -22,7 +22,7 @@ def test_iou_complete_overlap(preds, target, expected_iou): ], ) def test_iou_no_overlap(preds, target, expected_iou): - torch.testing.assert_allclose(iou(preds, target), expected_iou) + torch.testing.assert_close(iou(preds, target), expected_iou) @pytest.mark.parametrize( @@ -36,7 +36,7 @@ def test_iou_no_overlap(preds, target, expected_iou): ], ) def test_iou_multi(preds, target, expected_iou): - torch.testing.assert_allclose(iou(preds, target), expected_iou) + torch.testing.assert_close(iou(preds, target), expected_iou) @pytest.mark.parametrize( @@ -44,7 +44,7 @@ def test_iou_multi(preds, target, expected_iou): [(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([[1.0]]))], ) def test_complete_overlap(preds, target, expected_giou): - torch.testing.assert_allclose(giou(preds, target), expected_giou) + torch.testing.assert_close(giou(preds, target), expected_giou) @pytest.mark.parametrize( @@ -55,7 +55,7 @@ def test_complete_overlap(preds, target, expected_giou): ], ) def test_no_overlap(preds, target, expected_giou): - torch.testing.assert_allclose(giou(preds, target), expected_giou) + torch.testing.assert_close(giou(preds, target), expected_giou) @pytest.mark.parametrize( @@ -69,4 +69,4 @@ def test_no_overlap(preds, target, expected_giou): ], ) def test_giou_multi(preds, target, expected_giou): - torch.testing.assert_allclose(giou(preds, target), expected_giou) + torch.testing.assert_close(giou(preds, target), expected_giou, atol=.0001, rtol=.0001) From c32979a7007eacfd0262db37b7e32f8b7018079b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Sep 2022 21:29:52 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/metrics/test_object_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_object_detection.py b/tests/metrics/test_object_detection.py index be20bdd8d8..6eb78d7247 100644 --- a/tests/metrics/test_object_detection.py +++ b/tests/metrics/test_object_detection.py @@ -69,4 +69,4 @@ def test_no_overlap(preds, target, expected_giou): ], ) def test_giou_multi(preds, target, expected_giou): - torch.testing.assert_close(giou(preds, target), expected_giou, atol=.0001, rtol=.0001) + torch.testing.assert_close(giou(preds, target), expected_giou, atol=0.0001, rtol=0.0001) From 519c1fec194c2da654ec72bf6a0bbeba96c564c8 Mon Sep 17 00:00:00 2001 From: BaruchG Date: Thu, 22 Sep 2022 12:16:45 -0400 Subject: [PATCH 17/17] modified formatting of docstring --- pl_bolts/metrics/object_detection.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index 2e6f1a92eb..55f7582d79 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -4,16 +4,20 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the intersection over union. + Args: preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + Example: + >>> import torch >>> from pl_bolts.metrics.object_detection import iou >>> preds = torch.tensor([[100, 100, 200, 200]]) >>> target = torch.tensor([[150, 150, 250, 250]]) >>> iou(preds, target) tensor([[0.1429]]) + Returns: IoU tensor: an NxM tensor containing the pairwise IoU values for every element in preds and target, where N is the number of prediction bounding boxes and M is the number of target bounding boxes @@ -32,18 +36,23 @@ def iou(preds: Tensor, target: Tensor) -> Tensor: def giou(preds: Tensor, target: Tensor) -> Tensor: """Calculates the generalized intersection over union. + It has been proposed in `Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression `_. + Args: preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + Example: + >>> import torch >>> from pl_bolts.metrics.object_detection import giou >>> preds = torch.tensor([[100, 100, 200, 200]]) >>> target = torch.tensor([[150, 150, 250, 250]]) >>> giou(preds, target) tensor([[-0.0794]]) + Returns: GIoU in an NxM tensor containing the pairwise GIoU values for every element in preds and target, where N is the number of prediction bounding boxes and M is the number of target bounding boxes