Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions tensorstore/driver/zarr3/chunk_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ ZarrChunkCache::~ZarrChunkCache() = default;

ZarrLeafChunkCache::ZarrLeafChunkCache(
kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state,
ZarrDType dtype, internal::CachePool::WeakPtr /*data_cache_pool*/)
ZarrDType dtype, internal::CachePool::WeakPtr /*data_cache_pool*/,
bool open_as_void)
: Base(std::move(store)),
codec_state_(std::move(codec_state)),
dtype_(std::move(dtype)) {}
dtype_(std::move(dtype)),
open_as_void_(open_as_void) {}

void ZarrLeafChunkCache::Read(ZarrChunkCache::ReadRequest request,
AnyFlowReceiver<absl::Status, internal::ReadChunk,
Expand Down Expand Up @@ -157,7 +159,7 @@ ZarrLeafChunkCache::DecodeChunk(span<const Index> chunk_indices,
absl::InlinedVector<SharedArray<const void>, 1> field_arrays(num_fields);

// Special case: void access - return raw bytes directly
if (num_fields == 1 && dtype_.fields[0].name == "<void>") {
if (open_as_void_) {
TENSORSTORE_ASSIGN_OR_RETURN(
field_arrays[0], codec_state_->DecodeArray(grid().components[0].shape(),
std::move(data)));
Expand Down Expand Up @@ -221,11 +223,13 @@ kvstore::Driver* ZarrLeafChunkCache::GetKvStoreDriver() {

ZarrShardedChunkCache::ZarrShardedChunkCache(
kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state,
ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool)
ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool,
bool open_as_void)
: base_kvstore_(std::move(store)),
codec_state_(std::move(codec_state)),
dtype_(std::move(dtype)),
data_cache_pool_(std::move(data_cache_pool)) {}
data_cache_pool_(std::move(data_cache_pool)),
open_as_void_(open_as_void) {}

Result<IndexTransform<>> TranslateCellToSourceTransformForShard(
IndexTransform<> transform, span<const Index> grid_cell_indices,
Expand Down Expand Up @@ -534,7 +538,7 @@ void ZarrShardedChunkCache::Entry::DoInitialize() {
*sharding_state.sub_chunk_codec_chain,
std::move(sharding_kvstore), cache.executor(),
ZarrShardingCodec::PreparedState::Ptr(&sharding_state),
cache.dtype_, cache.data_cache_pool_);
cache.dtype_, cache.data_cache_pool_, cache.open_as_void_);
zarr_chunk_cache = new_cache.release();
return std::unique_ptr<internal::Cache>(&zarr_chunk_cache->cache());
})
Expand Down
14 changes: 10 additions & 4 deletions tensorstore/driver/zarr3/chunk_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class ZarrLeafChunkCache : public internal::KvsBackedChunkCache,
explicit ZarrLeafChunkCache(kvstore::DriverPtr store,
ZarrCodecChain::PreparedState::Ptr codec_state,
ZarrDType dtype,
internal::CachePool::WeakPtr data_cache_pool);
internal::CachePool::WeakPtr data_cache_pool,
bool open_as_void = false);

void Read(ZarrChunkCache::ReadRequest request,
AnyFlowReceiver<absl::Status, internal::ReadChunk,
Expand Down Expand Up @@ -186,6 +187,7 @@ class ZarrLeafChunkCache : public internal::KvsBackedChunkCache,

ZarrCodecChain::PreparedState::Ptr codec_state_;
ZarrDType dtype_;
bool open_as_void_;
};

/// Chunk cache for a Zarr array where each chunk is a shard.
Expand All @@ -196,7 +198,8 @@ class ZarrShardedChunkCache : public internal::Cache, public ZarrChunkCache {
explicit ZarrShardedChunkCache(kvstore::DriverPtr store,
ZarrCodecChain::PreparedState::Ptr codec_state,
ZarrDType dtype,
internal::CachePool::WeakPtr data_cache_pool);
internal::CachePool::WeakPtr data_cache_pool,
bool open_as_void = false);

const ZarrShardingCodec::PreparedState& sharding_codec_state() const {
return static_cast<const ZarrShardingCodec::PreparedState&>(
Expand Down Expand Up @@ -246,6 +249,7 @@ class ZarrShardedChunkCache : public internal::Cache, public ZarrChunkCache {
kvstore::DriverPtr base_kvstore_;
ZarrCodecChain::PreparedState::Ptr codec_state_;
ZarrDType dtype_;
bool open_as_void_;

// Data cache pool, if it differs from `this->pool()` (which is equal to the
// metadata cache pool).
Expand All @@ -260,11 +264,13 @@ class ZarrShardSubChunkCache : public ChunkCacheImpl {
explicit ZarrShardSubChunkCache(
kvstore::DriverPtr store, Executor executor,
ZarrShardingCodec::PreparedState::Ptr sharding_state,
ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool)
ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool,
bool open_as_void = false)
: ChunkCacheImpl(std::move(store),
ZarrCodecChain::PreparedState::Ptr(
sharding_state->sub_chunk_codec_state),
std::move(dtype), std::move(data_cache_pool)),
std::move(dtype), std::move(data_cache_pool),
open_as_void),
sharding_state_(std::move(sharding_state)),
executor_(std::move(executor)) {}

Expand Down
38 changes: 12 additions & 26 deletions tensorstore/driver/zarr3/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,9 @@ class ZarrDriverSpec
jb::Member("field", jb::Projection<&ZarrDriverSpec::selected_field>(
jb::DefaultValue<jb::kNeverIncludeDefaults>(
[](auto* obj) { *obj = std::string{}; }))),

// NEW: wrap the open_as_void projection in a Validate
jb::Member("open_as_void",
jb::Validate(
[](const auto& options, ZarrDriverSpec* obj) -> absl::Status {
// At this point, Projection has already set obj->open_as_void
if (obj->open_as_void) {
obj->selected_field = "<void>";
}
return absl::OkStatus();
},
jb::Projection<&ZarrDriverSpec::open_as_void>(
jb::Member("open_as_void", jb::Projection<&ZarrDriverSpec::open_as_void>(
jb::DefaultValue<jb::kNeverIncludeDefaults>(
[](auto* v) { *v = false; })))));
[](auto* v) { *v = false; }))));



