In [1]:
import numpy as np
import torch
# Dataset是一个抽象类, DataLoader用来加载数据(比如，自动做Batch_size,完成shuffle操作等)
from torch.utils.data import Dataset, DataLoader  
import pandas as pd

## 1. Prepare Dataset
### diabetes是一个糖尿病数据集，N行9列，前8列为features，最后一个为标签

In [2]:
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        # 读取csv, csv.gz文件;  以 ',' 作为分割符;  1080/2080显卡只支持32位浮点数的数据类型
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32) 
        self.len = xy.shape[0]                        # xy.shape结果为[N,9], 取第一个索引的值,即N。得到数据集中的样本个数
        # torch.from_numpy会根据获取的数值 创建Tensor张量
        self.x_data = torch.from_numpy(xy[:, :-1])    # 取所有行，前8列的数据
        self.y_data = torch.from_numpy(xy[:, [-1]])   # 取所有行，最后一列。 用[-1]表示拿出来一个矩阵，如果不加[]得到的是向量
        
    def __getitem__(self, index):                     # 此函数获取索引为index的数据（对于数据集内存占用过大情况，现用现调）
        return self.x_data[index], self.y_data[index] # 返回一个元组（ ，）

    def __len__(self):                                # 此函数获取数据集中样本总数
        return self.len

In [3]:
dataset = DiabetesDataset('diabetes.csv.gz')
# 调用DataLoader是为了每次调用这一组数据，而不用一次性加载全部数据造成内存占满的情况
train_loader = DataLoader(dataset  = dataset,   
                          batch_size = 32,
                          shuffle = True,        # shuffle对一个数据集是否打乱顺序
                          num_workers = 2)       # 读取数据时，需不需要多线程并行读取，这里设置为2表示用2个进程读取数据

## 2. Design model using Class

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

In [5]:
model = Model()

## 3. Construct loss and optimizer

In [6]:
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)



## 4. Training cycle 
### (1) 1次epoch表示对所有样本训练1次， Batch_Size表示1次训练的样本数量， Iteration表示batch一共分了多少个，即内层的迭代一共进行了多少次，例如，对于一个10,000的样本，Batch_Size = 1000，则Iteraion = 10
### (2) 这里把一个循环改为嵌套for循环以便使用Mini-Batch进行训练， Mini-Batch使用了GPU的并行运算提高运算效率

In [7]:
# DataFrame的单元格可以存放数值、字符串等，同时DataFrame可以设置列名columns与行名index。
# columns可以不写在pd.Series()的index=[]处写
log = pd.DataFrame(index=[], columns=['epoch', 'lr', 'loss'])

if __name__ == '__main__':    # Windows 和 Linux 多线程的接口不同，在使用时可能报RuntimeError，改为此可以通过(0.4版本问题)
    for epoch in range(100):
        # enumerate进行枚举, 参数0表示从0开始进行枚举(默认参数就是0)，记到i中， data每次拿到的就是第i组的数据(这里就是32×8矩阵)
        # train_loader会自动把矩阵数据转换乘Tensor
        for i, data in enumerate(train_loader, 0):  # train_loader拿到__getitem__()返回的(x[i], y[i])元组放到data里，
            # 1. Prepare data
            inputs, labels = data                   # 这里inputs, labels都是Tensor张量了
            # 2. Forward
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
            # 3. Backward
            optimizer.zero_grad()
            loss.backward()
            # 4. Updata
            optimizer.step()
        
        # 保存模型参数
        # Series 是一维数组, 基于Numpy的ndarray结构
        # pd.Series([list],index=[list]): 参数为list ,index为可选参数，若不填写则默认为index从0开始
        temp = pd.Series([
            epoch, 
            0.01, 
            loss.item()
        ], index=['epoch', 'lr', 'loss'])
        
        log = log.append(temp, ignore_index=True)     # 只有当ignore_index=True或Series有名称时,才能追加Series
        log.to_csv('log/log.csv', index=False)            # 将log统计到csv文件中, 不设置行号

