Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fsspec local file protocol checks for new fsspec version #19023

Merged
merged 12 commits into from
Nov 18, 2023
Merged
2 changes: 1 addition & 1 deletion requirements/app/app.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ packaging
typing-extensions >=4.4.0, <4.8.0
deepdiff >=5.7.0, <6.6.0
starsessions >=1.2.1, <2.0 # strict
fsspec >=2022.5.0, <2023.10.0
fsspec >=2022.5.0, <2023.11.0
croniter >=1.3.0, <1.5.0 # strict; TODO: for now until we find something more robust.
traitlets >=5.3.0, <5.10.0
arrow >=1.2.0, <1.3.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/data/cloud.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

fsspec[http] >2021.06.0, <2023.10.0
fsspec[http] >2021.06.0, <2023.11.0
s3fs >=2022.5.0, <2023.7.0
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

numpy >=1.17.2, <1.27.0
torch >=1.12.0, <2.2.0
fsspec[http]>2021.06.0, <2023.10.0
fsspec[http]>2021.06.0, <2023.11.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.8.0
lightning-utilities >=0.8.0, <0.10.0
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ numpy >=1.17.2, <1.27.0
torch >=1.12.0, <2.2.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >2021.06.0, <2023.10.0
fsspec[http] >2021.06.0, <2023.11.0
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.8.0
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import IO, Any, Dict, Union

import fsspec
import fsspec.utils
import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import AbstractFileSystem
Expand Down Expand Up @@ -128,3 +129,7 @@ def _is_dir(fs: AbstractFileSystem, path: Union[str, Path], strict: bool = False
return not fs.isfile(path)

return fs.isdir(path)


def _is_local_file_protocol(path: _PATH) -> bool:
return fsspec.utils.get_protocol(str(path)) == "file"
4 changes: 2 additions & 2 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union

import fsspec.utils
import torch
import torch.nn.functional as F
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler

from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp
Expand Down Expand Up @@ -48,7 +48,7 @@ def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, tim

"""
# Fast path: Any non-local filesystem is considered shared (e.g., S3)
if path is not None and fsspec.utils.get_protocol(str(path)) != "file":
if path is not None and not _is_local_file_protocol(path):
return True

path = Path(Path.cwd() if path is None else path).resolve()
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed checks for local file protocol due to fsspec changes in 2023.10.0 ([#19023](https://github.com/Lightning-AI/lightning/pull/19023))



## [2.1.2] - 2023-11-15
Expand All @@ -52,6 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/lightning/issues/18954))
- Fixed an issue preventing the user to `Trainer.save_checkpoint()` an FSDP model when `Trainer.test/validate/predict()` ran after `Trainer.fit()` ([#18992](https://github.com/Lightning-AI/lightning/issues/18992))


## [2.1.1] - 2023-11-06

### Fixed
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.callbacks import Checkpoint
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -457,7 +457,7 @@ def __validate_init_configuration(self) -> None:
def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
self._fs = get_filesystem(dirpath if dirpath else "")

if dirpath and self._fs.protocol == "file":
if dirpath and _is_local_file_protocol(dirpath if dirpath else ""):
dirpath = os.path.realpath(dirpath)

self.dirpath = dirpath
Expand Down Expand Up @@ -675,7 +675,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[

# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
if self._fs.protocol == "file" and self._last_checkpoint_saved and self.save_top_k != 0:
if _is_local_file_protocol(filepath) and self._last_checkpoint_saved and self.save_top_k != 0:
self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
else:
self._save_checkpoint(trainer, filepath)
Expand Down Expand Up @@ -771,7 +771,7 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren
"""
if previous == current:
return False
if self._fs.protocol != "file":
if not _is_local_file_protocol(previous):
return True
previous = Path(previous).absolute()
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import lightning.pytorch as pl
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators import Accelerator
Expand Down Expand Up @@ -1285,7 +1285,7 @@ def default_root_dir(self) -> str:
It is used as a fallback if logger or checkpoint callback do not define specific save paths.

"""
if get_filesystem(self._default_root_dir).protocol == "file":
if _is_local_file_protocol(self._default_root_dir):
return os.path.normpath(self._default_root_dir)
return self._default_root_dir

Expand Down