Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/readers/test_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def sparse_uri(tmp_path_factory, non_zero_per_row=3):
d2 = np.random.uniform(domains[2][0], domains[2][1], num_cells)
d3_days = (domains[3][1] - domains[3][0]).astype(int)
d3 = domains[3][0] + np.random.randint(0, d3_days, num_cells)
d4 = np.random.choice(list(string.ascii_lowercase), num_cells)
# TODO: increase 8 to something larger after sh-19349 is fixed
d4 = np.random.choice([c * 8 for c in string.ascii_lowercase], num_cells)
a[d0, d1, d2, d3, d4] = {
"af8": np.random.rand(num_cells),
"af4": np.random.rand(num_cells).astype(np.float32),
Expand Down
27 changes: 21 additions & 6 deletions tiledb/ml/readers/_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,27 @@ def max_partition_weight(self) -> int:
except KeyError:
memory_budget = 10 * 1024**2

# The memory budget should be large enough to read the cells of the largest field
bytes_per_cell = max(
self._array.schema.attr_or_dim_dtype(field).itemsize
for field in self._query_kwargs["dims"] + self._query_kwargs["attrs"]
)
return max(1, memory_budget // int(bytes_per_cell))
# Determine the bytes per (non-empty) cell for each field.
# - For fixed size fields, this is just the `dtype.itemsize` of the field.
# - For variable size fields, the best we can do is to estimate the average bytes
# size. We also need to take into account the (fixed) offset buffer size per
# cell (=8 bytes).
offset_itemsize = np.dtype(np.uint64).itemsize
attr_or_dim_dtype = self._array.schema.attr_or_dim_dtype
query = self._array.query(return_incomplete=True, **self._query_kwargs)
est_sizes = query.multi_index[:].estimated_result_sizes()
bytes_per_cell = []
for field, est_result_size in est_sizes.items():
if est_result_size.offsets_bytes == 0:
bytes_per_cell.append(attr_or_dim_dtype(field).itemsize)
else:
num_cells = est_result_size.offsets_bytes / offset_itemsize
avg_itemsize = est_result_size.data_bytes / num_cells
bytes_per_cell.append(max(avg_itemsize, offset_itemsize))

# Finally, the number of cells that can fit in the memory_budget depends on the
# maximum bytes_per_cell
return max(1, memory_budget // ceil(max(bytes_per_cell)))


class KeyDimQuery:
Expand Down