# 3.3 线性回归的简洁实现

In [12]:
import torch
from torch import nn
import numpy as np
torch.manual_seed(1)

<torch._C.Generator at 0x251cac408d0>

## 3.3.1 生成数据集

In [7]:
num_inputs = 2
num_examples = 1000
true_w = [2,-3.4]
true_b = 4.2
features = torch.tensor(np.random.normal(0,1,(num_examples,num_inputs)),dtype=torch.float)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)

## 3.3.2 读取数据

In [8]:
import torch.utils.data as Data

batch_size = 10
# 将训练数据的特征和标签组合
dataset = Data.TensorDataset(features,labels)
# 随机读取小批量
data_iter = Data.DataLoader(dataset,batch_size,shuffle=True)

for x,y in data_iter:
    print(x,y)
    break

tensor([[-0.0855,  0.4226],
        [-0.2249, -1.0162],
        [ 0.7831,  1.9082],
        [-1.6112, -0.8587],
        [ 1.5580,  0.8329],
        [ 1.8370, -0.1645],
        [-0.7036,  1.0886],
        [ 1.1491, -1.8283],
        [ 0.6859, -0.7697],
        [ 0.0377, -1.5190]]) tensor([ 2.5740,  7.1853, -0.7198,  3.8905,  4.4847,  8.4323, -0.9159, 12.7167,
         8.1773,  9.4378])


In [25]:

class LinearNet(nn.Module):
    def __init__(self,n_feature):
        super(LinearNet,self).__init__()
        self.linear = nn.Linear(n_feature,1)
        
    def forward(self,x):
        return self.linear(x)
    
net = LinearNet(num_inputs)
print(net)

for param in net.parameters():
    print(param)

LinearNet(
  (linear): Linear(in_features=2, out_features=1, bias=True)
)
Parameter containing:
tensor([[0.1975, 0.6707]], requires_grad=True)
Parameter containing:
tensor([0.4667], requires_grad=True)


In [34]:
from torch.nn import init

init.normal_(net.linear.weight,mean=0,std=0.01)
init.constant_(net.linear.bias,val=0)
loss = nn.MSELoss()

import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr=0.03)
print(optimizer)

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.03
    momentum: 0
    nesterov: False
    weight_decay: 0
)


In [41]:
num_epochs = 10

for epoch in range(1,num_epochs+1):
    for X,y in data_iter:
        print(y.view(-1,1))
        output = net(X)
        # 在函数的参数中经常可以看到-1例如x.view(-1, 4)
        # 这里-1表示一个不确定的数，就是你如果不确定你想要reshape成几行，但是你很肯定要reshape成4列
        l = loss(output,y.view(-1,1))
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    print('epoch %d, loss %f' % (epoch , l.item()))

dense = net.linear
print(true_w, dense.weight)
print(true_b, dense.bias)

tensor([[10.8599],
        [ 9.3996],
        [ 6.6549],
        [ 4.0018],
        [-1.4253],
        [ 7.6049],
        [ 5.0116],
        [ 4.1740],
        [ 3.5123],
        [ 6.1627]])
tensor([[ 9.3933],
        [ 3.1260],
        [ 6.9165],
        [ 3.2284],
        [ 1.7361],
        [ 3.8428],
        [ 7.8971],
        [ 4.1128],
        [-1.3675],
        [ 1.5934]])
tensor([[ 8.6244],
        [ 7.9831],
        [-0.2153],
        [ 2.3095],
        [-2.2311],
        [ 2.9802],
        [ 0.5836],
        [ 7.6784],
        [-3.0048],
        [-2.1314]])
tensor([[ 1.1265],
        [ 5.5314],
        [ 3.0605],
        [ 4.6107],
        [-2.7677],
        [ 3.4473],
        [11.0422],
        [ 2.8007],
        [ 2.2428],
        [ 5.6296]])
tensor([[-3.3785],
        [10.0079],
        [-2.1659],
        [10.9136],
        [ 0.9638],
        [ 2.6264],
        [ 3.5954],
        [ 4.1650],
        [ 2.4023],
        [ 5.4002]])
tensor([[ 8.6212],
        [ 2.9734],
       

        [ 5.6296]])
tensor([[ 5.5422],
        [ 4.5324],
        [ 1.9974],
        [ 3.0452],
        [ 2.3322],
        [ 1.7552],
        [-0.4229],
        [ 4.1709],
        [10.5595],
        [ 7.0105]])
tensor([[ 0.9384],
        [ 6.5453],
        [ 3.0474],
        [-0.9987],
        [ 5.2607],
        [ 6.9196],
        [12.6437],
        [ 6.5367],
        [ 6.7663],
        [ 0.5447]])
tensor([[ 4.8134],
        [ 5.6426],
        [ 2.1585],
        [ 4.3038],
        [ 3.9910],
        [-0.9159],
        [ 3.5653],
        [ 6.5407],
        [ 3.0714],
        [ 9.2862]])
tensor([[-2.2311],
        [ 8.1773],
        [ 8.7176],
        [ 6.5927],
        [ 6.5269],
        [-2.9032],
        [-2.0259],
        [ 1.6615],
        [ 0.2233],
        [ 3.3848]])
tensor([[ 0.2111],
        [-4.6466],
        [ 3.9544],
        [-0.8449],
        [-0.5594],
        [ 4.1011],
        [ 6.3334],
        [ 7.5557],
        [ 8.5617],
        [-1.4197]])
