Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,25 @@ def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
return item_sizes


def _to_path(element: str) -> str:
return element if _IS_IN_STUDIO and element.startswith("/teamspace") else str(Path(element).resolve())


def _is_path(input_dir: Optional[str], element: Any) -> bool:
if not isinstance(element, str):
return False

if _IS_IN_STUDIO and input_dir is not None:
if element.startswith(input_dir):
return True

element = str(Path(element).absolute())
if element.startswith(input_dir):
return True

return os.path.exists(element)


class BaseWorker:
def __init__(
self,
Expand Down Expand Up @@ -381,7 +400,6 @@ def __init__(
self.remove_queue: Queue = Queue()
self.progress_queue: Queue = progress_queue
self.error_queue: Queue = error_queue
self._collected_items = 0
self._counter = 0
self._last_time = time()
self._index_counter = 0
Expand Down Expand Up @@ -504,22 +522,13 @@ def _collect_paths(self) -> None:
for item in self.items:
flattened_item, spec = tree_flatten(item)

def is_path(element: Any) -> bool:
if not isinstance(element, str):
return False

element: str = str(Path(element).resolve())
if _IS_IN_STUDIO and self.input_dir.path is not None:
if self.input_dir.path.startswith("/teamspace/studios/this_studio"):
return os.path.exists(element)
return element.startswith(self.input_dir.path)
return os.path.exists(element)

# For speed reasons, we assume starting with `self.input_dir` is enough to be a real file.
# Other alternative would be too slow.
# TODO: Try using dictionary for higher accurary.
indexed_paths = {
index: str(Path(element).resolve()) for index, element in enumerate(flattened_item) if is_path(element)
index: _to_path(element)
for index, element in enumerate(flattened_item)
if _is_path(self.input_dir.path, element)
}

if len(indexed_paths) == 0:
Expand All @@ -537,7 +546,6 @@ def is_path(element: Any) -> bool:
self.paths.append(paths)

items.append(tree_unflatten(flattened_item, spec))
self._collected_items += 1

self.items = items

Expand Down
23 changes: 23 additions & 0 deletions tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
DataTransformRecipe,
_download_data_target,
_get_item_filesizes,
_is_path,
_map_items_to_workers_sequentially,
_map_items_to_workers_weighted,
_remove_target,
_to_path,
_upload_fn,
_wait_for_disk_usage_higher_than_threshold,
_wait_for_file_to_exist,
Expand Down Expand Up @@ -1136,3 +1138,24 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression):
tensor = torchaudio.load(sample)
assert tensor[0].shape == torch.Size([1, 16000])
assert tensor[1] == 16000


def test_is_path_valid_in_studio(monkeypatch, tmpdir):
filepath = os.path.join(tmpdir, "a.png")
with open(filepath, "w") as f:
f.write("Hello World")

monkeypatch.setattr(data_processor_module, "_IS_IN_STUDIO", True)

assert _is_path("/teamspace/studios/this_studio", "/teamspace/studios/this_studio/a.png")
assert _is_path("/teamspace/studios/this_studio", filepath)


@pytest.mark.skipif(sys.platform == "win32", reason="skip windows")
def test_to_path(tmpdir):
filepath = os.path.join(tmpdir, "a.png")
with open(filepath, "w") as f:
f.write("Hello World")

assert _to_path("/teamspace/studios/this_studio/a.png") == "/teamspace/studios/this_studio/a.png"
assert _to_path(filepath) == filepath