diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index b6f66b3c..ca415ecc 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -19,11 +19,10 @@ jobs: fail-fast: false matrix: os: ["ubuntu-22.04", "macos-14", "windows-2022"] - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] exclude: - { os: "windows-2022", python-version: "3.13" } - - { os: "macos-14", python-version: "3.12" } - - { os: "macos-14", python-version: "3.13" } + - { os: "windows-2022", python-version: "3.14" } # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 60 diff --git a/setup.py b/setup.py index 5d92d9ed..e320d57d 100644 --- a/setup.py +++ b/setup.py @@ -99,5 +99,7 @@ def _prepare_extras(requirements_dir: str = _PATH_REQUIRES, skip_files: tuple = "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ], ) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 27205cea..fa47b346 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -624,7 +624,7 @@ def __init__( self._num_samples_yielded_wrapper: dict[int, list[int]] = {} self._num_cycles: dict[int, list[int]] = {} self.rng_state: Optional[Any] = None - self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) + self._worker_idx: Optional[Any] = None # Lazily initialized in __iter__ self._worker_idx_iter: Optional[Any] = None self._latest_worker_idx = 0 self.restore = False @@ -767,6 +767,9 @@ def load_state_dict(self, obj: dict[str, Any]) -> None: # Used to restart on the next DataLoader worker from the previous run. self._latest_worker_idx = obj["latest_worker_idx"] + 1 + # Initialize _worker_idx if not already set (e.g., when loading state before first iteration) + if self._worker_idx is None: + self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) self._worker_idx_iter = iter(self._worker_idx) for _ in range(self._latest_worker_idx): next(self._worker_idx_iter)