Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

symbolic+imperative nn interface #5705

Closed
wants to merge 17 commits into from
67 changes: 67 additions & 0 deletions example/autograd/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# pylint: skip-file
""" data iterator for mnist """
import sys
import os
# code to automatically download dataset
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
import get_data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to add a function in test_utils.py, for example

def get_mnist():
    """Download and load the MNIST dataset

    Returns
    -------
    dict
        A dict containing
    """
    def read_data(label_url, image_url):
        with gzip.open(mx.test_utils.download(label_url)) as flbl:
            magic, num = struct.unpack(">II", flbl.read(8))
            label = np.fromstring(flbl.read(), dtype=np.int8)
        with gzip.open(mx.test_utils.download(image_url), 'rb') as fimg:
            magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
            image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
            image = image.reshape(image.shape[0], 1, 28, 28).astype(np.float32)/255
        return (label, image)

    # changed to mxnet.io for more stable hosting
    # path='http://yann.lecun.com/exdb/mnist/'
    path='http://data.mxnet.io/data/mnist/'
    (train_lbl, train_img) = read_data(
        path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')
    (test_lbl, test_img) = read_data(
            path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')
    return {'train_data':train_img, 'train_label':train_lbl,
            'test_data':test_img, 'test_label':test_lbl}

then we can use NDArrayIter for them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temporary. I'll do something like pytorch.dataset later

import mxnet as mx

def mnist_iterator(batch_size, input_shape):
"""return train and val iterators for mnist"""
# download data
get_data.GetMNIST_ubyte()
flat = False if len(input_shape) == 3 else True

train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
shuffle=True,
flat=flat)

val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
flat=flat)

return (train_dataiter, val_dataiter)


def cifar10_iterator(batch_size, data_shape):
train = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/train.rec",
# mean_img = "data/cifar/mean.bin",
data_shape = data_shape,
batch_size = batch_size,
rand_crop = True,
rand_mirror = True)

val = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/test.rec",
# mean_img = "data/cifar/mean.bin",
rand_crop = False,
rand_mirror = False,
data_shape = data_shape,
batch_size = batch_size)

return train, val

class DummyIter(mx.io.DataIter):
def __init__(self, batch_size, data_shape):
self.data_shape = (batch_size,) + data_shape
self.label_shape = (batch_size,)
self.provide_data = [('data', self.data_shape)]
self.provide_label = [('softmax_label', self.label_shape)]

def next(self):
return mx.io.DataBatch(data=[mx.nd.zeros(self.data_shape)],
label=[mx.nd.zeros(self.label_shape)])


def dummy_iterator(batch_size, data_shape):
return DummyIter(batch_size, data_shape), DummyIter(batch_size, data_shape)
65 changes: 65 additions & 0 deletions example/autograd/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# pylint: skip-file
from data import mnist_iterator
import mxnet as mx
from mxnet.contrib import nn
import numpy as np
import logging
from mxnet.contrib import autograd as ag
logging.basicConfig(level=logging.DEBUG)

# define network

net = nn.Sequential()
net.add(nn.Dense(128, in_units=784, activation='relu'))
net.add(nn.Dense(64, in_units=128, activation='relu'))
net.add(nn.Dense(10, in_units=64))

# data

train_data, val_data = mnist_iterator(batch_size=100, input_shape = (784,))

# train

def test(ctx):
metric = mx.metric.Accuracy()
val_data.reset()
for batch in val_data:
data = nn.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0)
label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
for x in data:
outputs.append(net(x))
metric.update(label, outputs)
print 'validation acc: %s=%f'%metric.get()

def train(epoch, ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
net.params.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
optim = nn.Optim(net.params, 'sgd', {'learning_rate': 0.1})
metric = mx.metric.Accuracy()

for i in range(epoch):
train_data.reset()
for batch in train_data:
data = nn.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0)
label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
with ag.train_section():
for x, y in zip(data, label):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we having an inner loop? Is it iterating through the batch one example at a time? Is it used to demonstrate that we can do this or is it suggesting the user should do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is doing data parallel on multiple gpus.
we can have a DataParallel util like pytorch but collecting the outputs to one gpu is not a good idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Is it possible to hide the data-par mechanism like the original mxnet way? I guess in the imperative style API, the users always explicitly write for loops to handle time index or other dimensions. So it seems to be reasonable to always assume the 1st dimension is batch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it also confuses me. i thought you are iterating for example by example. probably rename load_data into something like split_data_to_dev.

