Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev flow.utils.data part1 #5406

Merged
merged 65 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
eb694ab
refine and add test case
Flowingsun007 Jun 10, 2021
813f379
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
0aae06b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
4a38ccd
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 10, 2021
b58b849
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 11, 2021
b5e151a
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 11, 2021
0ff1651
support ellipsis type slice
Flowingsun007 Jun 11, 2021
1276e65
refine
Flowingsun007 Jun 11, 2021
b9066f9
refine
Flowingsun007 Jun 11, 2021
8f81967
support slice assign ellipsis type
Flowingsun007 Jun 11, 2021
8f8cee2
refine
Flowingsun007 Jun 11, 2021
a79dcdf
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 12, 2021
81a400e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 12, 2021
e9e2fa4
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 13, 2021
dbdcf18
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 13, 2021
b11243f
Merge branch 'master' into dev_fix_slice_bug
Flowingsun007 Jun 13, 2021
9c7185b
Merge branch 'master' into dev_fix_slice_bug
oneflow-ci-bot Jun 13, 2021
c8c78fb
register fn to localtensor
Flowingsun007 Jun 13, 2021
f565929
Merge branch 'dev_fix_slice_bug' of https://github.com/Oneflow-Inc/on…
Flowingsun007 Jun 13, 2021
ebc25f0
Merge branch 'dev_fix_slice_bug'
Flowingsun007 Jun 13, 2021
475bbff
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 13, 2021
b69f554
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 15, 2021
e8cd9e3
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 16, 2021
a500a6d
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
5387b8f
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
756b0ed
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
34e9fd5
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 17, 2021
a5d67ac
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 21, 2021
e547b4b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 21, 2021
756e537
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 24, 2021
a39271b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 25, 2021
d5ecb51
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 27, 2021
db1b536
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 28, 2021
75cc02b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 28, 2021
634b968
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 29, 2021
6be7d0b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 29, 2021
d1eaabe
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jun 30, 2021
2ffb4ed
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jul 1, 2021
eca3dd6
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jul 5, 2021
9da3134
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jul 5, 2021
467bc42
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
Flowingsun007 Jul 6, 2021
d1a674b
basic implemetation of dataloader
Flowingsun007 Jul 6, 2021
2a7e9d5
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 6, 2021
a2bc0e1
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 6, 2021
c64a100
merge master
Flowingsun007 Jul 6, 2021
f363956
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 6, 2021
2a884bb
format
Flowingsun007 Jul 6, 2021
d2be60a
Merge branch 'dev_flow.utils.data_part1' of https://github.com/Oneflo…
Flowingsun007 Jul 6, 2021
d4de446
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 9, 2021
dfb3ba9
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 12, 2021
9c829e8
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 13, 2021
a2dfcdd
refine as comments
Flowingsun007 Jul 13, 2021
0b1418b
fix comments
Flowingsun007 Jul 13, 2021
28875db
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 13, 2021
f2600b4
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 14, 2021
e0d79be
remove useless code
Flowingsun007 Jul 14, 2021
923782b
Merge branch 'dev_flow.utils.data_part1' of https://github.com.cnpmjs…
Flowingsun007 Jul 14, 2021
fede595
add test case into unnitest
Flowingsun007 Jul 14, 2021
1f1307f
refine ascomments
Flowingsun007 Jul 14, 2021
44c7d58
refine as comments
Flowingsun007 Jul 14, 2021
c62c7dc
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 14, 2021
05a5b51
Merge branch 'master' into dev_flow.utils.data_part1
Flowingsun007 Jul 14, 2021
b8277be
Merge branch 'master' into dev_flow.utils.data_part1
oneflow-ci-bot Jul 14, 2021
723e41a
Merge branch 'master' into dev_flow.utils.data_part1
oneflow-ci-bot Jul 14, 2021
3154b81
Merge branch 'master' into dev_flow.utils.data_part1
oneflow-ci-bot Jul 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 51 additions & 0 deletions oneflow/python/test/dataloader/test_numpy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np

import oneflow.experimental as flow
import oneflow.python.utils.data as Data


class ScpDataset(Data.Dataset):
def __init__(self, chunksize=200, dim=81, length=2000):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

length弄少一点?设置成10这种数量级的?

self.chunksize = chunksize
self.dim = dim
self.length = length

def __getitem__(self, index):
np.random.seed(index)
return np.random.randn(self.chunksize, self.dim)

def __len__(self):
return self.length


@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestNumpyDataset(flow.unittest.TestCase):
def test_numpy_dataset(test_case):
dataset = ScpDataset()
dataloader = Data.DataLoader(dataset, batch_size=32, shuffle=True)
for X in dataloader:
print(X.shape)
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
unittest.main()
82 changes: 82 additions & 0 deletions oneflow/python/test/dataloader/test_tensor_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np

import oneflow.experimental as flow
import oneflow.python.utils.data as Data
import oneflow.experimental.nn as nn


class LinearNet(nn.Module):
def __init__(self, n_feature):
super(LinearNet, self).__init__()
self.linear = nn.Linear(n_feature, 1)

