Skip to content
Merged
18 changes: 0 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,21 +341,11 @@ storage_options = {

dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options)

# s5cmd compatible storage options for a custom S3-compatible endpoint
# Note: If s5cmd is installed, it will be used by default for S3 operations. If you prefer not to use s5cmd, you can disable it by setting the environment variable: `DISABLE_S5CMD=1`
storage_options = {
"AWS_ACCESS_KEY_ID": "your_access_key_id",
"AWS_SECRET_ACCESS_KEY": "your_secret_access_key",
"S3_ENDPOINT_URL": "your_endpoint_url", # Required only for custom endpoints
}


dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options)
```

Alternative: Using `s5cmd` for S3 Operations


Also, you can specify a custom cache directory when initializing your dataset. This is useful when you want to store the cache in a specific location.
```python
from litdata import StreamingDataset
Expand Down Expand Up @@ -543,21 +533,13 @@ aws_storage_options={
}
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)


# Read data from AWS S3 using s5cmd
# Note: If s5cmd is installed, it will be used by default for S3 operations. If you prefer not to use s5cmd, you can disable it by setting the environment variable: `DISABLE_S5CMD=1`
aws_storage_options={
"AWS_ACCESS_KEY_ID": os.environ['AWS_ACCESS_KEY_ID'],
"AWS_SECRET_ACCESS_KEY": os.environ['AWS_SECRET_ACCESS_KEY'],
"S3_ENDPOINT_URL": os.environ['AWS_ENDPOINT_URL'], # Required only for custom endpoints
}
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)

# Read Data from AWS S3 with Unsigned Request using s5cmd
aws_storage_options={
"AWS_NO_SIGN_REQUEST": "Yes" # Required for unsigned requests
"S3_ENDPOINT_URL": os.environ['AWS_ENDPOINT_URL'], # Required only for custom endpoints
}
dataset = ld.StreamingDataset("s3://my-bucket/my-data", storage_options=aws_storage_options)


