Skip to content

Commit

Permalink
Fixes shuffle + collate (#1552)
Browse files Browse the repository at this point in the history
* fixes shuffle + collate

* fix collate

* temp

* fix test

* link fix

* lint fixes

* fix my_collate
  • Loading branch information
AbhinavTuli committed Mar 28, 2022
1 parent ff63116 commit e6023fe
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 31 deletions.
46 changes: 26 additions & 20 deletions hub/integrations/pytorch/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Iterable, Optional, Sequence, List, Union
from hub.constants import MB
from hub.integrations.pytorch.common import PytorchTransformFunction
from hub.integrations.pytorch.common import PytorchTransformFunction, collate_fn
from hub.util.compute import get_compute_provider

from hub.util.iterable_ordered_dict import IterableOrderedDict
Expand Down Expand Up @@ -449,8 +449,6 @@ def __init__(
transform: PytorchTransformFunction = PytorchTransformFunction(),
num_workers: int = 1,
buffer_size: int = 512,
batch_size: int = 0,
collate_fn: Optional[Callable] = None,
) -> None:
super().__init__()

Expand All @@ -464,8 +462,6 @@ def __init__(
)

self.num_workers = num_workers
self.batch_size = batch_size
self.collate_fn = collate_fn
self.buffer_size = buffer_size * MB

if self.buffer_size == 0:
Expand All @@ -476,32 +472,42 @@ def __iter__(self):

sub_loader = DataLoader(
self.torch_datset,
batch_size=self.batch_size,
batch_size=1,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
collate_fn=collate_fn,
)

it = iter(sub_loader)

try:
while True:
next_batch = next(it)
batch_keys = list(next_batch.keys())

for i in range(len(next_batch[batch_keys[0]])):
val = IterableOrderedDict(
{k: next_batch[k][i].clone().detach() for k in batch_keys}
)
if isinstance(next_batch, dict):
d = {}
for k, v in next_batch.items():
current_val = v[0]
if isinstance(current_val, torch.Tensor):
current_val = current_val.clone().detach()
d[k] = current_val
val = IterableOrderedDict(d)
elif isinstance(next_batch, Sequence):
val = []
for item in next_batch:
if isinstance(item, torch.Tensor):
item = item.clone().detach()
val.append(item)
else:
val = next_batch

if buffer is not None:
result = buffer.exchange(val)
if buffer is not None:
result = buffer.exchange(val)

if result:
yield result
else:
yield val
if result:
yield result
else:
yield val

del next_batch, batch_keys
del next_batch

except StopIteration:
pass
Expand Down
2 changes: 0 additions & 2 deletions hub/integrations/pytorch/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def create_dataloader_nesteddataloader(
transform=transform,
num_workers=num_workers,
buffer_size=buffer_size,
batch_size=batch_size,
collate_fn=collate_fn,
),
batch_size=batch_size,
collate_fn=collate_fn,
Expand Down
22 changes: 15 additions & 7 deletions hub/integrations/pytorch/shuffle_buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Any
from typing import List, Any, Sequence
from random import randrange
from functools import reduce
from operator import mul
Expand Down Expand Up @@ -76,12 +76,20 @@ def emtpy(self) -> bool:
return len(self.buffer) == 0

def _sample_size(self, sample):
return sum(
[
tensor.storage().element_size() * reduce(mul, tensor.shape, 1)
for _, tensor in sample.items()
]
)
if isinstance(sample, dict):
return sum(
[
tensor.storage().element_size() * reduce(mul, tensor.shape, 1)
for _, tensor in sample.items()
]
)
elif isinstance(sample, Sequence):
return sum(
[
tensor.storage().element_size() * reduce(mul, tensor.shape, 1)
for tensor in sample
]
)

def __len__(self):
return len(self.buffer)
Expand Down
29 changes: 29 additions & 0 deletions hub/integrations/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

from hub.tests.dataset_fixtures import enabled_datasets

try:
from torch.utils.data._utils.collate import default_collate
except ImportError:
pass

# ensure tests have multiple chunks without a ton of data
PYTORCH_TESTS_MAX_CHUNK_SIZE = 5 * KB
Expand All @@ -25,6 +29,11 @@ def to_tuple(sample):
return sample["image"], sample["image2"]


def my_collate(batch):
x = [((x["a"], x["b"]), x["c"]) for x in batch]
return default_collate(x)


def pytorch_small_shuffle_helper(start, end, dataloader):
for _ in range(2):
all_values = []
Expand Down Expand Up @@ -454,6 +463,26 @@ def test_pytorch_large(local_ds):
np.testing.assert_array_equal(batch["label"][0], idx)


@requires_torch
@pytest.mark.parametrize("shuffle", [True, False])
def test_pytorch_collate(local_ds, shuffle):
local_ds.create_tensor("a")
local_ds.create_tensor("b")
local_ds.create_tensor("c")
for _ in range(100):
local_ds.a.append(0)
local_ds.b.append(1)
local_ds.c.append(2)

ptds = local_ds.pytorch(batch_size=4, shuffle=shuffle, collate_fn=my_collate)
for batch in ptds:
assert len(batch) == 2
assert len(batch[0]) == 2
np.testing.assert_array_equal(batch[0][0], np.array([0, 0, 0, 0]).reshape(4, 1))
np.testing.assert_array_equal(batch[0][1], np.array([1, 1, 1, 1]).reshape(4, 1))
np.testing.assert_array_equal(batch[1], np.array([2, 2, 2, 2]).reshape(4, 1))


def run_ddp(rank, size, ds, q, backend="gloo"):
import torch.distributed as dist
import os
Expand Down
2 changes: 0 additions & 2 deletions hub/integrations/tests/test_pytorch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def test_method2(ds):
mock_dataset(ds),
use_local_cache=False,
num_workers=2,
batch_size=2,
collate_fn=default_collate_fn,
tensors=tensors,
)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=default_collate_fn)
Expand Down

0 comments on commit e6023fe

Please sign in to comment.