0 0 0.6781840324401855
0 1 0.6780489683151245
0 2 0.6745361685752869
0 3 0.670197606086731
0 4 0.6914077401161194
0 5 0.6696715354919434
0 6 0.6728875637054443
0 7 0.6716985106468201
0 8 0.6872734427452087
0 9 0.6750584840774536
0 10 0.679185688495636
0 11 0.6627431511878967
0 12 0.6743265986442566
0 13 0.6743621230125427
0 14 0.6736940741539001
0 15 0.6688845753669739
0 16 0.7098183631896973
0 17 0.6685884594917297
0 18 0.6684346199035645
0 19 0.6637052893638611
0 20 0.6618891954421997
0 21 0.6913508772850037
0 22 0.6571475267410278
0 23 0.6858339309692383
1 0 0.6761105060577393
1 1 0.6658093929290771
1 2 0.6758351922035217
1 3 0.6712312698364258
1 4 0.6755590438842773
1 5 0.6645718216896057
1 6 0.6584717631340027
1 7 0.6748349070549011
1 8 0.6974685788154602
1 9 0.6578949689865112
1 10 0.679921567440033
1 11 0.6975662708282471
1 12 0.6341253519058228
1 13 0.6504709720611572
1 14 0.6739828586578369
1 15 0.6555998921394348
1 16 0.6361373066902161
1 17 0.6538922190666199
1 18 0.68572390

15 0 0.7125488519668579
15 1 0.7286705374717712
15 2 0.6616706252098083
15 3 0.6949934959411621
15 4 0.6285768151283264
15 5 0.6285820007324219
15 6 0.6785580515861511
15 7 0.645490288734436
15 8 0.6789294481277466
15 9 0.5616820454597473
15 10 0.5942302942276001
15 11 0.6451250314712524
15 12 0.7123931050300598
15 13 0.5599685907363892
15 14 0.6106811761856079
15 15 0.6621284484863281
15 16 0.6274304389953613
15 17 0.6109097599983215
15 18 0.6272336840629578
15 19 0.6964680552482605
15 20 0.6449592113494873
15 21 0.6442679166793823
15 22 0.627626895904541
15 23 0.6223702430725098
16 0 0.6446802616119385
16 1 0.5921852588653564
16 2 0.6794084310531616
16 3 0.6268105506896973
16 4 0.71389240026474
16 5 0.6615152955055237
16 6 0.6615817546844482
16 7 0.6279290318489075
16 8 0.6795368194580078
16 9 0.5929120182991028
16 10 0.6449165940284729
16 11 0.6447303891181946
16 12 0.6963213682174683
16 13 0.6443575620651245
16 14 0.5923216342926025
16 15 0.6794573068618774
16 16 0.6448381543159485

29 0 0.6624752879142761
29 1 0.720671534538269
29 2 0.6816242933273315
29 3 0.6626040935516357
29 4 0.6629800200462341
29 5 0.6625314354896545
29 6 0.5868377089500427
29 7 0.6817840933799744
29 8 0.6625424027442932
29 9 0.6258065104484558
29 10 0.6815754175186157
29 11 0.6252983212471008
29 12 0.7381395697593689
29 13 0.606868326663971
29 14 0.5683160424232483
29 15 0.5490044951438904
29 16 0.7003545761108398
29 17 0.6629757285118103
29 18 0.624771237373352
29 19 0.7194443345069885
29 20 0.568182110786438
29 21 0.605324923992157
29 22 0.5861515402793884
29 23 0.6458543539047241
30 0 0.6058372259140015
30 1 0.7202555537223816
30 2 0.6056992411613464
30 3 0.6437218189239502
30 4 0.6244834065437317
30 5 0.6434194445610046
30 6 0.6059294939041138
30 7 0.6429349184036255
30 8 0.6439540386199951
30 9 0.6441383361816406
30 10 0.624085009098053
30 11 0.643824577331543
30 12 0.5664442181587219
30 13 0.6242425441741943
30 14 0.741248369216919
30 15 0.6243743300437927
30 16 0.740620493888855
30 1

