-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev flow.utils.data part1 #5406
Merged
Merged
Changes from all commits
Commits
Show all changes
65 commits
Select commit
Hold shift + click to select a range
eb694ab
refine and add test case
Flowingsun007 813f379
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 0aae06b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 4a38ccd
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 b58b849
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 b5e151a
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 0ff1651
support ellipsis type slice
Flowingsun007 1276e65
refine
Flowingsun007 b9066f9
refine
Flowingsun007 8f81967
support slice assign ellipsis type
Flowingsun007 8f8cee2
refine
Flowingsun007 a79dcdf
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 81a400e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 e9e2fa4
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 dbdcf18
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 b11243f
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 9c7185b
Merge branch 'master' into dev_fix_slice_bug
oneflow-ci-bot c8c78fb
register fn to localtensor
Flowingsun007 f565929
Merge branch 'dev_fix_slice_bug' of https://github.com/Oneflow-Inc/on…
Flowingsun007 ebc25f0
Merge branch 'dev_fix_slice_bug'
Flowingsun007 475bbff
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 b69f554
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 e8cd9e3
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 a500a6d
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 5387b8f
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 756b0ed
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 34e9fd5
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 a5d67ac
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 e547b4b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 756e537
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 a39271b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 d5ecb51
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 db1b536
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 75cc02b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 634b968
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 6be7d0b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 d1eaabe
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 2ffb4ed
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 eca3dd6
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 9da3134
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 467bc42
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 d1a674b
basic implemetation of dataloader
Flowingsun007 2a7e9d5
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 a2bc0e1
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 c64a100
merge master
Flowingsun007 f363956
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 2a884bb
format
Flowingsun007 d2be60a
Merge branch 'dev_flow.utils.data_part1' of https://github.com/Oneflo…
Flowingsun007 d4de446
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 dfb3ba9
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 9c829e8
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 a2dfcdd
refine as comments
Flowingsun007 0b1418b
fix comments
Flowingsun007 28875db
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 f2600b4
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 e0d79be
remove useless code
Flowingsun007 923782b
Merge branch 'dev_flow.utils.data_part1' of https://github.com.cnpmjs…
Flowingsun007 fede595
add test case into unnitest
Flowingsun007 1f1307f
refine ascomments
Flowingsun007 44c7d58
refine as comments
Flowingsun007 c62c7dc
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 05a5b51
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 b8277be
Merge branch 'master' into dev_flow.utils.data_part1
oneflow-ci-bot 723e41a
Merge branch 'master' into dev_flow.utils.data_part1
oneflow-ci-bot 3154b81
Merge branch 'master' into dev_flow.utils.data_part1
oneflow-ci-bot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import unittest | ||
import numpy as np | ||
|
||
import oneflow.experimental as flow | ||
import oneflow.python.utils.data as Data | ||
|
||
|
||
class ScpDataset(Data.Dataset): | ||
def __init__(self, chunksize=200, dim=81, length=2000): | ||
self.chunksize = chunksize | ||
self.dim = dim | ||
self.length = length | ||
|
||
def __getitem__(self, index): | ||
np.random.seed(index) | ||
return np.random.randn(self.chunksize, self.dim) | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
|
||
@flow.unittest.skip_unless_1n1d() | ||
@unittest.skipIf( | ||
not flow.unittest.env.eager_execution_enabled(), | ||
".numpy() doesn't work in lazy mode", | ||
) | ||
class TestNumpyDataset(flow.unittest.TestCase): | ||
def test_numpy_dataset(test_case): | ||
dataset = ScpDataset() | ||
dataloader = Data.DataLoader(dataset, batch_size=16, shuffle=True) | ||
for X in dataloader: | ||
test_case.assertEqual(X.shape, flow.Size([16, 200, 81])) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import unittest | ||
import numpy as np | ||
|
||
import oneflow.experimental as flow | ||
import oneflow.python.utils.data as Data | ||
import oneflow.experimental.nn as nn | ||
|
||
|
||
class LinearNet(nn.Module): | ||
def __init__(self, n_feature): | ||
super(LinearNet, self).__init__() | ||
self.linear = nn.Linear(n_feature, 1) | ||
|
||
def forward(self, x): | ||
y = self.linear(x) | ||
return y | ||
|
||
|
||
@flow.unittest.skip_unless_1n1d() | ||
@unittest.skipIf( | ||
not flow.unittest.env.eager_execution_enabled(), | ||
".numpy() doesn't work in lazy mode", | ||
) | ||
class TestTensorDataset(flow.unittest.TestCase): | ||
def test_tensor_dataset(test_case): | ||
|
||
num_inputs = 2 | ||
num_examples = 1000 | ||
true_w = [2, -3.4] | ||
true_b = 4.2 | ||
|
||
net = LinearNet(num_inputs) | ||
flow.nn.init.normal_(net.linear.weight, mean=0, std=0.01) | ||
flow.nn.init.constant_(net.linear.bias, val=0) | ||
|
||
loss = nn.MSELoss() | ||
optimizer = flow.optim.SGD(net.parameters(), lr=0.03) | ||
|
||
features = flow.tensor( | ||
np.random.normal(0, 1, (num_examples, num_inputs)), dtype=flow.float | ||
) | ||
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b | ||
labels += flow.tensor( | ||
np.random.normal(0, 0.01, size=labels.size()), dtype=flow.float | ||
) | ||
|
||
batch_size = 10 | ||
# combine features and labels | ||
dataset = Data.TensorDataset(features, labels) | ||
# random get small batch | ||
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, num_workers=0) | ||
|
||
num_epochs = 10 | ||
for epoch in range(1, num_epochs + 1): | ||
for X, y in data_iter: | ||
output = net(X) | ||
l = loss(output, y) | ||
optimizer.zero_grad() | ||
l.backward() | ||
optimizer.step() | ||
if epoch == num_epochs: | ||
test_case.assertLess(l.numpy(), 0.00019) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
from oneflow.python.utils.data.sampler import ( | ||
Sampler, | ||
SequentialSampler, | ||
RandomSampler, | ||
SubsetRandomSampler, | ||
BatchSampler, | ||
) | ||
from oneflow.python.utils.data.dataset import ( | ||
Dataset, | ||
IterableDataset, | ||
TensorDataset, | ||
ConcatDataset, | ||
Subset, | ||
random_split, | ||
) | ||
from oneflow.python.utils.data.dataset import IterableDataset as IterDataPipe | ||
from oneflow.python.utils.data.dataloader import DataLoader, _DatasetKind | ||
from oneflow.python.utils.data.decorator import ( | ||
functional_datapipe, | ||
guaranteed_datapipes_determinism, | ||
non_deterministic, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"Sampler", | ||
"SequentialSampler", | ||
"RandomSampler", | ||
"SubsetRandomSampler", | ||
"BatchSampler", | ||
"Dataset", | ||
"IterableDataset", | ||
"TensorDataset", | ||
"ConcatDataset", | ||
"Subset", | ||
"random_split", | ||
"DataLoader", | ||
"_DatasetKind", | ||
"IterDataPipe", | ||
"functional_datapipe", | ||
"guaranteed_datapipes_determinism", | ||
"non_deterministic", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
r"""Utility classes & functions for data loading. Code in this folder is mostly | ||
used by ../dataloder.py. | ||
|
||
A lot of multiprocessing is used in data loading, which only supports running | ||
functions defined in global environment (py2 can't serialize static methods). | ||
Therefore, for code tidiness we put these functions into different files in this | ||
folder. | ||
""" | ||
|
||
import sys | ||
import atexit | ||
|
||
|
||
IS_WINDOWS = sys.platform == "win32" | ||
|
||
MP_STATUS_CHECK_INTERVAL = 5.0 | ||
r"""Interval (in seconds) to check status of processes to avoid hanging in | ||
multiprocessing data loading. This is mainly used in getting data from | ||
another process, in which case we need to periodically check whether the | ||
sender is alive to prevent hanging.""" | ||
|
||
|
||
python_exit_status = False | ||
r"""Whether Python is shutting down. This flag is guaranteed to be set before | ||
the Python core library resources are freed, but Python may already be exiting | ||
for some time when this is set. | ||
|
||
Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar | ||
hook in Python 3.7 multiprocessing library: | ||
https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 | ||
""" | ||
|
||
|
||
def _set_python_exit_flag(): | ||
global python_exit_status | ||
python_exit_status = True | ||
|
||
|
||
atexit.register(_set_python_exit_flag) | ||
|
||
|
||
from . import collate, fetch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to | ||
collate samples fetched from dataset into Tensor(s). | ||
|
||
These **needs** to be in global scope since Py2 doesn't support serializing | ||
static methods. | ||
""" | ||
|
||
import oneflow as flow | ||
import re | ||
import collections | ||
import oneflow.python.utils as utils | ||
|
||
string_classes = (str, bytes) | ||
|
||
np_str_obj_array_pattern = re.compile(r"[SaUO]") | ||
|
||
|
||
def default_convert(data): | ||
r"""Converts each NumPy array data field into a tensor""" | ||
elem_type = type(data) | ||
if isinstance(data, (flow.Tensor, flow._oneflow_internal.Tensor)): | ||
return data | ||
elif ( | ||
elem_type.__module__ == "numpy" | ||
and elem_type.__name__ != "str_" | ||
and elem_type.__name__ != "string_" | ||
): | ||
# array of string classes and object | ||
if ( | ||
elem_type.__name__ == "ndarray" | ||
and np_str_obj_array_pattern.search(data.dtype.str) is not None | ||
): | ||
return data | ||
return flow.tensor(data) | ||
elif isinstance(data, collections.abc.Mapping): | ||
return {key: default_convert(data[key]) for key in data} | ||
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple | ||
return elem_type(*(default_convert(d) for d in data)) | ||
elif isinstance(data, collections.abc.Sequence) and not isinstance( | ||
data, string_classes | ||
): | ||
return [default_convert(d) for d in data] | ||
else: | ||
# NOTE: torch just return data here, and not raise any exception! | ||
raise TypeError(default_convert_err_msg_format.format(elem_type)) | ||
|
||
|
||
default_collate_err_msg_format = ( | ||
"default_collate: batch must contain tensors, numpy arrays, numbers, " | ||
"dicts or lists; found {}" | ||
) | ||
|
||
default_convert_err_msg_format = ( | ||
"default_convert: batch must contain tensors, numpy arrays, numbers, " | ||
"dicts or lists; found {}" | ||
) | ||
|
||
|
||
def default_collate(batch): | ||
r"""Puts each data field into a tensor with outer dimension batch size""" | ||
|
||
elem = batch[0] | ||
elem_type = type(elem) | ||
if isinstance(elem, (flow.Tensor, flow._oneflow_internal.Tensor)): | ||
# TODO: tensor.storage()._new_shared(numel) | ||
return flow.experimental.stack(batch, dim=0) | ||
elif ( | ||
elem_type.__module__ == "numpy" | ||
and elem_type.__name__ != "str_" | ||
and elem_type.__name__ != "string_" | ||
): | ||
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": | ||
# array of string classes and object | ||
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | ||
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | ||
|
||
return default_collate([flow.Tensor(b) for b in batch]) | ||
elif elem.shape == (): # scalars | ||
return flow.Tensor(batch) | ||
elif isinstance(elem, float): | ||
return flow.tensor(batch, dtype=flow.float64) | ||
elif isinstance(elem, int): | ||
return flow.tensor(batch) | ||
elif isinstance(elem, string_classes): | ||
return batch | ||
elif isinstance(elem, collections.abc.Mapping): | ||
return {key: default_collate([d[key] for d in batch]) for key in elem} | ||
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple | ||
return elem_type(*(default_collate(samples) for samples in zip(*batch))) | ||
elif isinstance(elem, collections.abc.Sequence): | ||
# check to make sure that the elements in batch have consistent size | ||
it = iter(batch) | ||
elem_size = len(next(it)) | ||
if not all(len(elem) == elem_size for elem in it): | ||
raise RuntimeError("each element in list of batch should be of equal size") | ||
transposed = zip(*batch) | ||
return [default_collate(samples) for samples in transposed] | ||
|
||
raise TypeError(default_collate_err_msg_format.format(elem_type)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
length弄少一点?设置成10这种数量级的?