Expand Down
30 changes: 17 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ line-length = 120
exclude = [
".git",
"docs",
"src/litdata/debugger.py",
"src/litdata/utilities/_pytree.py",
]
# Enable Pyflakes `E` and `F` codes by default.
Expand Down Expand Up @@ -65,20 +66,22 @@ lint.per-file-ignores."examples/**" = [
]
lint.per-file-ignores."setup.py" = [ "D100", "SIM115" ]
lint.per-file-ignores."src/**" = [
"D100", # Missing docstring in public module
"D101", # todo: Missing docstring in public class
"D102", # todo: Missing docstring in public method
"D103", # todo: Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # todo: Missing docstring in magic method
"D107", # todo: Missing docstring in __init__
"D205", # todo: 1 blank line required between summary line and description
"D100", # Missing docstring in public module
"D101", # todo: Missing docstring in public class
"D102", # todo: Missing docstring in public method
"D103", # todo: Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # todo: Missing docstring in magic method
"D107", # todo: Missing docstring in __init__
"D205", # todo: 1 blank line required between summary line and description
"D401",
"D404", # todo: First line should be in imperative mood; try rephrasing
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected.
"S602", # todo: `subprocess` call with `shell=True` identified, security issue
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
"S607", # todo: Starting a process with a partial executable path
"D404", # todo: First line should be in imperative mood; try rephrasing
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected.
"S602", # todo: `subprocess` call with `shell=True` identified, security issue
"S605", # todo: Starting a process with a shell: seems safe, but may be changed in the future; consider rewriting without `shell`
"S607", # todo: Starting a process with a partial executable path
"UP006", # UP006 Use `list` instead of `List` for type annotation
"UP035", # UP035 `typing.Tuple` is deprecated, use `tuple` instead
]
lint.per-file-ignores."tests/**" = [
"D100",
Expand Down Expand Up @@ -166,6 +169,7 @@ exclude = [
"src/litdata/imports.py",
"src/litdata/imports.py",
"src/litdata/processing/data_processor.py",
"src/litdata/debugger.py",
]
install_types = "True"
non_interactive = "True"
Expand Down
1 change: 0 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ polars >1.0.0
lightning
transformers <4.53.0
zstd
s5cmd >=0.2.0
soundfile >=0.13.0 # required for torchaudio backend
1 change: 0 additions & 1 deletion src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@

_MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120"))
_FORCE_DOWNLOAD_TIME = int(os.getenv("FORCE_DOWNLOAD_TIME", "30"))
_DISABLE_S5CMD = bool(int(os.getenv("DISABLE_S5CMD", "0")))

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
168 changes: 114 additions & 54 deletions src/litdata/debugger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
Expand All @@ -13,92 +13,148 @@

import logging
import os
import sys
import re
import threading
import time
from functools import lru_cache

from litdata.constants import _PRINT_DEBUG_LOGS
from litdata.utilities.env import _DistributedEnv, _WorkerEnv
from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv

# Create the root logger for the library
root_logger = logging.getLogger("litdata")

class TimedFlushFileHandler(logging.FileHandler):
"""FileHandler that flushes every N seconds in a background thread."""

def __init__(self, filename, mode="a", flush_interval=2):
super().__init__(filename, mode)
self.flush_interval = flush_interval
self._stop_event = threading.Event()
t = threading.Thread(target=self._flusher, daemon=True, name="TimedFlushFileHandler._flusher")
t.start()

def _flusher(self):
while not self._stop_event.is_set():
time.sleep(self.flush_interval)
self.flush()

def close(self):
self._stop_event.set()
self.flush()
super().close()


class EnvConfigFilter(logging.Filter):
"""A logging filter that reads its configuration from environment variables."""

def __init__(self):
super().__init__()
self.name_re = re.compile(r"name:\s*([^;]+);")

def _get_name_from_msg(self, msg):
match = self.name_re.search(msg)
return match.group(1).strip() if match else None

def filter(self, record):
"""Determine if a log record should be processed by checking env vars."""
is_iterating_dataset_enabled = os.getenv("LITDATA_LOG_ITERATING_DATASET", "True").lower() == "true"
is_getitem_enabled = os.getenv("LITDATA_LOG_GETITEM", "True").lower() == "true"
is_item_loader_enabled = os.getenv("LITDATA_LOG_ITEM_LOADER", "True").lower() == "true"

log_name = self._get_name_from_msg(record.getMessage())

if log_name:
if not is_iterating_dataset_enabled and log_name.startswith("iterating_dataset"):
return False
if not is_getitem_enabled and log_name.startswith("getitem_dataset_for_chunk_index"):
return False
if not is_item_loader_enabled and log_name.startswith("item_loader"):
return False

return True


def get_logger_level(level: str) -> int:
"""Get the log level from the level string."""
level = level.upper()
if level in logging._nameToLevel:
return logging._nameToLevel[level]
raise ValueError(f"Invalid log level: {level}. Valid levels: {list(logging._nameToLevel.keys())}.")
raise ValueError(f"Invalid log level: {level}")


class LitDataLogger:
def __init__(self, name: str):
_instance = None
_lock = threading.Lock()

def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self, name="litdata", flush_interval=2):
if hasattr(self, "logger"):
return # Already initialized

self.logger = logging.getLogger(name)
self.logger.propagate = False
self.log_file, self.log_level = self.get_log_file_and_level()
self.setup_logger()
self.flush_interval = flush_interval
self._setup_logger()

@staticmethod
def get_log_file_and_level() -> tuple[str, int]:
def get_log_file_and_level():
log_file = os.getenv("LITDATA_LOG_FILE", "litdata_debug.log")
log_lvl = os.getenv("LITDATA_LOG_LEVEL", "DEBUG")
return log_file, get_logger_level(log_lvl)

log_lvl = get_logger_level(log_lvl)

return log_file, log_lvl

def setup_logger(self) -> None:
"""Configures logging by adding handlers and formatting."""
if len(self.logger.handlers) > 0: # Avoid duplicate handlers
def _setup_logger(self):
if self.logger.handlers:
return

self.logger.setLevel(self.log_level)
formatter = logging.Formatter("ts:%(created)s;PID:%(process)d; TID:%(thread)d; %(message)s")
handler = TimedFlushFileHandler(self.log_file, flush_interval=self.flush_interval)
handler.setFormatter(formatter)
handler.setLevel(self.log_level)
self.logger.addHandler(handler)

# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(self.log_level)

