In [66]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

In [67]:
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:,:-1])
        self.y_data = torch.from_numpy(xy[:,[-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
    
    def __len__(self):
        return self.len

In [68]:
dataset = DiabetesDataset('D:/BaiduNetdiskDownload/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)

In [69]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

In [70]:
model = Model()

In [71]:
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [73]:
for epoch in range(100):
    for i, data in enumerate(train_loader, 0):
        input, label = data
        y_pred = model(input)
        loss = criterion(y_pred, label)
        print(epoch, i, loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

0 0 0.5435408353805542
0 1 0.6223184466362
0 2 0.6826750040054321
0 3 0.6819577813148499
0 4 0.6029261350631714
0 5 0.681847095489502
0 6 0.6220875382423401
0 7 0.6218135952949524
0 8 0.5630695223808289
0 9 0.6826927065849304
0 10 0.6419224739074707
0 11 0.6826149821281433
0 12 0.7222346663475037
0 13 0.6027178764343262
0 14 0.6420824527740479
0 15 0.6628562808036804
0 16 0.642457127571106
0 17 0.7221549153327942
0 18 0.6227484345436096
0 19 0.6623396873474121
0 20 0.7613834142684937
0 21 0.5634104013442993
0 22 0.6427303552627563
0 23 0.5625922679901123
1 0 0.5628756880760193
1 1 0.6627597808837891
1 2 0.6423060894012451
1 3 0.6224116683006287
1 4 0.7217874526977539
1 5 0.5835828185081482
1 6 0.6622223854064941
1 7 0.6622980833053589
1 8 0.5829110145568848
1 9 0.6821032166481018
1 10 0.6223850846290588
1 11 0.7224562764167786
1 12 0.7616926431655884
1 13 0.5832773447036743
1 14 0.6417575478553772
1 15 0.7018678188323975
1 16 0.6821523904800415
1 17 0.6021953225135803
1 18 0.6031361222