diff --git a/src/litdata/__about__.py b/src/litdata/__about__.py index f1cdb23b2..b6bb069f9 100644 --- a/src/litdata/__about__.py +++ b/src/litdata/__about__.py @@ -14,7 +14,7 @@ import time -__version__ = "0.2.30" +__version__ = "0.2.31" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 98ce5fefe..528c9bfd0 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -346,6 +346,7 @@ def _execute( num_nodes: int, machine: Optional["Machine"] = None, command: Optional[str] = None, + interruptible: bool = False, ) -> None: """Remotely execute the current operator.""" if not _LIGHTNING_SDK_AVAILABLE: @@ -370,6 +371,7 @@ def _execute( teamspace_id=studio._teamspace.id, cluster_id=studio._studio.cluster_id, machine=machine or studio._studio_api.get_machine(studio._studio.id, studio._teamspace.id), + interruptible=interruptible, ) has_printed = False diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 90729ffbd..8c774c235 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -291,6 +291,7 @@ def print_fn(msg, file=None): "teamspace_id": "teamspace_id", "cluster_id": "cluster_id", "machine": "cpu", + "interruptible": False, } generated_kwargs = (