In [22]:
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset 

# 生成模拟数据

In [19]:
def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, size=(num_examples, len(w)))
    Y = torch.matmul(X, w) + b
    Y += torch.normal(0, 0.01, size = Y.shape)
    return X, Y.reshape(-1, 1)

In [20]:
true_w = torch.tensor([3.2, 4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

# 获取数据迭代器

In [27]:
def get_data_iter(data, batch_size, is_train=True):
    dataset = TensorDataset(*data)
    data_iter = DataLoader(dataset, batch_size, shuffle = is_train)
    return data_iter

In [29]:
data_iter = get_data_iter([features, labels], 10)

In [31]:
next(iter(data_iter))

[tensor([[-0.4386, -1.1963],
         [ 0.6923, -0.1040],
         [ 1.5843, -0.3188],
         [-0.3641, -0.2749],
         [ 1.2557, -1.6779],
         [-0.1862, -0.3417],
         [ 0.6682, -1.3999],
         [ 1.1975,  0.2846],
         [-1.3019,  2.2075],
         [-0.0748,  1.9454]]),
 tensor([[-1.9917],
         [ 6.0024],
         [ 8.0046],
         [ 1.9243],
         [ 1.5199],
         [ 2.2368],
         [ 0.7314],
         [ 9.1652],
         [ 8.8734],
         [11.7320]])]

# 模型，优化器，损失函数

In [35]:
net = nn.Sequential(nn.Linear(2, 1))
# 初始化canshu
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0.0)

optimizer = torch.optim.SGD(net.parameters(), lr = 0.03)
loss_func = nn.MSELoss()

# 训练

In [39]:
epoch_nums = 10

for epoch in range(epoch_nums):
    for X, y in data_iter:
        loss = loss_func(net(X), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss = loss_func(net(X), y)
    print(f"epoch: {epoch}, loss: {loss}")


epoch: 0, loss: 4.796232678927481e-05
epoch: 1, loss: 7.420527253998443e-05
epoch: 2, loss: 7.121392991393805e-05
epoch: 3, loss: 9.217062324751168e-05
epoch: 4, loss: 6.221771764103323e-05
epoch: 5, loss: 5.615618283627555e-05
epoch: 6, loss: 0.00015340746904257685
epoch: 7, loss: 7.011611160123721e-05
epoch: 8, loss: 3.2252231903839856e-05
epoch: 9, loss: 8.344273373950273e-05


In [40]:
net[0].weight.data, net[0].bias.data

(tensor([[3.1995, 3.9998]]), tensor([4.2002]))