In [104]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

In [105]:
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:,:-1])      #取所有行(所有样本)，每行取第一列到倒数第二列
        self.y_data = torch.from_numpy(xy[:,[-1]])     #取所有行，每行取最后一列，注意要用方括号括起，表明这是一个矩阵

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

In [106]:
dataset = DiabetesDataset(r'datasets/diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)   #此处多线程改为>0就总是报错，不知如何解决...

In [107]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)   #将8维输入降维为6维
        self.linear2 = torch.nn.Linear(6, 4)   #...
        self.linear3 = torch.nn.Linear(4, 1)   #...为1维
        self.sigmoid = torch.nn.Sigmoid()      #nn.Sigmoid可以构建计算图，functional.sigmoid不行

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

In [108]:
model = Model()

criterion = torch.nn.BCELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [111]:
if __name__ == '__main__':
    for epoch in range(1000):
        for i, (inputs, labels) in enumerate(train_loader):   #迭代时，DataLoader对象会调用Dataset对象的__getitem__方法来获取数据样本
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

0 0 14.284976959228516
0 1 18.857349395751953
0 2 23.049983978271484
0 3 10.601666450500488
0 4 16.85533905029297
0 5 13.773054122924805
0 6 20.34584617614746
0 7 9.604386329650879
0 8 11.267093658447266
0 9 10.964487075805664
0 10 19.536928176879883
0 11 11.202899932861328
0 12 18.715194702148438
0 13 12.294262886047363
0 14 17.229232788085938
0 15 15.267297744750977
0 16 16.850467681884766
0 17 15.476259231567383
0 18 11.414511680603027
0 19 17.23111343383789
0 20 17.956501007080078
0 21 13.165566444396973
0 22 13.99106216430664
0 23 7.063802719116211
1 0 12.28667163848877
1 1 13.304987907409668
1 2 22.284278869628906
1 3 15.386178016662598
1 4 14.792638778686523
1 5 22.333724975585938
1 6 13.547369003295898
1 7 12.346099853515625
1 8 14.153853416442871
1 9 18.55133819580078
1 10 12.116819381713867
1 11 16.904003143310547
1 12 18.091947555541992
1 13 12.761507034301758
1 14 11.01201057434082
1 15 8.924812316894531
1 16 12.301843643188477
1 17 14.261887550354004
1 18 17.43215560913086