z = net(x)
loss = nn.loss.softmax_cross_entropy_loss(z, y)
ag.compute_gradient([loss])
outputs.append(z)
metric.update(label, outputs)
optim.step(batch.data[0].shape[0])
name, acc = metric.get()
metric.reset()
print 'training acc at epoch %d: %s=%f'%(i, name, acc)
test(ctx)

net.params.save('mnist.params')


if __name__ == '__main__':
train(10, [mx.cpu(0), mx.cpu(1)])
197 changes: 197 additions & 0 deletions example/autograd/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from __future__ import division

import time
import mxnet as mx
from mxnet.contrib import nn
from mxnet.contrib import autograd as ag
from data import *

def conv3x3(filters, stride, in_filters):
return nn.Conv2D(filters, kernel_size=3, strides=stride, padding=1,
use_bias=False, in_filters=in_filters)

class BasicBlock(nn.Layer):
def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
super(BasicBlock, self).__init__(**kwargs)
with self.scope:
self.bn1 = nn.BatchNorm(num_features=in_filters)
self.conv1 = conv3x3(filters, stride, in_filters)
self.bn2 = nn.BatchNorm(num_features=filters)
self.conv2 = conv3x3(filters, 1, filters)
if downsample:
self.downsample = nn.Conv2D(filters, 1, stride, use_bias=False,
in_filters=in_filters)
else:
self.downsample = None

def generic_forward(self, domain, x):
if not self.downsample:
residual = x
x = self.bn1(x)
x = domain.Activation(x, act_type='relu')
if self.downsample:
residual = self.downsample(x)
x = self.conv1(x)

x = self.bn2(x)
x = domain.Activation(x, act_type='relu')
x = self.conv2(x)

return x + residual


