Skip to content

Commit

Permalink
Added test for locally stored hub Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
DebadityaPal committed Jan 29, 2021
1 parent b986d17 commit 9b60667
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions benchmarks/benchmark_tiledb_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@

def time_tiledb(dataset, batch_size=1):
ds = hub.Dataset(dataset)
if os.path.exists("./test/" + dataset.split("/")[1]):
ds_tldb = tiledb.open("./test/" + dataset.split("/")[1])
if os.path.exists(dataset.split("/")[1] + "_tileDB"):
ds_tldb = tiledb.open(dataset.split("/")[1] + "_tileDB")
else:
if not os.path.exists("./test"):
os.makedirs("test")
if not os.path.exists(dataset.split("/")[1] + "_tileDB"):
os.makedirs(dataset.split("/")[1] + "_tileDB")
ds_numpy = np.concatenate(
(
ds["image"].compute().reshape(ds.shape[0], -1),
ds["label"].compute().reshape(ds.shape[0], -1),
),
axis=1,
)
ds_tldb = tiledb.from_numpy("./test/" + dataset.split("/")[1], ds_numpy)
ds_tldb = tiledb.from_numpy(dataset.split("/")[1] + "_tileDB", ds_numpy)

assert type(ds_tldb) == tiledb.array.DenseArray

Expand Down Expand Up @@ -63,9 +63,14 @@ def time_hub(dataset, batch_size=1):

if __name__ == "__main__":
for dataset in datasets:
data = hub.Dataset.from_tfds(dataset.split("/")[1])
data.store("./" + dataset.split("/")[1] + "_hub")

for batch_size in batch_sizes:
print("Dataset: ", dataset, "with Batch Size: ", batch_size)
print("Performance of TileDB")
time_tiledb(dataset, batch_size)
print("Performance of Hub")
print("Performance of Hub (Stored on the Cloud):")
time_hub(dataset, batch_size)
print("Performance of Hub (Stored Locally):")
time_hub("./" + dataset.split("/")[1] + "_hub", batch_size)

0 comments on commit 9b60667

Please sign in to comment.