Skip to content

Commit

Permalink
Fix fsspec local file protocol checks for new fsspec version (#19023)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit b8a96fe)
  • Loading branch information
awaelchli authored and Borda committed Dec 19, 2023
1 parent 3d060bd commit 285e784
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion requirements/app/app.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ packaging
typing-extensions >=4.0.0, <4.8.0
deepdiff >=5.7.0, <6.6.0
starsessions >=1.2.1, <2.0 # strict
fsspec >=2022.5.0, <2023.10.0
fsspec >=2022.5.0, <2023.11.0
croniter >=1.3.0, <1.5.0 # strict; TODO: for now until we find something more robust.
traitlets >=5.3.0, <5.10.0
arrow >=1.2.0, <1.3.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/data/cloud.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

fsspec[http] >2021.06.0, <2023.10.0
fsspec[http] >2021.06.0, <2023.11.0
s3fs >=2022.5.0, <2023.7.0
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

numpy >=1.17.2, <1.27.0
torch >=1.12.0, <2.2.0
fsspec[http]>2021.06.0, <2023.10.0
fsspec[http]>2021.06.0, <2023.11.0
packaging >=20.0, <=23.1
typing-extensions >=4.0.0, <4.8.0
lightning-utilities >=0.8.0, <0.10.0
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ numpy >=1.17.2, <1.27.0
torch >=1.12.0, <2.2.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >2021.06.0, <2023.10.0
fsspec[http] >2021.06.0, <2023.11.0
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
packaging >=20.0, <=23.1
typing-extensions >=4.0.0, <4.8.0
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import IO, Any, Dict, Union

import fsspec
import fsspec.utils
import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import AbstractFileSystem
Expand Down Expand Up @@ -128,3 +129,7 @@ def _is_dir(fs: AbstractFileSystem, path: Union[str, Path], strict: bool = False
return not fs.isfile(path)

return fs.isdir(path)


def _is_local_file_protocol(path: _PATH) -> bool:
return fsspec.utils.get_protocol(str(path)) == "file"
4 changes: 2 additions & 2 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union

import fsspec.utils
import torch
import torch.nn.functional as F
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler

from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp
Expand Down Expand Up @@ -48,7 +48,7 @@ def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, tim
"""
# Fast path: Any non-local filesystem is considered shared (e.g., S3)
if path is not None and fsspec.utils.get_protocol(str(path)) != "file":
if path is not None and not _is_local_file_protocol(path):
return True

path = Path(Path.cwd() if path is None else path).resolve()
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed checks for local file protocol due to fsspec changes in 2023.10.0 ([#19023](https://github.com/Lightning-AI/lightning/pull/19023))



## [2.1.2] - 2023-11-15
Expand All @@ -25,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/lightning/issues/18954))
- Fixed an issue preventing the user to `Trainer.save_checkpoint()` an FSDP model when `Trainer.test/validate/predict()` ran after `Trainer.fit()` ([#18992](https://github.com/Lightning-AI/lightning/issues/18992))


## [2.1.1] - 2023-11-06

### Fixed
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -457,7 +457,7 @@ def __validate_init_configuration(self) -> None:
def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
self._fs = get_filesystem(dirpath if dirpath else "")

if dirpath and self._fs.protocol == "file":
if dirpath and _is_local_file_protocol(dirpath if dirpath else ""):
dirpath = os.path.realpath(dirpath)

self.dirpath = dirpath
Expand Down Expand Up @@ -675,7 +675,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[

# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
if self._fs.protocol == "file" and self._last_checkpoint_saved and self.save_top_k != 0:
if _is_local_file_protocol(filepath) and self._last_checkpoint_saved and self.save_top_k != 0:
self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
else:
self._save_checkpoint(trainer, filepath)
Expand Down Expand Up @@ -771,7 +771,7 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren
"""
if previous == current:
return False
if self._fs.protocol != "file":
if not _is_local_file_protocol(previous):
return True
previous = Path(previous).absolute()
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import lightning.pytorch as pl
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators import Accelerator
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def default_root_dir(self) -> str:
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if get_filesystem(self._default_root_dir).protocol == "file":
if _is_local_file_protocol(self._default_root_dir):
return os.path.normpath(self._default_root_dir)
return self._default_root_dir

Expand Down

0 comments on commit 285e784

Please sign in to comment.