In [28]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

In [30]:
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]]) # 将一维张量转换为二维列向量

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index] # 防止将全部数据加载 占用内存 只有提供index时才读取对应的行

    def __len__(self):
        return self.len

dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) # num_workers: 并行进程数

In [32]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()
        self.activate = torch.nn.ReLU()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()

In [34]:
criterion = torch.nn.BCELoss(reduction='mean')

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Adam优化器优势: 1.自适应学习率(适合处理不同尺度特征的数据 自动调整学习率大小) 2.内置动量机制: 加速收敛并减少震荡 3.对噪声和稀疏梯度的robust 4.早期的收敛速度快
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [36]:
for epoch in range(100):
    for i, data in enumerate(train_loader, 0):
        # prepare data
        inputs, outputs = data
        # Forward
        y_pred = model(inputs)
        loss = criterion(y_pred, outputs)
        print(epoch, i, loss.item())
        # Backward
        optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=1.0)
        optimizer.step()

0 0 0.7010657787322998
0 1 0.7197439670562744
0 2 0.7189052700996399
0 3 0.7048826217651367
0 4 0.7439979314804077
0 5 0.720372200012207
0 6 0.7192798852920532
0 7 0.7073383331298828
0 8 0.7214927673339844
0 9 0.7162734270095825
0 10 0.7053549289703369
0 11 0.7119475603103638
0 12 0.7142693400382996
0 13 0.7163422107696533
0 14 0.7180271148681641
0 15 0.6917305588722229
0 16 0.7139895558357239
0 17 0.7172983288764954
0 18 0.7015578746795654
0 19 0.7005729675292969
0 20 0.6939578056335449
0 21 0.6961404085159302
0 22 0.7065565586090088
0 23 0.7147424817085266
1 0 0.7028687596321106
1 1 0.7070249915122986
1 2 0.7093959450721741
1 3 0.6977869272232056
1 4 0.6998604536056519
1 5 0.7012599110603333
1 6 0.7009440064430237
1 7 0.6994372010231018
1 8 0.6960905194282532
1 9 0.6961193084716797
1 10 0.6961562037467957
1 11 0.6951104998588562
1 12 0.6930491924285889
1 13 0.6937616467475891
1 14 0.6938850283622742
1 15 0.6941798329353333
1 16 0.6941250562667847
1 17 0.6931402087211609
1 18 0.693415