# 线性回归 — 使用Gluon

In [5]:
# 创建数据集
from mxnet import ndarray as nd
from mxnet import autograd
from mxnet import gluon

num_inputs = 2
num_examples = 1000

true_w = [2, -3.4]
true_b = 4.2

X = nd.random_normal(shape=(num_examples, num_inputs))
y = true_w[0] * X[:, 0] + true_w[1] * X[:, 1] + true_b
y += .01 * nd.random_normal(shape=y.shape)
print(X.shape, y.shape)
print(X[0], y[0])

((1000L, 2L), (1000L,))
(
[-1.53171515 -0.88783199]
<NDArray 2 @cpu(0)>, 
[ 4.13801384]
<NDArray 1 @cpu(0)>)


In [6]:
# 数据读取
batch_size = 10
dataset = gluon.data.ArrayDataset(X, y)
data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True)
for data, label in data_iter:
    print data, label
    break


[[ 0.76334846 -0.71399009]
 [ 0.52214307  0.50699794]
 [-1.60041606 -1.3150456 ]
 [-0.43577141 -2.11474633]
 [-0.74014026 -1.06094301]
 [-0.50176102 -0.31240413]
 [ 1.91820276  0.39311436]
 [ 0.19981819 -0.76805735]
 [-1.53171515 -0.88783199]
 [ 1.11138046 -0.21727505]]
<NDArray 10x2 @cpu(0)> 
[  8.14267063   3.52273297   5.47352409  10.52182484   6.34672022
   4.25974178   6.70802546   7.20674849   4.13801384   7.14989901]
<NDArray 10 @cpu(0)>


In [7]:
# 定义模型
net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(1)) # 输出结点个数为1

In [8]:
# 初始化模型参数
net.initialize()

In [9]:
# 损失函数
square_loss = gluon.loss.L2Loss()

In [10]:
# 优化
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':0.1})

In [11]:
# 训练
epochs = 5
batch_size = 10
for e in range(epochs):
    total_loss = 0
    for data, label in data_iter:
        with autograd.record():
            output = net(data)
            loss = square_loss(output, label)
        loss.backward()
        trainer.step(batch_size)
        total_loss += nd.sum(loss).asscalar()
    print("Epoch %d, average loss: %f" % (e, total_loss/num_examples))

Epoch 0, average loss: 0.883264
Epoch 1, average loss: 0.000050
Epoch 2, average loss: 0.000051
Epoch 3, average loss: 0.000051
Epoch 4, average loss: 0.000051


In [12]:
dense = net[0]
true_w, dense.weight.data()

([2, -3.4], 
 [[ 2.0007093  -3.39950085]]
 <NDArray 1x2 @cpu(0)>)

In [15]:
true_b, dense.bias.data()

(4.2, 
 [ 4.20089436]
 <NDArray 1 @cpu(0)>)