Skip to content

Commit

Permalink
Fix MNIST 503 error by changing URL to AWS S3 (#633)
Browse files Browse the repository at this point in the history
* Use s3 url for mnist

* flake8

* Update changelog

* formatting

* Port _compare_version from PL

* .

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update tests for trainer.fit returning None

* Remove unused refs

* Update tests for trainer.fit returning None

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: jirka <jirka.borovec@seznam.cz>
  • Loading branch information
5 people committed May 11, 2021
1 parent b236f21 commit 31f0f51
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed the MNIST download giving HTTP 503 ([#633](https://github.com/PyTorchLightning/lightning-bolts/pull/633))


## [0.3.3] - 2021-04-17
Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Any, Callable, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import MNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import MNIST
else: # pragma: no cover
warn_missing_pkg('torchvision')
MNIST = None


class MNISTDataModule(VisionDataModule):
Expand Down
3 changes: 2 additions & 1 deletion pl_bolts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from pl_bolts.datasets.imagenet_dataset import extract_archive, parse_devkit_archive, UnlabeledImagenet
from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.datasets.mnist_dataset import BinaryMNIST, MNIST
from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin

__all__ = [
Expand All @@ -22,6 +22,7 @@
"ConcatDataset",
"DummyDataset",
"DummyDetectionDataset",
"MNIST",
"RandomDataset",
"RandomDictDataset",
"RandomDictStringDataset",
Expand Down
21 changes: 20 additions & 1 deletion pl_bolts/datasets/mnist_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -12,6 +12,25 @@
else: # pragma: no cover
warn_missing_pkg('PIL', pypi_name='Pillow')

# TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset
# from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`.
# See https://github.com/pytorch/vision/issues/3549 for details.
if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_9_1:
MNIST.resources = [
(
"https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz",
"f68b3c2dcbeaaa9fbdd348bbdeb94873"
), # noqa: E501
(
"https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz",
"d53e105ee54ea40749a09fcbcd1e9432"
), # noqa: E501
("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz",
"9fb629c4189551a2d022fa330f9573f3"), # noqa: E501
("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz",
"ec29112dd5afa0611ce80d1b7f02629c"), # noqa: E501
]


class BinaryMNIST(MNIST):

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pl_bolts.datasets import MNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
from torchvision.datasets import MNIST
else: # pragma: no cover
warn_missing_pkg('torchvision')

Expand Down
26 changes: 26 additions & 0 deletions pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,33 @@
import importlib
import operator

import torch
from packaging.version import Version
from pkg_resources import DistributionNotFound
from pytorch_lightning.utilities import _module_available

from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore


# Ported from https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py
def _compare_version(package: str, op, version) -> bool:
"""
Compare package version with some requirements
>>> _compare_version("torch", operator.ge, "0.1")
True
"""
try:
pkg = importlib.import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
pkg_version = Version(pkg.__version__)
except TypeError:
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
return op(pkg_version, Version(version))


_NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")

_TORCHVISION_AVAILABLE: bool = _module_available("torchvision")
Expand All @@ -12,5 +37,6 @@
_OPENCV_AVAILABLE: bool = _module_available("cv2")
_WANDB_AVAILABLE: bool = _module_available("wandb")
_MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib")
_TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.ge, "0.9.1")

__all__ = ["BatchGradientVerification"]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ torch>=1.6
torchmetrics>=0.2.0
pytorch-lightning>=1.1.1
dataclasses ; python_version <= "3.6"
packaging

0 comments on commit 31f0f51

Please sign in to comment.