Skip to content

Commit

Permalink
Prevent dataset to break if it already exists (#19491)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Feb 16, 2024
1 parent ddf2ac4 commit 53ea76a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/lightning/data/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.processing.readers import BaseReader
from lightning.data.processing.utilities import _create_dataset
from lightning.data.streaming import Cache
from lightning.data.streaming.cache import Dir
from lightning.data.streaming.client import S3Client
Expand All @@ -41,7 +42,6 @@

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.openapi import V1DatasetType
from lightning_cloud.utils.dataset import _create_dataset


if _BOTO3_AVAILABLE:
Expand Down Expand Up @@ -973,7 +973,8 @@ def run(self, data_recipe: DataRecipe) -> None:
print("Workers are finished.")
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

if num_nodes == node_rank + 1 and self.output_dir.url:
if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO:
assert self.output_dir.path
_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
Expand Down
70 changes: 68 additions & 2 deletions src/lightning/data/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,75 @@
import urllib
from contextlib import contextmanager
from subprocess import Popen
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union

from lightning.data.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_LATEST

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.openapi import (
ProjectIdDatasetsBody,
V1DatasetType,
)
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient


def _create_dataset(
input_dir: Optional[str],
storage_dir: str,
dataset_type: V1DatasetType,
empty: Optional[bool] = None,
size: Optional[int] = None,
num_bytes: Optional[str] = None,
data_format: Optional[Union[str, Tuple[str]]] = None,
compression: Optional[str] = None,
num_chunks: Optional[int] = None,
num_bytes_per_chunk: Optional[List[int]] = None,
name: Optional[str] = None,
version: Optional[int] = None,
) -> None:
"""Create a dataset with metadata information about its source and destination."""
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
user_id = os.getenv("LIGHTNING_USER_ID", None)
cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None)
lightning_app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None)

if project_id is None:
return

if not storage_dir:
raise ValueError("The storage_dir should be defined.")

client = LightningClient(retry=False)

from lightning.data.constants import _IS_IN_STUDIO
try:
client.dataset_service_create_dataset(
body=ProjectIdDatasetsBody(
cloud_space_id=cloud_space_id if lightning_app_id is None else None,
cluster_id=cluster_id,
creator_id=user_id,
empty=empty,
input_dir=input_dir,
lightning_app_id=lightning_app_id,
name=name,
size=size,
num_bytes=num_bytes,
data_format=str(data_format) if data_format else data_format,
compression=compression,
num_chunks=num_chunks,
num_bytes_per_chunk=num_bytes_per_chunk,
storage_dir=storage_dir,
type=dataset_type,
version=version,
),
project_id=project_id,
)
except ApiException as ex:
if "already exists" in str(ex.body):
pass
else:
raise ex


def get_worker_rank() -> Optional[str]:
Expand Down

0 comments on commit 53ea76a

Please sign in to comment.