In [None]:
from mxnet import ndarray as nd
from mxnet import gluon
from mxnet import autograd

num_train = 20
num_test = 100
num_inputs = 200

In [None]:
true_W = nd.ones((num_inputs,1)) * 0.01
true_b = 0.05

In [None]:
X = nd.random_normal(shape=(num_train + num_test,num_inputs))
y = nd.dot(X,true_W)
y += .01 * nd.random_normal(shape=y.shape)

X_train,X_test = X[:num_train,:],X[num_train:,:]
y_train,y_test = y[:num_train],y[num_train:]

In [None]:
import random
batch_size = 1
def data_iter(num_examples):
    idx = list(range(num_examples))
    random.shuffle(idx)
    for i in range(0,num_examples,batch_size):
        j = nd.array(idx[i:min(i+batch_size,num_examples)])
        yield X.take(j),y.take(j)

In [None]:
def get_params():
    W = nd.random_normal(shape=(num_inputs,1))*0.1
    b = nd.zeros((1,))
    for param in (W,b):
        param.attach_grad()
    return (W,b)

In [None]:
def net(X,lambd,W,b):
    return nd.dot(X,W) + b + lambd * ((W**2).sum() + b**2)

In [None]:
%matplotlib inline
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 120
import matplotlib.pyplot as plt

def square_loss(yhat,y):
    return (yhat - y.reshape(yhat.shape)) ** 2

def SGD(params,lr):
    for param in params:
        param[:] = param - lr * param.grad

def test(params,X,y):
    return square_loss(net(X,0,*params),y).mean().asscalar()

def train(lambd):
    epochs = 10
    learning_rate = 0.002
    params = get_params()
    train_loss = []
    test_loss = []
    for e in range(epochs):
        for data,label in data_iter(num_train):
            with autograd.record():
                output = net(data,lambd,*params)
                loss = square_loss(output,label)
            loss.backward()
            SGD(params,learning_rate)
        train_loss.append(test(params,X_train,y_train))
        test_loss.append(test(params,X_test,y_test))
    plt.plot(train_loss)
    plt.plot(test_loss)
    plt.legend(['train','test'])
    plt.show()
    return 'learned W[:10]:',params[0][:10],'learned b:',params[1]

In [None]:
train(0)

In [None]:
train(2)