# 自定义Dataset，使用DataLoader

## 准备数据集，定义自己的Dataset类

In [23]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

# 自定义类继承Dataset
class DiabetesDataset(Dataset):
    def __init__(self,filepath):
        xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:,:-1]) 
        self.y_data = torch.from_numpy(xy[:,[-1]])
        
    
    def __getitem__(self,index):
        return self.x_data[index],self.y_data[index]
    
    def __len__(self):
        return self.len
    
dataset = DiabetesDataset('./dataset/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)

## 设计模型

In [26]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1 = torch.nn.Linear(8,6)
        self.linear2 = torch.nn.Linear(6,4)        
        self.linear3 = torch.nn.Linear(4,1)
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self,x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x
    
model = Model()

## 构造损失函数和优化器

In [41]:
criterion = torch.nn.BCELoss(reduction='mean'
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)

## 训练模型

In [42]:
for epoch in range(100):
    for i,data in enumerate(train_loader,0):
        # 1.准备数据
        inputs,label = data
        # 2.向前传播
        y_pred = model(inputs)
        loss = criterion(y_pred,label)
        print(epoch,i,loss.item())
        # 3.向后传播
        optimizer.zero_grad()
        loss.backward()
        # 4.更新
        optimizer.step()

0 0 0.6235154867172241
0 1 0.64347243309021
0 2 0.5627185702323914
0 3 0.5802675485610962
0 4 0.6434760689735413
0 5 0.6867217421531677
0 6 0.6856158971786499
0 7 0.5817035436630249
0 8 0.6646777391433716
0 9 0.5805817246437073
0 10 0.7732195258140564
0 11 0.7038154006004333
0 12 0.6630167961120605
0 13 0.64347904920578
0 14 0.643456220626831
0 15 0.6241729259490967
0 16 0.6825361847877502
0 17 0.623955488204956
0 18 0.565483808517456
0 19 0.6229248046875
0 20 0.6641731262207031
0 21 0.6842526197433472
0 22 0.7233335375785828
0 23 0.6194761395454407
1 0 0.6434307098388672
1 1 0.5655342936515808
1 2 0.643239438533783
1 3 0.6231148838996887
1 4 0.7052810788154602
1 5 0.663301408290863
1 6 0.6434690952301025
1 7 0.6632025241851807
1 8 0.6042050123214722
1 9 0.5432953834533691
1 10 0.7068833112716675
1 11 0.6434580087661743
1 12 0.7251449227333069
1 13 0.6434575319290161
1 14 0.7413251399993896
1 15 0.551510214805603
1 16 0.6236625909805298
1 17 0.6432721614837646
1 18 0.6035352945327759
1

14 17 0.6433537006378174
14 18 0.6434304118156433
14 19 0.6049387454986572
14 20 0.7023637294769287
14 21 0.5480846166610718
14 22 0.6634326577186584
14 23 0.6457653641700745
15 0 0.643176794052124
15 1 0.7234731912612915
15 2 0.5288226008415222
15 3 0.6639221906661987
15 4 0.5621719360351562
15 5 0.6645979285240173
15 6 0.643358588218689
15 7 0.6849940419197083
15 8 0.7655694484710693
15 9 0.68145352602005
15 10 0.6623167991638184
15 11 0.6435600519180298
15 12 0.6434693336486816
15 13 0.6434378027915955
15 14 0.6250035166740417
15 15 0.643351137638092
15 16 0.6432420611381531
15 17 0.6434085965156555
15 18 0.6627602577209473
15 19 0.6243333220481873
15 20 0.5853747725486755
15 21 0.6432747840881348
15 22 0.643002450466156
15 23 0.6179060935974121
16 0 0.6025017499923706
16 1 0.6227580904960632
16 2 0.6225063800811768
16 3 0.5798919200897217
16 4 0.7090369462966919
16 5 0.6223006844520569
16 6 0.6436066031455994
16 7 0.622096598148346
16 8 0.6220757961273193
16 9 0.7078016996383667
16

28 22 0.7265534996986389
28 23 0.6733142733573914
29 0 0.6236379742622375
29 1 0.7025090456008911
29 2 0.7006568908691406
29 3 0.6806433200836182
29 4 0.6253694295883179
29 5 0.5692721605300903
29 6 0.6625151634216309
29 7 0.6629401445388794
29 8 0.6244004368782043
29 9 0.5655781030654907
29 10 0.6430279612541199
29 11 0.582171618938446
29 12 0.7064627408981323
29 13 0.7243305444717407
29 14 0.5851283669471741
29 15 0.6831375956535339
29 16 0.7613371014595032
29 17 0.6253602504730225
29 18 0.5682244300842285
29 19 0.6630377769470215
29 20 0.604321300983429
29 21 0.6232053637504578
29 22 0.6634885668754578
29 23 0.6457359790802002
30 0 0.6435500979423523
30 1 0.6430378556251526
30 2 0.5631155967712402
30 3 0.6224053502082825
30 4 0.7064751386642456
30 5 0.7042415738105774
30 6 0.6627887487411499
30 7 0.6625425219535828
30 8 0.5853827595710754
30 9 0.6234021186828613
30 10 0.805057168006897
30 11 0.6249943971633911
30 12 0.6621757745742798
30 13 0.6060588359832764
30 14 0.662504673004150

43 0 0.6629101037979126
43 1 0.6629611849784851
43 2 0.5643129944801331
43 3 0.6845817565917969
43 4 0.6028856039047241
43 5 0.70452880859375
43 6 0.6234015226364136
43 7 0.6832983493804932
43 8 0.642878532409668
43 9 0.6819958686828613
43 10 0.6818708777427673
43 11 0.6053129434585571
43 12 0.5850803852081299
43 13 0.6633722186088562
43 14 0.6031848192214966
43 15 0.7654733657836914
43 16 0.6620317697525024
43 17 0.6618714332580566
43 18 0.6243867874145508
43 19 0.6051952838897705
43 20 0.6038689613342285
43 21 0.6431995630264282
43 22 0.6031986474990845
43 23 0.6739935278892517
44 0 0.6227869987487793
44 1 0.6430032253265381
44 2 0.6835311651229858
44 3 0.6032143831253052
44 4 0.8057636022567749
44 5 0.7546629309654236
44 6 0.7132286429405212
44 7 0.627782940864563
44 8 0.6442496180534363
44 9 0.5750295519828796
44 10 0.6435739994049072
44 11 0.5519839525222778
44 12 0.7209544777870178
44 13 0.5873943567276001
44 14 0.6233857870101929
44 15 0.5839765071868896
44 16 0.6226562261581421

57 0 0.5810818672180176
57 1 0.7280372381210327
57 2 0.6024565696716309
57 3 0.6632265448570251
57 4 0.5814905762672424
57 5 0.7701418995857239
57 6 0.6230987906455994
57 7 0.681959867477417
57 8 0.7214409708976746
57 9 0.699191153049469
57 10 0.5886918902397156
57 11 0.5860459804534912
57 12 0.6428358554840088
57 13 0.6036996245384216
57 14 0.5617863535881042
57 15 0.6005740165710449
57 16 0.5357432961463928
57 17 0.6658092737197876
57 18 0.6431043148040771
57 19 0.6875563859939575
57 20 0.7079542875289917
57 21 0.622094988822937
57 22 0.7274221181869507
57 23 0.6732838153839111
58 0 0.6823045015335083
58 1 0.6814302206039429
58 2 0.6995663642883301
58 3 0.6251471638679504
58 4 0.6060574054718018
58 5 0.6237787008285522
58 6 0.6815783381462097
58 7 0.604759931564331
58 8 0.6425247192382812
58 9 0.5448755025863647
58 10 0.6837636232376099
58 11 0.6627732515335083
58 12 0.7031292915344238
58 13 0.6428459286689758
58 14 0.6817426681518555
58 15 0.6620582938194275
58 16 0.6617568731307983

71 0 0.7487330436706543
71 1 0.6628693342208862
71 2 0.7018512487411499
71 3 0.6618274450302124
71 4 0.7181261777877808
71 5 0.5709524154663086
71 6 0.7000613212585449
71 7 0.605849027633667
71 8 0.6040353178977966
71 9 0.5644950866699219
71 10 0.6427320241928101
71 11 0.6426084041595459
71 12 0.7038625478744507
71 13 0.6813438534736633
71 14 0.6425430774688721
71 15 0.7018960118293762
71 16 0.6049965023994446
71 17 0.642615795135498
71 18 0.6421378254890442
71 19 0.5458915829658508
71 20 0.6632046699523926
71 21 0.6218852996826172
71 22 0.6013547778129578
71 23 0.5864297747612
72 0 0.6641229391098022
72 1 0.6426159143447876
72 2 0.7275133728981018
72 3 0.6831192970275879
72 4 0.5443115830421448
72 5 0.6635568141937256
72 6 0.6012095212936401
72 7 0.7264440059661865
72 8 0.681871771812439
72 9 0.6618186235427856
72 10 0.5837434530258179
72 11 0.6832770705223083
72 12 0.7412658929824829
72 13 0.6614442467689514
72 14 0.6058700084686279
72 15 0.5661816000938416
72 16 0.6025146842002869
7

85 0 0.5638611316680908
85 1 0.6621200442314148
85 2 0.5198816657066345
85 3 0.6212576031684875
85 4 0.7068830728530884
85 5 0.6003246903419495
85 6 0.6212673187255859
85 7 0.6859431266784668
85 8 0.6840047836303711
85 9 0.6211709380149841
85 10 0.7255482077598572
85 11 0.7027981281280518
85 12 0.6045465469360352
85 13 0.6805129051208496
85 14 0.6618925333023071
85 15 0.6232412457466125
85 16 0.7384093999862671
85 17 0.6423535943031311
85 18 0.5508671998977661
85 19 0.6022148728370667
85 20 0.6625404953956604
85 21 0.6819108724594116
85 22 0.6023982763290405
85 23 0.7273938655853271
86 0 0.6606826186180115
86 1 0.6420056819915771
86 2 0.62389075756073
86 3 0.6413408517837524
86 4 0.6234930753707886
86 5 0.6032781600952148
86 6 0.6819703578948975
86 7 0.6033049821853638
86 8 0.682169497013092
86 9 0.6224381923675537
86 10 0.7606539726257324
86 11 0.6606261134147644
86 12 0.6603675484657288
86 13 0.624600350856781
86 14 0.5490775108337402
86 15 0.6427363157272339
86 16 0.6818735003471375

99 0 0.6606794595718384
99 1 0.5809068083763123
99 2 0.5800348520278931
99 3 0.64200758934021
99 4 0.6401567459106445
99 5 0.683415412902832
99 6 0.5153689980506897
99 7 0.797930121421814
99 8 0.62081378698349
99 9 0.7870451211929321
99 10 0.7178026437759399
99 11 0.6409239172935486
99 12 0.6406272649765015
99 13 0.6619828343391418
99 14 0.586504340171814
99 15 0.6605759263038635
99 16 0.6787992119789124
99 17 0.5671321153640747
99 18 0.6602593660354614
99 19 0.6418062448501587
99 20 0.6811391115188599
99 21 0.5839673280715942
99 22 0.603038489818573
99 23 0.615507185459137
