-
Notifications
You must be signed in to change notification settings - Fork 656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev flow.utils.data part3 #5644
Merged
Merged
Changes from 47 commits
Commits
Show all changes
98 commits
Select commit
Hold shift + click to select a range
ef446dc
add more datasets
Flowingsun007 3499641
add more transform funcs
Flowingsun007 08b262d
export interface
Flowingsun007 b2041fb
Merge branch 'master' of https://github.com.cnpmjs.org/Oneflow-Inc/on…
Flowingsun007 8048814
Merge branch 'dev_flow.utils.data-part3'
Flowingsun007 d88bbad
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 fc08cb5
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 2749006
Merge branch 'dev_flow.utils.data-part3' of https://github.com/Oneflo…
Flowingsun007 099c9d8
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 6048943
export datasets interface
Flowingsun007 e4c3864
merge master
Flowingsun007 2f5eabd
auto format by CI
oneflow-ci-bot b67c286
fix docs
Flowingsun007 2cc9b94
Merge branch 'fix_datasets_export' of https://github.com/Oneflow-Inc/…
Flowingsun007 7d50e8f
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 5c9f2cf
Merge branch 'fix_datasets_export' into dev_flow.utils.data-part3
Flowingsun007 6084d8b
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 e2d0843
skip test
Flowingsun007 26fd2cc
support DistributedSampler
Flowingsun007 272b97e
refine
Flowingsun007 561687d
add more transform function
Flowingsun007 136396e
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 68bffa8
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 a2e6951
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 c79f928
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 c705f48
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 0f48f6b
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 123bbdb
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 77fba7f
fix err import
Flowingsun007 27faabd
fix comment
Flowingsun007 9ac8d98
refine
Flowingsun007 06048c4
add more transform test
Flowingsun007 f4bf241
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 0cbefa3
refactor dataloader test
Flowingsun007 272df42
Merge branch 'dev_flow.utils.data-part3' of https://github.com/Oneflo…
Flowingsun007 305e399
refine
Flowingsun007 e2a935b
add ddp test
Flowingsun007 175b77d
refine
Flowingsun007 50eee85
refine
Flowingsun007 b43b5ce
add ddp test case
Flowingsun007 7f65a68
skil test
Flowingsun007 89ea598
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 bf04b24
add ddp test case
Flowingsun007 8dc0740
Merge branch 'dev_flow.utils.data-part3' of https://github.com/Oneflo…
Flowingsun007 4699460
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 94a9688
add test case
Flowingsun007 ecf6a22
refine
Flowingsun007 4085749
rm ddp test
Flowingsun007 df926ef
remove ddp test
Flowingsun007 1a6fc30
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 03e3052
auto format by CI
oneflow-ci-bot 8bcf3aa
format
Flowingsun007 afc2d56
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 ec8a35b
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 8e847e0
update api docs
Flowingsun007 27861a5
Merge branch 'dev_flow.utils.data-part3' of https://github.com/Oneflo…
Flowingsun007 86edc24
add utils.rst
Flowingsun007 3a469e8
auto format by CI
oneflow-ci-bot 21f05a1
fix ddp grad size
daquexian 55183e9
Merge remote-tracking branch 'origin/master' into fix_ddp_grad_size
daquexian 3b5cf1a
remove print
daquexian 9218800
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 38fe352
refine as comments
Flowingsun007 df07a10
Merge branch 'dev_flow.utils.data-part3' of https://github.com/Oneflo…
Flowingsun007 c25a097
refine
Flowingsun007 e19ea1b
auto format by CI
oneflow-ci-bot 84ec6c7
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 11f49aa
Merge branch 'master' into fix_ddp_grad_size
oneflow-ci-bot cbcc038
auto format by CI
oneflow-ci-bot aa4fcda
Merge remote-tracking branch 'origin/fix_ddp_grad_size' into dev_flow…
Flowingsun007 1d95021
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 6294526
refine
Flowingsun007 e8724b5
add ddp test
Flowingsun007 02a7c7e
Merge branch 'dev_flow.utils.data-part3' of https://github.com/Oneflo…
Flowingsun007 50ab414
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 57506ef
auto format by CI
oneflow-ci-bot 1c6aa4d
rm test case
Flowingsun007 546868a
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 7cbd439
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 423cf08
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 9492712
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 a435c97
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 7d4fc04
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 91a31a9
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 a0cf4ae
Merge branch 'master' into dev_flow.utils.data-part3
oneflow-ci-bot 089f0cf
Merge branch 'master' into dev_flow.utils.data-part3
oneflow-ci-bot bd566d1
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 efa9526
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 65fbd06
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 2dceb82
Merge branch 'master' into dev_flow.utils.data-part3
oneflow-ci-bot 0449fb2
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 52900ed
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 343d5e4
Merge branch 'master' into dev_flow.utils.data-part3
Flowingsun007 477c1e8
fix reshape
Flowingsun007 54f3497
Merge branch 'master' into dev_flow.utils.data-part3
oneflow-ci-bot c02c419
Merge branch 'master' into dev_flow.utils.data-part3
oneflow-ci-bot 3fcf059
Merge branch 'master' into dev_flow.utils.data-part3
oneflow-ci-bot 32d2d3d
Merge branch 'master' into dev_flow.utils.data-part3
oneflow-ci-bot File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import os | ||
import oneflow as flow | ||
import oneflow.utils.vision.transforms as transforms | ||
|
||
|
||
def load_data_cifar10( | ||
batch_size, | ||
data_dir="./data-test/cifar10", | ||
download=True, | ||
transform=None, | ||
source_url=None, | ||
num_workers=0, | ||
): | ||
cifar10_train = flow.utils.vision.datasets.CIFAR10( | ||
root=data_dir, | ||
train=True, | ||
download=download, | ||
transform=transform, | ||
source_url=source_url, | ||
) | ||
cifar10_test = flow.utils.vision.datasets.CIFAR10( | ||
root=data_dir, | ||
train=False, | ||
download=download, | ||
transform=transform, | ||
source_url=source_url, | ||
) | ||
|
||
train_iter = flow.utils.data.DataLoader( | ||
cifar10_train, batch_size=batch_size, shuffle=True, num_workers=num_workers | ||
) | ||
test_iter = flow.utils.data.DataLoader( | ||
cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers | ||
) | ||
return train_iter, test_iter | ||
|
||
|
||
def load_data_mnist( | ||
batch_size, resize=None, root="./data/mnist", download=True, source_url=None | ||
): | ||
"""Download the MNIST dataset and then load into memory.""" | ||
root = os.path.expanduser(root) | ||
transformer = [] | ||
if resize: | ||
transformer += [transforms.Resize(resize)] | ||
transformer += [transforms.ToTensor()] | ||
transformer = transforms.Compose(transformer) | ||
|
||
mnist_train = flow.utils.vision.datasets.MNIST( | ||
root=root, | ||
train=True, | ||
transform=transformer, | ||
download=download, | ||
source_url=source_url, | ||
) | ||
mnist_test = flow.utils.vision.datasets.MNIST( | ||
root=root, | ||
train=False, | ||
transform=transformer, | ||
download=download, | ||
source_url=source_url, | ||
) | ||
train_iter = flow.utils.data.DataLoader( | ||
mnist_train, batch_size, shuffle=True | ||
) | ||
test_iter = flow.utils.data.DataLoader( | ||
mnist_test, batch_size, shuffle=False | ||
) | ||
return train_iter, test_iter | ||
|
||
|
||
def get_fashion_mnist_dataset( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
resize=None, | ||
root="./data-test/fashion-mnist", | ||
download=True, | ||
source_url=None, | ||
): | ||
root = os.path.expanduser(root) | ||
trans = [] | ||
if resize: | ||
trans.append(transforms.Resize(resize)) | ||
trans.append(transforms.ToTensor()) | ||
transform = transforms.Compose(trans) | ||
|
||
mnist_train = flow.utils.vision.datasets.FashionMNIST( | ||
root=root, | ||
train=True, | ||
transform=transform, | ||
download=download, | ||
source_url=source_url, | ||
) | ||
mnist_test = flow.utils.vision.datasets.FashionMNIST( | ||
root=root, | ||
train=False, | ||
transform=transform, | ||
download=download, | ||
source_url=source_url, | ||
) | ||
return mnist_train, mnist_test | ||
|
||
|
||
# reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch | ||
def load_data_fashion_mnist( | ||
batch_size, | ||
resize=None, | ||
root="./data-test/fashion-mnist", | ||
download=True, | ||
source_url=None, | ||
num_workers=0, | ||
): | ||
"""Download the Fashion-MNIST dataset and then load into memory.""" | ||
root = os.path.expanduser(root) | ||
trans = [] | ||
if resize: | ||
trans.append(transforms.Resize(resize)) | ||
trans.append(transforms.ToTensor()) | ||
transform = transforms.Compose(trans) | ||
|
||
mnist_train = flow.utils.vision.datasets.FashionMNIST( | ||
root=root, | ||
train=True, | ||
transform=transform, | ||
download=download, | ||
source_url=source_url, | ||
) | ||
mnist_test = flow.utils.vision.datasets.FashionMNIST( | ||
root=root, | ||
train=False, | ||
transform=transform, | ||
download=download, | ||
source_url=source_url, | ||
) | ||
|
||
train_iter = flow.utils.data.DataLoader( | ||
mnist_train, batch_size, shuffle=True, num_workers=num_workers | ||
) | ||
test_iter = flow.utils.data.DataLoader( | ||
mnist_test, batch_size, shuffle=False, num_workers=num_workers | ||
) | ||
return train_iter, test_iter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# ref:https://zhuanlan.zhihu.com/p/178402798 | ||
import argparse | ||
from tqdm import tqdm | ||
|
||
import oneflow as flow | ||
import oneflow.utils.vision as vision | ||
import oneflow.nn as nn | ||
import oneflow.F as F | ||
import oneflow.utils as utils | ||
import oneflow.optim as optim | ||
import oneflow.distributed as dist | ||
from oneflow.nn.parallel import DistributedDataParallel as DDP | ||
|
||
|
||
class ToyModel(nn.Module): | ||
def __init__(self): | ||
super(ToyModel, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
def get_dataset(): | ||
transform = vision.transforms.Compose([ | ||
vision.transforms.ToTensor(), | ||
vision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
my_trainset = vision.datasets.CIFAR10(root='./data', train=True, | ||
download=True, transform=transform) | ||
|
||
train_sampler = utils.data.distributed.DistributedSampler(my_trainset) | ||
trainloader = utils.data.DataLoader(my_trainset, | ||
batch_size=16, num_workers=2, sampler=train_sampler) | ||
return trainloader | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--local_rank", default=-1, type=int) | ||
FLAGS = parser.parse_args() | ||
# local_rank = FLAGS.local_rank | ||
local_rank = flow.device("cuda") | ||
|
||
# torch.cuda.set_device(local_rank) | ||
# dist.init_process_group(backend='nccl') | ||
|
||
trainloader = get_dataset() | ||
|
||
model = ToyModel().to(local_rank) | ||
|
||
ckpt_path = None | ||
# if dist.get_rank() == 0 and ckpt_path is not None: | ||
# model.load_state_dict(torch.load(ckpt_path)) | ||
# DDP model | ||
# model = DDP(model, device_ids=[local_rank], output_device=local_rank) | ||
model = DDP(model) | ||
optimizer = optim.SGD(model.parameters(), lr=0.001) | ||
loss_func = nn.CrossEntropyLoss().to(local_rank) | ||
|
||
|
||
model.train() | ||
iterator = tqdm(range(10)) | ||
for epoch in iterator: | ||
trainloader.sampler.set_epoch(epoch) | ||
for data, label in trainloader: | ||
data, label = data.to(local_rank), label.to(local_rank) | ||
optimizer.zero_grad() | ||
prediction = model(data) | ||
loss = loss_func(prediction, label) | ||
loss.backward() | ||
iterator.desc = "loss = %0.3f" % loss.numpy() | ||
optimizer.step() | ||
|
||
# if dist.get_rank() == 0: | ||
# torch.save(model.module.state_dict(), "%d.ckpt" % epoch) | ||
|
||
################ | ||
# export CUDA_VISIBLE_DEVICES="0,1" | ||
# python -m oneflow.distributed.launch --nproc_per_node 2 test_ddp_flow.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# ref:https://zhuanlan.zhihu.com/p/178402798 | ||
import argparse | ||
from tqdm import tqdm | ||
import torch | ||
import torchvision as vision | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.utils as utils | ||
import torch.optim as optim | ||
import torch.distributed as dist | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
|
||
|
||
class ToyModel(nn.Module): | ||
def __init__(self): | ||
super(ToyModel, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
def get_dataset(): | ||
transform = vision.transforms.Compose([ | ||
vision.transforms.ToTensor(), | ||
vision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
my_trainset = vision.datasets.CIFAR10(root='./data', train=True, | ||
download=True, transform=transform) | ||
|
||
train_sampler = utils.data.distributed.DistributedSampler(my_trainset) | ||
trainloader = utils.data.DataLoader(my_trainset, | ||
batch_size=16, num_workers=2, sampler=train_sampler) | ||
return trainloader | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--local_rank", default=-1, type=int) | ||
FLAGS = parser.parse_args() | ||
local_rank = FLAGS.local_rank | ||
|
||
torch.cuda.set_device(local_rank) | ||
dist.init_process_group(backend='nccl') | ||
|
||
trainloader = get_dataset() | ||
|
||
model = ToyModel().to(local_rank) | ||
|
||
ckpt_path = None | ||
if dist.get_rank() == 0 and ckpt_path is not None: | ||
model.load_state_dict(torch.load(ckpt_path)) | ||
# DDP model | ||
model = DDP(model, device_ids=[local_rank], output_device=local_rank) | ||
optimizer = optim.SGD(model.parameters(), lr=0.001) | ||
loss_func = nn.CrossEntropyLoss().to(local_rank) | ||
|
||
|
||
model.train() | ||
iterator = tqdm(range(10)) | ||
for epoch in iterator: | ||
trainloader.sampler.set_epoch(epoch) | ||
for data, label in trainloader: | ||
data, label = data.to(local_rank), label.to(local_rank) | ||
optimizer.zero_grad() | ||
prediction = model(data) | ||
loss = loss_func(prediction, label) | ||
loss.backward() | ||
iterator.desc = "loss = %0.3f" % loss | ||
optimizer.step() | ||
|
||
# if dist.get_rank() == 0: | ||
# torch.save(model.module.state_dict(), "%d.ckpt" % epoch) | ||
|
||
################ | ||
# export CUDA_VISIBLE_DEVICES="0,1" | ||
# python -m torch.distributed.launch --nproc_per_node 2 test_ddp_torch.py |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这类多封装一层的函数是没有必要的吧。pytorch 没有。我们封装了,教育用户的成本很大。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,这个仅仅是在test_case里测试使用的,避免每个test case都写一遍同样的数据加载过程
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件是放在 test 目录下,作为 data_utils 的话没有问题