43 0 0.6635510921478271
43 1 0.6824476718902588
43 2 0.5842341184616089
43 3 0.6238490343093872
43 4 0.6041738390922546
43 5 0.6831538677215576
43 6 0.6437222957611084
43 7 0.603900671005249
43 8 0.6629625558853149
43 9 0.6829882264137268
43 10 0.7223138213157654
43 11 0.7030526995658875
43 12 0.6830475330352783
43 13 0.5853527188301086
43 14 0.6824307441711426
43 15 0.6246436834335327
43 16 0.6826637387275696
43 17 0.5852360725402832
43 18 0.6636213660240173
43 19 0.6435675621032715
43 20 0.6436395049095154
43 21 0.6823239326477051
43 22 0.5461636781692505
43 23 0.5912072658538818
44 0 0.6238341927528381
44 1 0.722094714641571
44 2 0.6242032647132874
44 3 0.6828252673149109
44 4 0.5650762915611267
44 5 0.6237279176712036
44 6 0.5641751289367676
44 7 0.6833631992340088
44 8 0.6830126643180847
44 9 0.5848872065544128
44 10 0.643239438533783
44 11 0.6434194445610046
44 12 0.6042174100875854
44 13 0.624190628528595
44 14 0.7629789113998413
44 15 0.6044543981552124
44 16 0.584425687789917


57 0 0.5842190980911255
57 1 0.6631737947463989
57 2 0.6040185689926147
57 3 0.6833722591400146
57 4 0.5649104714393616
57 5 0.6825505495071411
57 6 0.6635209918022156
57 7 0.6044628024101257
57 8 0.7431087493896484
57 9 0.683786928653717
57 10 0.6434650421142578
57 11 0.7025197744369507
57 12 0.6634595394134521
57 13 0.6234794855117798
57 14 0.6249070167541504
57 15 0.6826573610305786
57 16 0.6430912017822266
57 17 0.7416062951087952
57 18 0.5653862953186035
57 19 0.6436870098114014
57 20 0.6430850625038147
57 21 0.6434029340744019
57 22 0.5451500415802002
57 23 0.6461490392684937
58 0 0.6633393168449402
58 1 0.7029664516448975
58 2 0.6439440846443176
58 3 0.6623761653900146
58 4 0.54549241065979
58 5 0.7231497764587402
58 6 0.5255031585693359
58 7 0.7224588394165039
58 8 0.6238040328025818
58 9 0.6232669949531555
58 10 0.6628192663192749
58 11 0.6236705780029297
58 12 0.7023630142211914
58 13 0.7412877678871155
58 14 0.5851659774780273
58 15 0.604647159576416
58 16 0.6248214840888977

71 0 0.6430106163024902
71 1 0.6633674502372742
71 2 0.6035350561141968
71 3 0.6636115908622742
71 4 0.6238920092582703
71 5 0.6237260103225708
71 6 0.6232953667640686
71 7 0.5833513140678406
71 8 0.7833020687103271
71 9 0.7028246521949768
71 10 0.6237654685974121
71 11 0.6438626050949097
71 12 0.6631850004196167
71 13 0.5642468929290771
71 14 0.6436769962310791
71 15 0.6634532809257507
71 16 0.6637519598007202
71 17 0.6639368534088135
71 18 0.6631354093551636
71 19 0.6833487153053284
71 20 0.5842770338058472
71 21 0.5839699506759644
71 22 0.6832739114761353
71 23 0.6462216377258301
72 0 0.6843181848526001
72 1 0.5446128845214844
72 2 0.6631426215171814
72 3 0.5640281438827515
72 4 0.683447539806366
72 5 0.6432517170906067
72 6 0.7040427923202515
72 7 0.6625968217849731
72 8 0.6635515689849854
72 9 0.6244546175003052
72 10 0.6035356521606445
72 11 0.6830818057060242
72 12 0.6437922120094299
72 13 0.6440112590789795
72 14 0.6227874755859375
72 15 0.6632344722747803
72 16 0.6829895377159