Expand Down Expand Up @@ -592,10 +581,7 @@ class ZarrDataCache : public ChunkCacheImpl, public DataCacheBase {
grid_(DataCacheBase::GetChunkGridSpecification(
metadata(),
// Check if this is void access by examining the dtype
(ChunkCacheImpl::dtype_.fields.size() == 1 &&
ChunkCacheImpl::dtype_.fields[0].name == "<void>")
? kVoidFieldIndex
: 0)) {}
ChunkCacheImpl::open_as_void_ ? kVoidFieldIndex : false)) {}

const internal::LexicographicalGridIndexKeyParser& GetChunkStorageKeyParser()
final {
Expand Down Expand Up @@ -626,9 +612,8 @@ class ZarrDataCache : public ChunkCacheImpl, public DataCacheBase {
const void* metadata_ptr, size_t component_index) override {
const auto& metadata = *static_cast<const ZarrMetadata*>(metadata_ptr);

// Check if this is void access by examining the cache's dtype
const bool is_void_access = (ChunkCacheImpl::dtype_.fields.size() == 1 &&
ChunkCacheImpl::dtype_.fields[0].name == "<void>");
// Check if this is void access by examining the stored flag
const bool is_void_access = ChunkCacheImpl::open_as_void_;

if (is_void_access) {
// For void access, create transform with extra bytes dimension
Expand Down Expand Up @@ -802,7 +787,7 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
TENSORSTORE_ASSIGN_OR_RETURN(
auto metadata,
internal_zarr3::GetNewMetadata(spec().metadata_constraints,
spec().schema),
spec().schema, spec().selected_field, spec().open_as_void),
tensorstore::MaybeAnnotateStatus(
_, "Cannot create using specified \"metadata\" and schema"));
return metadata;
Expand All @@ -819,15 +804,15 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
*static_cast<const ZarrMetadata*>(initializer.metadata.get());
// For void access, modify the dtype to indicate special handling
ZarrDType dtype = metadata.data_type;
if (spec().selected_field == "<void>") {
if (spec().open_as_void) {
// Create a synthetic dtype for void access
dtype = ZarrDType{
/*.has_fields=*/false,
/*.fields=*/{ZarrDType::Field{
ZarrDType::BaseDType{"<void>", dtype_v<tensorstore::dtypes::byte_t>,
ZarrDType::BaseDType{"", dtype_v<tensorstore::dtypes::byte_t>,
{metadata.data_type.bytes_per_outer_element}},
/*.outer_shape=*/{},
/*.name=*/"<void>",
/*.name=*/"",
/*.field_shape=*/{metadata.data_type.bytes_per_outer_element},
/*.num_inner_elements=*/metadata.data_type.bytes_per_outer_element,
/*.byte_offset=*/0,
Expand All @@ -837,7 +822,8 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
return internal_zarr3::MakeZarrChunkCache<DataCacheBase, ZarrDataCache>(
*metadata.codecs, std::move(initializer), spec().store.path,
metadata.codec_state, dtype,
/*data_cache_pool=*/*cache_pool());
/*data_cache_pool=*/*cache_pool(),
spec().open_as_void);
}

Result<size_t> GetComponentIndex(const void* metadata_ptr,
Expand All @@ -847,7 +833,7 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
ValidateMetadata(metadata, spec().metadata_constraints));
TENSORSTORE_ASSIGN_OR_RETURN(
auto field_index,
GetFieldIndex(metadata.data_type, spec().selected_field));
GetFieldIndex(metadata.data_type, spec().selected_field, spec().open_as_void));
// For void access, map to component index 0
if (field_index == kVoidFieldIndex) {
field_index = 0;
Expand Down
14 changes: 8 additions & 6 deletions tensorstore/driver/zarr3/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -799,12 +799,14 @@ std::string GetFieldNames(const ZarrDType& dtype) {
constexpr size_t kVoidFieldIndex = size_t(-1);

Result<size_t> GetFieldIndex(const ZarrDType& dtype,
std::string_view selected_field) {
// Special case: "<void>" requests raw byte access (works for any dtype)
if (selected_field == "<void>") {
std::string_view selected_field,
bool open_as_void) {
// Special case: open_as_void requests raw byte access (works for any dtype)

if (open_as_void) {
if (dtype.fields.empty()) {
return absl::FailedPreconditionError(
"Requested field \"<void>\" but dtype has no fields");
"Requested void access but dtype has no fields");
}
return kVoidFieldIndex;
}
Expand Down Expand Up @@ -1138,7 +1140,7 @@ absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata,

Result<std::shared_ptr<const ZarrMetadata>> GetNewMetadata(
const ZarrMetadataConstraints& metadata_constraints, const Schema& schema,
std::string_view selected_field) {
std::string_view selected_field, bool open_as_void) {
auto metadata = std::make_shared<ZarrMetadata>();

metadata->zarr_format = metadata_constraints.zarr_format.value_or(3);
Expand All @@ -1165,7 +1167,7 @@ Result<std::shared_ptr<const ZarrMetadata>> GetNewMetadata(
}

TENSORSTORE_ASSIGN_OR_RETURN(
size_t field_index, GetFieldIndex(metadata->data_type, selected_field));
size_t field_index, GetFieldIndex(metadata->data_type, selected_field, open_as_void));
SpecRankAndFieldInfo info;
info.field = &metadata->data_type.fields[field_index];
info.chunked_rank = metadata_constraints.rank;
Expand Down
6 changes: 4 additions & 2 deletions tensorstore/driver/zarr3/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,14 @@ absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata,
/// unspecified.
Result<std::shared_ptr<const ZarrMetadata>> GetNewMetadata(
const ZarrMetadataConstraints& metadata_constraints,
const Schema& schema, std::string_view selected_field = {});
const Schema& schema, std::string_view selected_field = {},
bool open_as_void = false);

absl::Status ValidateDataType(DataType dtype);

Result<size_t> GetFieldIndex(const ZarrDType& dtype,
std::string_view selected_field);
std::string_view selected_field,
bool open_as_void = false);

struct SpecRankAndFieldInfo {
DimensionIndex chunked_rank = dynamic_rank;
Expand Down
2 changes: 1 addition & 1 deletion tensorstore/driver/zarr3/metadata_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ Result<std::shared_ptr<const ZarrMetadata>> TestGetNewMetadata(
TENSORSTORE_RETURN_IF_ERROR(status);
TENSORSTORE_ASSIGN_OR_RETURN(
auto constraints, ZarrMetadataConstraints::FromJson(constraints_json));
return GetNewMetadata(constraints, schema);
return GetNewMetadata(constraints, schema, /*selected_field=*/{}, /*open_as_void=*/false);
}

TEST(GetNewMetadataTest, DuplicateDimensionNames) {
Expand Down