Skip to content

Commit

Permalink
fixed an issue when overwriting transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
AbhinavTuli committed Mar 29, 2021
1 parent f688ec3 commit fc864eb
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
4 changes: 3 additions & 1 deletion hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ def __init__(
self._version_node = VersionNode(self._commit_id, self._branch)
self._branch_node_map = {self._branch: self._version_node}
self._commit_node_map = {self._commit_id: self._version_node}
self._chunk_commit_map = {
path: defaultdict(set) for schema, path in self._flat_tensors
}
self._tensors = dict(self._generate_storage_tensors())
self._chunk_commit_map = {key: defaultdict(set) for key in self.keys}
except Exception as e:
try:
self.close()
Expand Down
26 changes: 25 additions & 1 deletion hub/compute/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import numpy as np
import zarr

import os
from hub.cli.auth import login_fn
import hub
from hub.schema import Tensor, Image, Text
from hub.utils import Timer
Expand Down Expand Up @@ -364,6 +365,29 @@ def my_transform(sample):
out_ds.store(f"./data/test/test_pipeline_basic_output_{name}")


def test_transform_overwrite():
password = os.getenv("ACTIVELOOP_HUB_PASSWORD")
login_fn("testingacc", password)

schema = {
"image": hub.schema.Image(
shape=(None, None, 1), dtype="uint8", max_shape=(1, 1, 1)
),
}

@hub.transform(schema=schema)
def create_image(value):
return {"image": np.ones((1, 1, 1), dtype="uint8") * value}

ds1 = create_image(range(5))
ds = ds1.store("testingacc/ds_transform", public=False)
for i in range(5):
assert (ds["image", i].compute() == i * np.ones((1, 1, 1))).all()
ds = ds1.store("testingacc/ds_transform", public=False)
for i in range(5):
assert (ds["image", i].compute() == i * np.ones((1, 1, 1))).all()


if __name__ == "__main__":
with Timer("Test Transform"):
with Timer("test threaded"):
Expand Down
10 changes: 7 additions & 3 deletions hub/store/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""
from collections import defaultdict
import json
from collections.abc import MutableMapping
import posixpath
Expand Down Expand Up @@ -125,8 +124,13 @@ def __delitem__(self, k: str):
if self._ds._commit_id:
k = self.find_chunk(k) or f"{k}:{self._ds._commit_id}"
commit_id = k.split(":")[-1]
self._ds._chunk_commit_map[self._path][chunk_key].remove(commit_id)
del self._fs_map[k]
try:
self._ds._chunk_commit_map[self._path][chunk_key].remove(commit_id)
except Exception:
try:
del self._fs_map[k]
except Exception:
pass

def flush(self):
self._meta.flush()
Expand Down

0 comments on commit fc864eb

Please sign in to comment.