class Bottleneck(nn.Layer):
def __init__(self, filters, stride, downsample=False, in_filters=0, **kwargs):
super(Bottleneck, self).__init__(**kwargs)
with self.scope:
self.bn1 = nn.BatchNorm(num_features=in_filters)
self.conv1 = conv3x3(filters//4, 1, in_filters)
self.bn2 = nn.BatchNorm(num_features=filters//4)
self.conv2 = conv3x3(filters//4, stride, filters//4)
self.bn3 = nn.BatchNorm(num_features=filters//4)
self.conv3 = conv3x3(filters, 1, filters//4)
if downsample:
self.downsample = nn.Conv2D(filters, 1, stride, use_bias=False,
in_filters=in_filters)
else:
self.downsample = None

def generic_forward(self, domain, x):
if not self.downsample:
residual = x
x = self.bn1(x)
x = domain.Activation(x, act_type='relu')
if self.downsample:
residual = self.downsample(x)
x = self.conv1(x)

x = self.bn2(x)
x = domain.Activation(x, act_type='relu')
x = self.conv2(x)

x = self.bn3(x)
x = domain.Activation(x, act_type='relu')
x = self.conv3(x)

return x + residual

class Resnet(nn.Layer):
def __init__(self, block, classes, layers, filters, thumbnail=False, **kwargs):
super(Resnet, self).__init__(**kwargs)
with self.scope:
assert len(layers) == len(filters) - 1
self._thumbnail = thumbnail
self.bn_data = nn.BatchNorm(num_features=3, scale=False, center=False)
if thumbnail:
self.conv0 = conv3x3(filters[0], 1, 3)
else:
self.conv0 = nn.Conv2D(filters[0], 7, 2, 3, use_bias=False,
in_filters=3)
self.bn0 = nn.BatchNorm(num_features=filters[0])
self.pool0 = nn.MaxPool2D(3, 2, 1)

self.body = nn.Sequential()
in_filters = filters[0]
for i in range(len(layers)):
stride = 1 if i == 0 else 2
self.body.add(self._make_layer(block, layers[i], filters[i+1],
stride, in_filters=in_filters))
in_filters = filters[i+1]

self.bn1 = nn.BatchNorm(num_features=in_filters)
self.pool1 = nn.GlobalAvgPool2D()
self.dense1 = nn.Dense(classes, in_units=in_filters)

def _make_layer(self, block, layers, filters, stride, in_filters=0):
layer = nn.Sequential()
layer.add(block(filters, stride, True, in_filters=in_filters))
for i in range(layers-1):
layer.add(block(filters, 1, False, in_filters=filters))
return layer

def generic_forward(self, domain, x):
x = self.bn_data(x)
x = self.conv0(x)
if not self._thumbnail:
x = self.bn0(x)
x = domain.Activation(x, act_type='relu')
x = self.pool0(x)

x = self.body(x)

x = self.bn1(x)
x = domain.Activation(x, act_type='relu')
x = self.pool1(x)
x = x.reshape((0, -1))
x = self.dense1(x)

return x


def resnet18_cifar(classes):
return Resnet(BasicBlock, classes, [2, 2, 2], [16, 16, 32, 64], True)

def resnet50_imagenet(classes):
return Resnet(Bottleneck, classes, [3, 4, 6, 3], [64, 256, 512, 1024, 2048], False)

net = resnet18_cifar(10)
batch_size = 32*8
train_data, val_data = cifar10_iterator(batch_size, (3, 32, 32))


def test(ctx):
metric = mx.metric.Accuracy()
val_data.reset()
for batch in val_data:
data = nn.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0)
label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
for x in data:
outputs.append(net(x))
metric.update(label, outputs)
print 'validation acc: %s=%f'%metric.get()


def train(epoch, ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
net.params.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
optim = nn.Optim(net.params, 'sgd', {'learning_rate': 0.1})
metric = mx.metric.Accuracy()

for i in range(epoch):
tic = time.time()
train_data.reset()
btic = time.time()
for batch in train_data:
data = nn.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0)
label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
with ag.train_section():
for x, y in zip(data, label):
z = net(x)
loss = nn.loss.softmax_cross_entropy_loss(z, y)
ag.compute_gradient([loss])
outputs.append(z)
optim.step(batch.data[0].shape[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to batch_size

metric.update(label, outputs)
print batch_size/(time.time()-btic)
btic = time.time()

name, acc = metric.get()
metric.reset()
print 'training acc at epoch %d: %s=%f'%(i, name, acc)
print 'time: %f'%(time.time()-tic)
test(ctx)

net.params.save('mnist.params')

if __name__ == '__main__':
train(200, [mx.gpu(i) for i in range(8)])
import logging
logging.basicConfig(level=logging.DEBUG)
data = mx.sym.var('data')
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=[mx.gpu(i) for i in range(8)])
mod.fit(train_data, num_epoch=100, batch_end_callback = mx.callback.Speedometer(batch_size, 10))
32 changes: 32 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,26 @@ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle,
int *out_dev_type,
int *out_dev_id);
/*!
* \brief detach and ndarray from computation graph by clearing entry_
* \param handle NDArray handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out);
/*!
* \brief set the flag for gradient array state.
* \param handle NDArray handle
* \param state the new state.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySetGradState(NDArrayHandle handle, int state);
/*!
* \brief set the flag for gradient array state.
* \param handle NDArray handle
* \param state the new state.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
//--------------------------------
// Part 2: functions on NDArray
//--------------------------------
Expand Down Expand Up @@ -548,6 +568,18 @@ MXNET_DLL int MXAutogradMarkVariables(mx_uint num_var,
*/
MXNET_DLL int MXAutogradComputeGradient(mx_uint num_output,
NDArrayHandle* output_handles);
/*!
* \brief compute the gradient of outputs w.r.t variabels
* \param num_output number of output NDArray
* \param output_handles output NDArrays
* \param ograd_handles head gradient for NDArrays
* \param retain_graph whether to keep the graph after backward
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXAutogradBackward(mx_uint num_output,
NDArrayHandle* output_handles,
NDArrayHandle* ograd_handles,
int retain_graph);
//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down
Loading