Skip to content

Commit

Permalink
moved some code out of dataset class
Browse files Browse the repository at this point in the history
  • Loading branch information
AbhinavTuli committed Mar 17, 2021
1 parent 7fffed2 commit f7c3e77
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 28 deletions.
32 changes: 4 additions & 28 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
slice_split,
str_to_int,
_copy_helper,
_get_compressor,
_get_dynamic_tensor_dtype,
)

import hub.schema.serialize
Expand Down Expand Up @@ -67,8 +69,6 @@
from hub.store.metastore import MetaStorage
from hub.client.hub_control import HubControlClient
from hub.schema import Audio, BBox, ClassLabel, Image, Sequence, Text, Video
from hub.numcodecs import PngCodec

from hub.utils import norm_cache, norm_shape, _tuple_product
from hub import defaults
import pickle
Expand Down Expand Up @@ -468,30 +468,6 @@ def _check_and_prepare_dir(self):
raise NotHubDatasetToAppendException()
return True

def _get_dynamic_tensor_dtype(self, t_dtype):
if isinstance(t_dtype, Primitive):
return t_dtype.dtype
elif isinstance(t_dtype.dtype, Primitive):
return t_dtype.dtype.dtype
else:
return "object"

def _get_compressor(self, compressor: str):
if compressor is None:
return None
elif compressor.lower() == "lz4":
return numcodecs.LZ4(numcodecs.lz4.DEFAULT_ACCELERATION)
elif compressor.lower() == "zstd":
return numcodecs.Zstd(numcodecs.zstd.DEFAULT_CLEVEL)
elif compressor.lower() == "default":
return "default"
elif compressor.lower() == "png":
return PngCodec(solo_channel=True)
else:
raise ValueError(
f"Wrong compressor: {compressor}, only LZ4 and ZSTD are supported"
)

def _generate_storage_tensors(self):
for t in self._flat_tensors:
t_dtype, t_path = t
Expand All @@ -513,9 +489,9 @@ def _generate_storage_tensors(self):
mode=self._mode,
shape=self._shape + t_dtype.shape,
max_shape=self._shape + t_dtype.max_shape,
dtype=self._get_dynamic_tensor_dtype(t_dtype),
dtype=_get_dynamic_tensor_dtype(t_dtype),
chunks=t_dtype.chunks,
compressor=self._get_compressor(t_dtype.compressor),
compressor=_get_compressor(t_dtype.compressor),
)

def _open_storage_tensors(self):
Expand Down
31 changes: 31 additions & 0 deletions hub/api/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from hub.exceptions import ModuleNotInstalledException, DirectoryNotEmptyException
import hashlib
import time
import numcodecs
import numcodecs.lz4
import numcodecs.zstd
from hub.schema.features import Primitive
from hub.numcodecs import PngCodec


def slice_split(slice_):
Expand Down Expand Up @@ -198,3 +203,29 @@ def _copy_helper(
src_fs=src_fs,
)
return dst_url


def _get_dynamic_tensor_dtype(t_dtype):
if isinstance(t_dtype, Primitive):
return t_dtype.dtype
elif isinstance(t_dtype.dtype, Primitive):
return t_dtype.dtype.dtype
else:
return "object"


def _get_compressor(compressor: str):
if compressor is None:
return None
elif compressor.lower() == "lz4":
return numcodecs.LZ4(numcodecs.lz4.DEFAULT_ACCELERATION)
elif compressor.lower() == "zstd":
return numcodecs.Zstd(numcodecs.zstd.DEFAULT_CLEVEL)
elif compressor.lower() == "default":
return "default"
elif compressor.lower() == "png":
return PngCodec(solo_channel=True)
else:
raise ValueError(
f"Wrong compressor: {compressor}, only LZ4, PNG and ZSTD are supported"
)
16 changes: 16 additions & 0 deletions hub/api/tests/test_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
from hub.api.dataset_utils import _get_compressor
import numcodecs
import numcodecs.lz4
import numcodecs.zstd
from hub.numcodecs import PngCodec


def test_get_compression():
assert _get_compressor("lz4") == numcodecs.LZ4(numcodecs.lz4.DEFAULT_ACCELERATION)
assert _get_compressor(None) is None
assert _get_compressor("default") == "default"
assert _get_compressor("zstd") == numcodecs.Zstd(numcodecs.zstd.DEFAULT_CLEVEL)
assert _get_compressor("png") == PngCodec(solo_channel=True)
with pytest.raises(ValueError):
_get_compressor("abcd")

0 comments on commit f7c3e77

Please sign in to comment.