-
Notifications
You must be signed in to change notification settings - Fork 6.8k
symbolic+imperative nn interface #5705
Changes from all commits
546d9e7
8b9a185
3f8de1b
5e996ae
9fc7ced
62758c5
e084bba
a05a132
0790940
305a322
47ff553
9fcc8e9
dea8e53
d03b2e3
896e17a
7d65b56
7a537a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# pylint: skip-file | ||
""" data iterator for mnist """ | ||
import sys | ||
import os | ||
# code to automatically download dataset | ||
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) | ||
sys.path.append(os.path.join(curr_path, "../../tests/python/common")) | ||
import get_data | ||
import mxnet as mx | ||
|
||
def mnist_iterator(batch_size, input_shape): | ||
"""return train and val iterators for mnist""" | ||
# download data | ||
get_data.GetMNIST_ubyte() | ||
flat = False if len(input_shape) == 3 else True | ||
|
||
train_dataiter = mx.io.MNISTIter( | ||
image="data/train-images-idx3-ubyte", | ||
label="data/train-labels-idx1-ubyte", | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
flat=flat) | ||
|
||
val_dataiter = mx.io.MNISTIter( | ||
image="data/t10k-images-idx3-ubyte", | ||
label="data/t10k-labels-idx1-ubyte", | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
flat=flat) | ||
|
||
return (train_dataiter, val_dataiter) | ||
|
||
|
||
def cifar10_iterator(batch_size, data_shape): | ||
train = mx.io.ImageRecordIter( | ||
path_imgrec = "data/cifar/train.rec", | ||
# mean_img = "data/cifar/mean.bin", | ||
data_shape = data_shape, | ||
batch_size = batch_size, | ||
rand_crop = True, | ||
rand_mirror = True) | ||
|
||
val = mx.io.ImageRecordIter( | ||
path_imgrec = "data/cifar/test.rec", | ||
# mean_img = "data/cifar/mean.bin", | ||
rand_crop = False, | ||
rand_mirror = False, | ||
data_shape = data_shape, | ||
batch_size = batch_size) | ||
|
||
return train, val | ||
|
||
class DummyIter(mx.io.DataIter): | ||
def __init__(self, batch_size, data_shape): | ||
self.data_shape = (batch_size,) + data_shape | ||
self.label_shape = (batch_size,) | ||
self.provide_data = [('data', self.data_shape)] | ||
self.provide_label = [('softmax_label', self.label_shape)] | ||
|
||
def next(self): | ||
return mx.io.DataBatch(data=[mx.nd.zeros(self.data_shape)], | ||
label=[mx.nd.zeros(self.label_shape)]) | ||
|
||
|
||
def dummy_iterator(batch_size, data_shape): | ||
return DummyIter(batch_size, data_shape), DummyIter(batch_size, data_shape) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# pylint: skip-file | ||
from data import mnist_iterator | ||
import mxnet as mx | ||
from mxnet.contrib import nn | ||
import numpy as np | ||
import logging | ||
from mxnet.contrib import autograd as ag | ||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
# define network | ||
|
||
net = nn.Sequential() | ||
net.add(nn.Dense(128, in_units=784, activation='relu')) | ||
net.add(nn.Dense(64, in_units=128, activation='relu')) | ||
net.add(nn.Dense(10, in_units=64)) | ||
|
||
# data | ||
|
||
train_data, val_data = mnist_iterator(batch_size=100, input_shape = (784,)) | ||
|
||
# train | ||
|
||
def test(ctx): | ||
metric = mx.metric.Accuracy() | ||
val_data.reset() | ||
for batch in val_data: | ||
data = nn.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0) | ||
label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0) | ||
outputs = [] | ||
for x in data: | ||
outputs.append(net(x)) | ||
metric.update(label, outputs) | ||
print 'validation acc: %s=%f'%metric.get() | ||
|
||
def train(epoch, ctx): | ||
if isinstance(ctx, mx.Context): | ||
ctx = [ctx] | ||
net.params.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) | ||
optim = nn.Optim(net.params, 'sgd', {'learning_rate': 0.1}) | ||
metric = mx.metric.Accuracy() | ||
|
||
for i in range(epoch): | ||
train_data.reset() | ||
for batch in train_data: | ||
data = nn.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0) | ||
label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0) | ||
outputs = [] | ||
with ag.train_section(): | ||
for x, y in zip(data, label): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we having an inner loop? Is it iterating through the batch one example at a time? Is it used to demonstrate that we can do this or is it suggesting the user should do this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is doing data parallel on multiple gpus. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Is it possible to hide the data-par mechanism like the original mxnet way? I guess in the imperative style API, the users always explicitly write for loops to handle time index or other dimensions. So it seems to be reasonable to always assume the 1st dimension is batch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it also confuses me. i thought you are iterating for example by example. probably rename |
||
z = net(x) | ||
loss = nn.loss.softmax_cross_entropy_loss(z, y) | ||
ag.compute_gradient([loss]) | ||
outputs.append(z) | ||
metric.update(label, outputs) | ||
optim.step(batch.data[0].shape[0]) | ||
name, acc = metric.get() | ||
metric.reset() | ||
print 'training acc at epoch %d: %s=%f'%(i, name, acc) | ||
test(ctx) | ||
|
||
net.params.save('mnist.params') | ||
|
||
|
||
if __name__ == '__main__': | ||
train(10, [mx.cpu(0), mx.cpu(1)]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
from __future__ import division | ||
|
||
import time | ||
import mxnet as mx | ||
from mxnet.contrib import nn | ||
from mxnet.contrib import autograd as ag | ||
from data import * | ||
|
||
def conv3x3(filters, stride, in_filters): | ||
return nn.Conv2D(filters, kernel_size=3, strides=stride, padding=1, | ||
use_bias=False, in_filters=in_filters) | ||
|
||
class BasicBlock(nn.Layer): | ||
def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): | ||
super(BasicBlock, self).__init__(**kwargs) | ||
with self.scope: | ||
self.bn1 = nn.BatchNorm(num_features=in_filters) | ||
self.conv1 = conv3x3(filters, stride, in_filters) | ||
self.bn2 = nn.BatchNorm(num_features=filters) | ||
self.conv2 = conv3x3(filters, 1, filters) | ||
if downsample: | ||
self.downsample = nn.Conv2D(filters, 1, stride, use_bias=False, | ||
in_filters=in_filters) | ||
else: | ||
self.downsample = None | ||
|
||
def generic_forward(self, domain, x): | ||
if not self.downsample: | ||
residual = x | ||
x = self.bn1(x) | ||
x = domain.Activation(x, act_type='relu') | ||
if self.downsample: | ||
residual = self.downsample(x) | ||
x = self.conv1(x) | ||
|
||
x = self.bn2(x) | ||
x = domain.Activation(x, act_type='relu') | ||
x = self.conv2(x) | ||
|
||
return x + residual | ||
|
||
|
||
class Bottleneck(nn.Layer): | ||
def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs): | ||
super(Bottleneck, self).__init__(**kwargs) | ||
with self.scope: | ||
self.bn1 = nn.BatchNorm(num_features=in_filters) | ||
self.conv1 = conv3x3(filters//4, 1, in_filters) | ||
self.bn2 = nn.BatchNorm(num_features=filters//4) | ||
self.conv2 = conv3x3(filters//4, stride, filters//4) | ||
self.bn3 = nn.BatchNorm(num_features=filters//4) | ||
self.conv3 = conv3x3(filters, 1, filters//4) | ||
if downsample: | ||
self.downsample = nn.Conv2D(filters, 1, stride, use_bias=False, | ||
in_filters=in_filters) | ||
else: | ||
self.downsample = None | ||
|
||
def generic_forward(self, domain, x): | ||
if not self.downsample: | ||
residual = x | ||
x = self.bn1(x) | ||
x = domain.Activation(x, act_type='relu') | ||
if self.downsample: | ||
residual = self.downsample(x) | ||
x = self.conv1(x) | ||
|
||
x = self.bn2(x) | ||
x = domain.Activation(x, act_type='relu') | ||
x = self.conv2(x) | ||
|
||
x = self.bn3(x) | ||
x = domain.Activation(x, act_type='relu') | ||
x = self.conv3(x) | ||
|
||
return x + residual | ||
|
||
class Resnet(nn.Layer): | ||
def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs): | ||
super(Resnet, self).__init__(**kwargs) | ||
with self.scope: | ||
assert len(layers) == len(filters) - 1 | ||
self._thumbnail = thumbnail | ||
self.bn_data = nn.BatchNorm(num_features=3, scale=False, center=False) | ||
if thumbnail: | ||
self.conv0 = conv3x3(filters[0], 1, 3) | ||
else: | ||
self.conv0 = nn.Conv2D(filters[0], 7, 2, 3, use_bias=False, | ||
in_filters=3) | ||
self.bn0 = nn.BatchNorm(num_features=filters[0]) | ||
self.pool0 = nn.MaxPool2D(3, 2, 1) | ||
|
||
self.body = nn.Sequential() | ||
in_filters = filters[0] | ||
for i in range(len(layers)): | ||
stride = 1 if i == 0 else 2 | ||
self.body.add(self._make_layer(block, layers[i], filters[i+1], | ||
stride, in_filters=in_filters)) | ||
in_filters = filters[i+1] | ||
|
||
self.bn1 = nn.BatchNorm(num_features=in_filters) | ||
self.pool1 = nn.GlobalAvgPool2D() | ||
self.dense1 = nn.Dense(classes, in_units=in_filters) | ||
|
||
def _make_layer(self, block, layers, filters, stride, in_filters=0): | ||
layer = nn.Sequential() | ||
layer.add(block(filters, stride, True, in_filters=in_filters)) | ||
for i in range(layers-1): | ||
layer.add(block(filters, 1, False, in_filters=filters)) | ||
return layer | ||
|
||
def generic_forward(self, domain, x): | ||
x = self.bn_data(x) | ||
x = self.conv0(x) | ||
if not self._thumbnail: | ||
x = self.bn0(x) | ||
x = domain.Activation(x, act_type='relu') | ||
x = self.pool0(x) | ||
|
||
x = self.body(x) | ||
|
||
x = self.bn1(x) | ||
x = domain.Activation(x, act_type='relu') | ||
x = self.pool1(x) | ||
x = x.reshape((0, -1)) | ||
x = self.dense1(x) | ||
|
||
return x | ||
|
||
|
||
def resnet18_cifar(classes): | ||
return Resnet(BasicBlock, classes, [2, 2, 2], [16, 16, 32, 64], True) | ||
|
||
def resnet50_imagenet(classes): | ||
return Resnet(Bottleneck, classes, [3, 4, 6, 3], [64, 256, 512, 1024, 2048], False) | ||
|
||
net = resnet18_cifar(10) | ||
batch_size = 32*8 | ||
train_data, val_data = cifar10_iterator(batch_size, (3, 32, 32)) | ||
|
||
|
||
def test(ctx): | ||
metric = mx.metric.Accuracy() | ||
val_data.reset() | ||
for batch in val_data: | ||
data = nn.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0) | ||
label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0) | ||
outputs = [] | ||
for x in data: | ||
outputs.append(net(x)) | ||
metric.update(label, outputs) | ||
print 'validation acc: %s=%f'%metric.get() | ||
|
||
|
||
def train(epoch, ctx): | ||
if isinstance(ctx, mx.Context): | ||
ctx = [ctx] | ||
net.params.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) | ||
optim = nn.Optim(net.params, 'sgd', {'learning_rate': 0.1}) | ||
metric = mx.metric.Accuracy() | ||
|
||
for i in range(epoch): | ||
tic = time.time() | ||
train_data.reset() | ||
btic = time.time() | ||
for batch in train_data: | ||
data = nn.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0) | ||
label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0) | ||
outputs = [] | ||
with ag.train_section(): | ||
for x, y in zip(data, label): | ||
z = net(x) | ||
loss = nn.loss.softmax_cross_entropy_loss(z, y) | ||
ag.compute_gradient([loss]) | ||
outputs.append(z) | ||
optim.step(batch.data[0].shape[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to |
||
metric.update(label, outputs) | ||
print batch_size/(time.time()-btic) | ||
btic = time.time() | ||
|
||
name, acc = metric.get() | ||
metric.reset() | ||
print 'training acc at epoch %d: %s=%f'%(i, name, acc) | ||
print 'time: %f'%(time.time()-tic) | ||
test(ctx) | ||
|
||
net.params.save('mnist.params') | ||
|
||
if __name__ == '__main__': | ||
train(200, [mx.gpu(i) for i in range(8)]) | ||
import logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
data = mx.sym.var('data') | ||
out = net(data) | ||
softmax = mx.sym.SoftmaxOutput(out, name='softmax') | ||
mod = mx.mod.Module(softmax, context=[mx.gpu(i) for i in range(8)]) | ||
mod.fit(train_data, num_epoch=100, batch_end_callback = mx.callback.Speedometer(batch_size, 10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to add a function in test_utils.py, for example
then we can use NDArrayIter for them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is temporary. I'll do something like pytorch.dataset later