In [68]:
import random
import torch
from d2l import torch as d2l

In [69]:
'''生成数据'''
def synthetic_data(w, b, num_example):
    x = torch.normal(0, 1, (num_example, len(w)))
    y = torch.matmul(x, w) + b

    y += torch.normal(0.00, 0.01, y.shape)

    return x, y.reshape((-1,1))
    

In [70]:
# 生成数据
true_w = torch.tensor([2.0, 3.0])
true_b = torch.tensor(5.0)
features, labels = synthetic_data(true_w, true_b, 1000)

In [71]:
'''读取数据'''
def data_iter(batch_size, features, labels):
    num_example = len(features)
    indices = list(range(num_example))
    random.shuffle(indices)

    for i in range(0, num_example, batch_size):
        batch_indices = torch.tensor(indices[i : min(i+batch_size, num_example)])
        yield features[batch_indices], labels[batch_indices]


In [78]:
'''初始化参数'''
# w = torch.normal(0 , 0.01, (2,1))
w = torch.tensor([0,0], dtype = torch.float32).reshape((2,1))
b = torch.normal(0, 0.01, (1,))
w.requires_grad_(True), b.requires_grad_(True)

(tensor([[0.],
         [0.]], requires_grad=True),
 tensor([0.0068], requires_grad=True))

In [79]:
'''定义模型'''
def linear_regression(x, w, b):
    y = torch.matmul(x, w) + b
    return y

In [80]:
'''损失函数'''
def squared_loss(y_hat, y):
    l = (y_hat - y.reshape(y_hat.shape))**2 /2
    return l

In [81]:
'''优化算法-随机梯度下降'''
def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr*param.grad/batch_size
            param.grad.zero_()

In [82]:
'''训练'''
lr = 0.03
num_epochs = 3
batch_size = 10

In [83]:
for epoch in range(num_epochs):
    for x, y in  data_iter(batch_size,features, labels):
        y_hat = linear_regression(x, w, b)
        # 损失
        loss = squared_loss(y_hat, y)
        loss.sum().backward()
        # 更新参数
        sgd([w,b], lr, batch_size)

    with torch.no_grad():
        train_loss = squared_loss(linear_regression(features, w, b), labels)
        print(f'epoch:{epoch+1}, train_loss={float(train_loss.mean()):2f}')

epoch:1, train_loss=0.045988
epoch:2, train_loss=0.000182
epoch:3, train_loss=0.000049
