diff --git a/pyproject.toml b/pyproject.toml index 845f811de..61127e5f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ lint.extend-select = [ "SIM", # see: https://pypi.org/project/flake8-simplify "RET", # see: https://pypi.org/project/flake8-return "PT", # see: https://pypi.org/project/flake8-pytest-style + "NPY201", # see: https://docs.astral.sh/ruff/rules/numpy2-deprecation "RUF100" # yesqa ] lint.ignore = [ diff --git a/requirements.txt b/requirements.txt index 90f04087e..b4920f3ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch filelock -numpy < 2.0.0 +numpy boto3 requests diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 6caa75201..b001507bc 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -13,6 +13,7 @@ import os from pathlib import Path +from typing import Dict import numpy as np import torch @@ -59,8 +60,27 @@ 19: torch.bool, } -_NUMPY_SCTYPES = [v for values in np.sctypes.values() for v in values] -_NUMPY_DTYPES_MAPPING = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)} +_NUMPY_SCTYPES = [ # All NumPy scalar types from np.core.sctypes.values() + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + np.complex64, + np.complex128, + bool, + object, + bytes, + str, + np.void, +] +_NUMPY_DTYPES_MAPPING: Dict[int, np.dtype] = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)} _TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" _IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None)) diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py index 746e41c08..31789c833 100644 --- a/tests/streaming/test_serializer.py +++ b/tests/streaming/test_serializer.py @@ -204,10 +204,13 @@ def test_assert_no_header_tensor_serializer(): def test_assert_no_header_numpy_serializer(): serializer = NoHeaderNumpySerializer() - t = np.ones((10,)) + t = np.ones((10,), dtype=np.float64) assert serializer.can_serialize(t) data, name = serializer.serialize(t) - assert name == "no_header_numpy:10" + try: + assert name == "no_header_numpy:10" + except AssertionError as e: # debug what np.core.sctypes looks like on Windows + raise ValueError(np.core.sctypes) from e assert serializer._dtype is None serializer.setup(name) assert serializer._dtype == np.dtype("float64")