def forward(self, x):
y = self.linear(x)
return y


@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestTensorDataset(flow.unittest.TestCase):
def test_tensor_dataset(test_case):

num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2

net = LinearNet(num_inputs)
print(net)
flow.nn.init.normal_(net.linear.weight, mean=0, std=0.01)
flow.nn.init.constant_(net.linear.bias, val=0)

loss = nn.MSELoss()
optimizer = flow.optim.SGD(net.parameters(), lr=0.03)
print(optimizer)

features = flow.tensor(
np.random.normal(0, 1, (num_examples, num_inputs)), dtype=flow.float
)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += flow.tensor(
np.random.normal(0, 0.01, size=labels.size()), dtype=flow.float
)

batch_size = 10
# 将训练数据的特征和标签组合
dataset = Data.TensorDataset(features, labels)
# 随机读取小批量
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, num_workers=0)

num_epochs = 10
for epoch in range(1, num_epochs + 1):
for X, y in data_iter:
output = net(X)
l = loss(output, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
print("epoch %d, loss: %f" % (epoch, l.numpy()))


if __name__ == "__main__":
unittest.main()
Empty file.
58 changes: 58 additions & 0 deletions oneflow/python/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from oneflow.python.utils.data.sampler import (
Sampler,
SequentialSampler,
RandomSampler,
SubsetRandomSampler,
BatchSampler,
)
from oneflow.python.utils.data.dataset import (
Dataset,
IterableDataset,
TensorDataset,
ConcatDataset,
Subset,
random_split,
)
from oneflow.python.utils.data.dataset import IterableDataset as IterDataPipe
from oneflow.python.utils.data.dataloader import DataLoader, _DatasetKind
from oneflow.python.utils.data.decorator import (
functional_datapipe,
guaranteed_datapipes_determinism,
non_deterministic,
)


__all__ = [
"Sampler",
"SequentialSampler",
"RandomSampler",
"SubsetRandomSampler",
"BatchSampler",
"Dataset",
"IterableDataset",
"TensorDataset",
"ConcatDataset",
"Subset",
"random_split",
"DataLoader",
"_DatasetKind",
"IterDataPipe",
"functional_datapipe",
"guaranteed_datapipes_determinism",
"non_deterministic",
]
57 changes: 57 additions & 0 deletions oneflow/python/utils/data/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
r"""Utility classes & functions for data loading. Code in this folder is mostly
used by ../dataloder.py.

A lot of multiprocessing is used in data loading, which only supports running
functions defined in global environment (py2 can't serialize static methods).
Therefore, for code tidiness we put these functions into different files in this
folder.
"""

import sys
import atexit


IS_WINDOWS = sys.platform == "win32"

MP_STATUS_CHECK_INTERVAL = 5.0
r"""Interval (in seconds) to check status of processes to avoid hanging in
multiprocessing data loading. This is mainly used in getting data from
another process, in which case we need to periodically check whether the
sender is alive to prevent hanging."""


python_exit_status = False
r"""Whether Python is shutting down. This flag is guaranteed to be set before
the Python core library resources are freed, but Python may already be exiting
for some time when this is set.

Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
hook in Python 3.7 multiprocessing library:
https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
"""


def _set_python_exit_flag():
global python_exit_status
python_exit_status = True


atexit.register(_set_python_exit_flag)


from . import collate, fetch
114 changes: 114 additions & 0 deletions oneflow/python/utils/data/_utils/collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).

These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""

import oneflow as flow
import re
import collections
import oneflow.python.utils as utils

string_classes = (str, bytes)

np_str_obj_array_pattern = re.compile(r"[SaUO]")


def default_convert(data):
r"""Converts each NumPy array data field into a tensor"""
elem_type = type(data)
if isinstance(data, (flow.Tensor, flow._oneflow_internal.Tensor)):
return data
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
# array of string classes and object
if (
elem_type.__name__ == "ndarray"
and np_str_obj_array_pattern.search(data.dtype.str) is not None
):
return data
return flow.tensor(data)
elif isinstance(data, collections.abc.Mapping):
return {key: default_convert(data[key]) for key in data}
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
return elem_type(*(default_convert(d) for d in data))
elif isinstance(data, collections.abc.Sequence) and not isinstance(
data, string_classes
):
return [default_convert(d) for d in data]
else:
# NOTE: torch just return data here, and not raise any exception!
raise TypeError(default_convert_err_msg_format.format(elem_type))


default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}"
)

default_convert_err_msg_format = (
"default_convert: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}"
)


def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""

elem = batch[0]
elem_type = type(elem)
if isinstance(elem, (flow.Tensor, flow._oneflow_internal.Tensor)):
# TODO: tensor.storage()._new_shared(numel)
return flow.experimental.stack(batch, dim=0)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))

return default_collate([flow.Tensor(b) for b in batch])
elif elem.shape == (): # scalars
return flow.Tensor(batch)
elif isinstance(elem, float):
return flow.tensor(batch, dtype=flow.float64)
elif isinstance(elem, int):
return flow.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

raise TypeError(default_collate_err_msg_format.format(elem_type))