Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes shuffle + collate #1552

Merged
merged 8 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions hub/integrations/pytorch/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Iterable, Optional, Sequence, List, Union
from hub.constants import MB
from hub.integrations.pytorch.common import PytorchTransformFunction
from hub.integrations.pytorch.common import PytorchTransformFunction, collate_fn
from hub.util.compute import get_compute_provider

from hub.util.iterable_ordered_dict import IterableOrderedDict
Expand Down Expand Up @@ -443,8 +443,6 @@ def __init__(
transform: PytorchTransformFunction = PytorchTransformFunction(),
num_workers: int = 1,
buffer_size: int = 512,
batch_size: int = 0,
collate_fn: Optional[Callable] = None,
) -> None:
super().__init__()

Expand All @@ -458,8 +456,6 @@ def __init__(
)

self.num_workers = num_workers
self.batch_size = batch_size
self.collate_fn = collate_fn
self.buffer_size = buffer_size * MB

if self.buffer_size == 0:
Expand All @@ -470,32 +466,42 @@ def __iter__(self):

sub_loader = DataLoader(
self.torch_datset,
batch_size=self.batch_size,
batch_size=1,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
collate_fn=collate_fn,
)

it = iter(sub_loader)

try:
while True:
next_batch = next(it)
batch_keys = list(next_batch.keys())

for i in range(len(next_batch[batch_keys[0]])):
val = IterableOrderedDict(
{k: next_batch[k][i].clone().detach() for k in batch_keys}
)
if isinstance(next_batch, dict):
d = {}
for k, v in next_batch.items():
current_val = v[0]
if isinstance(current_val, torch.Tensor):
current_val = current_val.clone().detach()
d[k] = current_val
val = IterableOrderedDict(d)
elif isinstance(next_batch, Sequence):
val = []
for item in next_batch:
if isinstance(item, torch.Tensor):
item = item.clone().detach()
val.append(item)
else:
val = next_batch

if buffer is not None:
result = buffer.exchange(val)
if buffer is not None:
result = buffer.exchange(val)

if result:
yield result
else:
yield val
if result:
yield result
else:
yield val

del next_batch, batch_keys
del next_batch

except StopIteration:
pass
Expand Down
2 changes: 0 additions & 2 deletions hub/integrations/pytorch/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def create_dataloader_nesteddataloader(
transform=transform,
num_workers=num_workers,
buffer_size=buffer_size,
batch_size=batch_size,
collate_fn=collate_fn,
),
batch_size=batch_size,
collate_fn=collate_fn,
Expand Down
22 changes: 15 additions & 7 deletions hub/integrations/pytorch/shuffle_buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Any
from typing import List, Any, Sequence
from random import randrange
from functools import reduce
from operator import mul
Expand Down Expand Up @@ -76,12 +76,20 @@ def emtpy(self) -> bool:
return len(self.buffer) == 0

def _sample_size(self, sample):
return sum(
[
tensor.storage().element_size() * reduce(mul, tensor.shape, 1)
for _, tensor in sample.items()
]
)
if isinstance(sample, dict):
return sum(
[
tensor.storage().element_size() * reduce(mul, tensor.shape, 1)
for _, tensor in sample.items()
]
)
elif isinstance(sample, Sequence):
return sum(
[
tensor.storage().element_size() * reduce(mul, tensor.shape, 1)
for tensor in sample
]
)

def __len__(self):
return len(self.buffer)
Expand Down
29 changes: 29 additions & 0 deletions hub/integrations/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

from hub.tests.dataset_fixtures import enabled_datasets

try:
from torch.utils.data._utils.collate import default_collate
except ImportError:
pass

# ensure tests have multiple chunks without a ton of data
PYTORCH_TESTS_MAX_CHUNK_SIZE = 5 * KB
Expand All @@ -25,6 +29,11 @@ def to_tuple(sample):
return sample["image"], sample["image2"]


def my_collate(batch):
x = [((x["a"], x["b"]), x["c"]) for x in batch]
return default_collate(x)


def pytorch_small_shuffle_helper(start, end, dataloader):
for _ in range(2):
all_values = []
Expand Down Expand Up @@ -454,6 +463,26 @@ def test_pytorch_large(local_ds):
np.testing.assert_array_equal(batch["label"][0], idx)


@requires_torch
@pytest.mark.parametrize("shuffle", [True, False])
def test_pytorch_collate(local_ds, shuffle):
local_ds.create_tensor("a")
local_ds.create_tensor("b")
local_ds.create_tensor("c")
for _ in range(100):
local_ds.a.append(0)
local_ds.b.append(1)
local_ds.c.append(2)

ptds = local_ds.pytorch(batch_size=4, shuffle=shuffle, collate_fn=my_collate)
for batch in ptds:
assert len(batch) == 2
assert len(batch[0]) == 2
np.testing.assert_array_equal(batch[0][0], np.array([0, 0, 0, 0]).reshape(4, 1))
np.testing.assert_array_equal(batch[0][1], np.array([1, 1, 1, 1]).reshape(4, 1))
np.testing.assert_array_equal(batch[1], np.array([2, 2, 2, 2]).reshape(4, 1))


def run_ddp(rank, size, ds, q, backend="gloo"):
import torch.distributed as dist
import os
Expand Down
2 changes: 0 additions & 2 deletions hub/integrations/tests/test_pytorch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def test_method2(ds):
mock_dataset(ds),
use_local_cache=False,
num_workers=2,
batch_size=2,
collate_fn=default_collate_fn,
tensors=tensors,
)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=default_collate_fn)
Expand Down