85 0 0.7625319361686707
85 1 0.6243950128555298
85 2 0.623111367225647
85 3 0.6242055892944336
85 4 0.7417809963226318
85 5 0.6035911440849304
85 6 0.682896077632904
85 7 0.603614091873169
85 8 0.6233822703361511
85 9 0.6434794664382935
85 10 0.6237953305244446
85 11 0.6032695770263672
85 12 0.6037017107009888
85 13 0.7031782269477844
85 14 0.6836403608322144
85 15 0.6239422559738159
85 16 0.6832551956176758
85 17 0.7027680277824402
85 18 0.5648014545440674
85 19 0.6439188718795776
85 20 0.524949848651886
85 21 0.6244028806686401
85 22 0.7230184078216553
85 23 0.6460736989974976
86 0 0.6834201812744141
86 1 0.5837081074714661
86 2 0.6631879806518555
86 3 0.6440489888191223
86 4 0.6235712766647339
86 5 0.7031030654907227
86 6 0.7618306875228882
86 7 0.6621120572090149
86 8 0.7014575600624084
86 9 0.5850570201873779
86 10 0.5650944709777832
86 11 0.7229035496711731
86 12 0.6433461904525757
86 13 0.6244759559631348
86 14 0.6043199896812439
86 15 0.6040892004966736
86 16 0.6038960814476013

99 0 0.7024433016777039
99 1 0.5849878787994385
99 2 0.6239038705825806
99 3 0.7234305143356323
99 4 0.7219876646995544
99 5 0.6236425042152405
99 6 0.6633642911911011
99 7 0.584496259689331
99 8 0.682931661605835
99 9 0.6822859048843384
99 10 0.6630341410636902
99 11 0.5450852513313293
99 12 0.6238795518875122
99 13 0.6834725141525269
99 14 0.6238481998443604
99 15 0.6041813492774963
99 16 0.6433798670768738
99 17 0.6823542714118958
99 18 0.5841881036758423
99 19 0.6425437927246094
99 20 0.6633365750312805
99 21 0.70279860496521
99 22 0.6237364411354065
99 23 0.5917839407920837


In [8]:
print("w = ", model.linear1.weight.data.numpy())
print("w = ", model.linear2.weight.data.numpy())
print("w = ", model.linear3.weight.data)

w =  [[-0.3243042   0.15349141  0.17779218 -0.30270377  0.1634415   0.20683688
   0.00837262  0.32740536]
 [ 0.17377017 -0.25876588 -0.20830365 -0.05947145  0.01778396 -0.0295886
   0.3337342   0.13339478]
 [-0.16289759  0.07391965  0.02356913  0.04065726 -0.17856693 -0.3166271
   0.33593374  0.30980033]
 [ 0.32775304 -0.3333064  -0.13822295 -0.11354297  0.33300266 -0.09909375
  -0.12502661 -0.02732395]
 [ 0.10412861  0.04759761 -0.058348    0.3100948   0.12203769 -0.18845658
   0.2425685  -0.03624997]
 [-0.00563358  0.25217718 -0.03842327  0.18896016 -0.01983326  0.2128165
   0.07974286 -0.07076398]]
w =  [[ 0.38850346 -0.14054058  0.36292648 -0.23136075 -0.22474015 -0.3347486 ]
 [ 0.07521944 -0.18067504  0.30370682 -0.2414233   0.13490742 -0.16061905]
 [ 0.22961177  0.40265778  0.00484482  0.04473725 -0.34106147  0.3373315 ]
 [ 0.2478347  -0.28179342 -0.35227597  0.3888513  -0.36498448  0.0692066 ]]
w =  tensor([[0.2838, 0.1832, 0.4440, 0.4114]])