tensor([[ 0.1193],
      

        [-1.0822]])
tensor([[ 4.3600],
        [-3.1852],
        [ 1.2831],
        [ 5.2951],
        [10.9299],
        [ 2.3312],
        [ 4.8586],
        [ 7.4837],
        [12.6437],
        [ 3.5068]])
tensor([[ 0.8996],
        [-0.7075],
        [ 5.7908],
        [ 9.6884],
        [-0.7655],
        [-0.9874],
        [ 0.4939],
        [ 4.3049],
        [ 4.1691],
        [ 3.0685]])
tensor([[5.3337],
        [2.4023],
        [0.8724],
        [8.9315],
        [9.2359],
        [4.1679],
        [4.9354],
        [1.1676],
        [4.0140],
        [5.1672]])
tensor([[ 0.9638],
        [ 3.5104],
        [ 4.4241],
        [ 4.0861],
        [ 0.2952],
        [ 7.9896],
        [ 3.8905],
        [ 4.2392],
        [ 6.7534],
        [-2.8947]])
tensor([[-2.1890],
        [ 1.2923],
        [-0.1444],
        [ 1.8944],
        [ 6.7663],
        [-4.6881],
        [10.3375],
        [ 3.9333],
        [-5.4502],
        [ 2.6260]])
tensor([[ 8.1773],
        [ 3.2423

tensor([[ 1.5462],
        [ 1.9974],
        [ 3.0411],
        [-0.3795],
        [ 5.0703],
        [ 2.3547],
        [ 3.3403],
        [ 7.6131],
        [ 1.8503],
        [ 7.5385]])
tensor([[ 8.1668],
        [-0.3434],
        [-2.1557],
        [ 2.1585],
        [ 1.4585],
        [ 4.5076],
        [ 1.8944],
        [ 4.7426],
        [ 4.3826],
        [-7.8058]])
tensor([[7.5557],
        [0.3261],
        [5.4043],
        [0.5836],
        [6.0056],
        [0.6217],
        [3.1394],
        [3.0291],
        [9.6312],
        [2.5693]])
tensor([[ 3.4555],
        [-2.2311],
        [-0.4431],
        [ 0.0265],
        [ 2.8007],
        [ 1.6182],
        [ 1.0531],
        [ 4.5956],
        [ 4.1040],
        [ 7.5349]])
tensor([[ 3.4518],
        [ 3.9858],
        [ 4.1740],
        [ 9.2344],
        [-0.4308],
        [-2.0259],
        [10.7933],
        [ 7.2035],
        [ 2.2621],
        [ 9.3996]])
tensor([[ 3.8882],
        [ 1.2967],
        [-0.9352]

        [-0.3576]])
tensor([[ 3.5068],
        [ 3.0872],
        [11.5916],
        [-2.9032],
        [ 4.8606],
        [10.4460],
        [ 0.5994],
        [-5.2434],
        [ 7.0655],
        [ 1.3089]])
tensor([[ 5.9702],
        [ 2.5048],
        [12.5823],
        [ 1.5934],
        [ 7.8247],
        [-1.3262],
        [ 3.5955],
        [ 4.2303],
        [ 6.1640],
        [ 4.1128]])
tensor([[11.9609],
        [ 2.6787],
        [ 6.8936],
        [ 3.0605],
        [ 2.3095],
        [ 2.9802],
        [-0.8168],
        [ 4.0633],
        [ 2.7469],
        [ 0.1941]])
tensor([[ 8.4410],
        [ 3.5890],
        [ 3.3066],
        [ 1.1712],
        [ 0.4468],
        [-2.5469],
        [ 3.8695],
        [ 4.2895],
        [10.6382],
        [ 1.1514]])
tensor([[ 6.9349],
        [ 7.9441],
        [11.1755],
        [ 7.6784],
        [ 3.1260],
        [-3.9997],
        [-4.6432],
        [ 6.9661],
        [ 5.5314],
        [ 6.4765]])
epoch 7, loss 0.000164
te

        [-1.4616]])
tensor([[ 4.2318],
        [ 6.2867],
        [ 7.9648],
        [ 5.3774],
        [12.3077],
        [-4.5814],
        [ 5.7180],
        [-3.8425],
        [-2.2873],
        [-2.1890]])
tensor([[ 4.9159e+00],
        [-1.3363e+00],
        [ 1.1423e+01],
        [ 4.0633e+00],
        [ 6.1640e+00],
        [ 1.0164e+01],
        [ 1.8489e+00],
        [ 8.5872e+00],
        [-6.2732e-04],
        [ 2.2015e+00]])
tensor([[10.3198],
        [10.2794],
        [ 3.1466],
        [10.3375],
        [-0.9159],
        [-4.5981],
        [ 7.9384],
        [-4.6432],
        [ 7.8247],
        [-3.9288]])
tensor([[ 4.2775],
        [ 3.5955],
        [ 0.5658],
        [ 4.0200],
        [ 7.7408],
        [-6.5278],
        [ 2.0577],
        [ 3.1317],
        [ 6.5407],
        [10.2279]])
tensor([[ 3.0685],
        [ 7.6702],
        [ 2.3652],
        [ 3.8251],
        [-0.5571],
        [-2.1659],
        [-4.4502],
        [ 9.6884],
        [ 4.1795],
     