Skip to content

Commit

Permalink
Fix mypy typing for utilities.cloud_io.py (#8671)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <thomas@grid.ai>
  • Loading branch information
stancld and tchaton committed Aug 3, 2021
1 parent 8274183 commit 08ac079
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -64,6 +64,7 @@ module = [
"pytorch_lightning.trainer.connectors.logger_connector",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io",
"pytorch_lightning.utilities.device_dtype_mixin",
"pytorch_lightning.utilities.device_parser",
"pytorch_lightning.utilities.parsing",
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/utilities/cloud_io.py
Expand Up @@ -14,15 +14,20 @@

import io
from pathlib import Path
from typing import IO, Union
from typing import Any, Callable, Dict, IO, Optional, Union

import fsspec
import torch
from fsspec.implementations.local import LocalFileSystem
from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem
from packaging.version import Version


def load(path_or_url: Union[str, IO, Path], map_location=None):
def load(
path_or_url: Union[str, IO, Path],
map_location: Optional[
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
] = None,
) -> Any:
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similiar
return torch.load(path_or_url, map_location=map_location)
Expand All @@ -33,7 +38,7 @@ def load(path_or_url: Union[str, IO, Path], map_location=None):
return torch.load(f, map_location=map_location)


def get_filesystem(path: Union[str, Path]):
def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem:
path = str(path)
if "://" in path:
# use the fileystem from the protocol specified
Expand All @@ -42,7 +47,7 @@ def get_filesystem(path: Union[str, Path]):
return LocalFileSystem()


def atomic_save(checkpoint, filepath: str):
def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
Args:
Expand Down

0 comments on commit 08ac079

Please sign in to comment.