In [8]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import numpy as np

class DiabetesDataset(Dataset):
    def __init__(self, file_path):
        xy = np.loadtxt(file_path, delimiter=' ', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(9, 6)
        self.linear2 = nn.Linear(6, 4)
        self.linear3 = nn.Linear(4, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

dataset = DiabetesDataset('../7.处理多维特征的输入/diabetes_data_raw.csv.gz')
train_loader = DataLoader(dataset = dataset, batch_size=32, shuffle=True, num_workers=0) # num_workers=0 在 Windows 环境下使用 PyTorch 的 DataLoader 时，可能会遇到 DataLoader worker 退出的问题。这通常是由于 Windows 下的多进程机制与 PyTorch 的 DataLoader 不兼容所致。
model = Model()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(100):
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        y_pred = model(inputs)
        loss = criterion(y_pred, labels)
        print(epoch, i, loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


0 0 8621.1640625
0 1 7952.830078125
0 2 8037.923828125
0 3 8756.767578125
0 4 8298.4228515625
0 5 8272.9853515625
0 6 8241.671875
0 7 8067.23388671875
0 8 8349.8291015625
0 9 8302.7333984375
0 10 7972.57763671875
0 11 8480.7021484375
0 12 8049.0146484375
0 13 8669.74609375
1 0 8971.140625
1 1 8205.2021484375
1 2 8247.701171875
1 3 8330.107421875
1 4 8205.7333984375
1 5 8087.88720703125
1 6 7832.32470703125
1 7 8030.1064453125
1 8 8336.107421875
1 9 8512.982421875
1 10 8410.044921875
1 11 7792.32568359375
1 12 8545.044921875
1 13 8423.3193359375
2 0 8094.85595703125
2 1 8353.9189453125
2 2 8098.48095703125
2 3 7954.94970703125
2 4 8267.4189453125
2 5 8741.326171875
2 6 8253.2626953125
2 7 7289.94873046875
2 8 8660.66796875
2 9 8881.66796875
2 10 8684.04296875
2 11 8763.91796875
2 12 8094.10546875
2 13 7645.6259765625
3 0 8827.82421875
3 1 8502.32421875
3 2 8162.82373046875
3 3 8038.3857421875
3 4 7607.10546875
3 5 8851.82421875
3 6 8444.01171875
3 7 7687.51123046875
3 8 8311.0107421875