# File handler
file_handler = logging.FileHandler(self.log_file)
file_handler.setLevel(self.log_level)
self.logger.filters = [f for f in self.logger.filters if not isinstance(f, EnvConfigFilter)]
self.logger.addFilter(EnvConfigFilter())

# Log format
formatter = logging.Formatter(
"ts:%(created)s; logger_name:%(name)s; level:%(levelname)s; PID:%(process)d; TID:%(thread)d; %(message)s"
)
# ENV - f"{WORLD_SIZE, GLOBAL_RANK, NNODES, LOCAL_RANK, NODE_RANK}"
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
def get_logger(self):
return self.logger

# Attach handlers
if _PRINT_DEBUG_LOGS:
self.logger.addHandler(console_handler)
self.logger.addHandler(file_handler)


def enable_tracer() -> None:
def enable_tracer(
flush_interval: int = 5, item_loader=True, iterating_dataset=True, getitem_dataset_for_chunk_index=True
) -> logging.Logger:
"""Convenience function to enable and configure litdata logging.
This function SETS the environment variables that control the logging behavior.
"""
os.environ["LITDATA_LOG_FILE"] = "litdata_debug.log"
LitDataLogger("litdata")
os.environ["LITDATA_LOG_ITEM_LOADER"] = str(item_loader)
os.environ["LITDATA_LOG_ITERATING_DATASET"] = str(iterating_dataset)
os.environ["LITDATA_LOG_GETITEM"] = str(getitem_dataset_for_chunk_index)

master_logger = LitDataLogger(flush_interval=flush_interval).get_logger()
return master_logger


def _get_log_msg(data: dict) -> str:
log_msg = ""

if "name" not in data or "ph" not in data:
raise ValueError(f"Missing required keys in data dictionary. Required keys: 'name', 'ph'. Received: {data}")

env_info_data = env_info()
data.update(env_info_data)

for key, value in data.items():
log_msg += f"{key}: {value};"
return log_msg


@lru_cache(maxsize=1)
def env_info() -> dict:
dist_env = _DistributedEnv.detect()
worker_env = _WorkerEnv.detect() # will all threads read the same value if decorate this function with `@cache`
if _is_in_dataloader_worker():
return _cached_env_info()

dist_env = _DistributedEnv.detect()
worker_env = _WorkerEnv.detect()
return {
"dist_world_size": dist_env.world_size,
"dist_global_rank": dist_env.global_rank,
Expand All @@ -108,16 +164,20 @@ def env_info() -> dict:
}


# -> Chrome tracing colors
# url: https://chromium.googlesource.com/external/trace-viewer/+/bf55211014397cf0ebcd9e7090de1c4f84fc3ac0/tracing/tracing/ui/base/color_scheme.html

# # ------
@lru_cache(maxsize=1)
def _cached_env_info() -> dict:
dist_env = _DistributedEnv.detect()
worker_env = _WorkerEnv.detect()
return {
"dist_world_size": dist_env.world_size,
"dist_global_rank": dist_env.global_rank,
"dist_num_nodes": dist_env.num_nodes,
"worker_world_size": worker_env.world_size,
"worker_rank": worker_env.rank,
}


# thread_state_iowait: {r: 182, g: 125, b: 143},
# thread_state_running: {r: 126, g: 200, b: 148},
# thread_state_runnable: {r: 133, g: 160, b: 210},
# ....
# Chrome trace colors
class ChromeTraceColors:
PINK = "thread_state_iowait"
GREEN = "thread_state_running"
Expand Down
4 changes: 2 additions & 2 deletions src/litdata/streaming/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def compress(self, data: bytes) -> bytes:
def decompress(self, data: bytes) -> bytes:
import zstd

logger.debug(_get_log_msg({"name": "Decompressing data", "ph": "B", "cname": ChromeTraceColors.MUSTARD_YELLOW}))
logger.debug(_get_log_msg({"name": "decompress", "ph": "B", "cname": ChromeTraceColors.MUSTARD_YELLOW}))
decompressed_data = zstd.decompress(data)
logger.debug(_get_log_msg({"name": "Decompressed data", "ph": "E", "cname": ChromeTraceColors.MUSTARD_YELLOW}))
logger.debug(_get_log_msg({"name": "decompress", "ph": "E", "cname": ChromeTraceColors.MUSTARD_YELLOW}))
return decompressed_data

@classmethod
Expand Down
Loading
Loading