<a href="https://colab.research.google.com/github/BarryLiu-97/Pytorch-Tutorial/blob/master/06_Dataset_and_DataLoader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

使用minibatch进行训练

梯度下降可以使用的方法：
- 所有数据全部都用：利用CPU和GPU的并行计算能力，计算速度较高
- 随机梯度下降：只用一个样本，随机性较好，能克服鞍点的问题
- MiniBatch: 随机性较好的同时运算速度也较高

In [5]:
import torch
from torch.utils.data import Dataset      #Dataset是一个抽象类，抽象类是不能实例化的，只能被其他子类继承
from torch.utils.data import DataLoader     #Pytorch帮助我们加载数据的类
import numpy as np

In [3]:
from google.colab import files
files.upload() 

Saving diabetes.csv to diabetes.csv


{'diabetes.csv': b'-0.294118,0.487437,0.180328,-0.292929,0,0.00149028,-0.53117,-0.0333333,0\n-0.882353,-0.145729,0.0819672,-0.414141,0,-0.207153,-0.766866,-0.666667,1\n-0.0588235,0.839196,0.0491803,0,0,-0.305514,-0.492741,-0.633333,0\n-0.882353,-0.105528,0.0819672,-0.535354,-0.777778,-0.162444,-0.923997,0,1\n0,0.376884,-0.344262,-0.292929,-0.602837,0.28465,0.887276,-0.6,0\n-0.411765,0.165829,0.213115,0,0,-0.23696,-0.894962,-0.7,1\n-0.647059,-0.21608,-0.180328,-0.353535,-0.791962,-0.0760059,-0.854825,-0.833333,0\n0.176471,0.155779,0,0,0,0.052161,-0.952178,-0.733333,1\n-0.764706,0.979899,0.147541,-0.0909091,0.283688,-0.0909091,-0.931682,0.0666667,0\n-0.0588235,0.256281,0.57377,0,0,0,-0.868488,0.1,0\n-0.529412,0.105528,0.508197,0,0,0.120715,-0.903501,-0.7,1\n0.176471,0.688442,0.213115,0,0,0.132638,-0.608027,-0.566667,0\n0.176471,0.396985,0.311475,0,0,-0.19225,0.163962,0.2,1\n-0.882353,0.899497,-0.0163934,-0.535354,1,-0.102832,-0.726729,0.266667,0\n-0.176471,0.00502513,0,0,0,-0.105812,-0.6

In [6]:
class DiabetesDataset(Dataset):
  def __init__(self, filepath):
    diabetes = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
    self.len = diabetes.shape[0]     #数据集有多少个样本，就是有几行数据
    self.x_data = torch.from_numpy(diabetes[:,:-1])
    self.y_data = torch.from_numpy(diabetes[:,[-1]])

  def __getitem__(self, index):
    """
    使得实例化DiabetesDataset之后能够进行下标操作。dataset[index]
    """
    return self.x_data[index], self.y_data[index]   #  return x,y 表示返回的是(x, y)这样的元组

  def __len__(self):
    """
    返回数据条数
    """
    return self.len

dataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32,
              shuffle=True, num_workers=2)   #num_workers表示读取数据的时候是否要多进程，几个进程进行读取

In [7]:
class Model(torch.nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.linear1 = torch.nn.Linear(8, 6)
    self.linear2 = torch.nn.Linear(6, 4)
    self.linear3 = torch.nn.Linear(4, 1)
    self.sigmoid = torch.nn.Sigmoid()  #在此处可以更改激活函数

  def forward(self, x):
    x = self.sigmoid(self.linear1(x))
    x = self.sigmoid(self.linear2(x))
    x = self.sigmoid(self.linear3(x))  #relu会使得小于0的值输出为0，会有风险，所以最后一层一般不用relu
    return x

model = Model()

criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)



In [8]:
for epoch in range(100):
  for i, data in enumerate(train_loader, 0):
    #1. Prepare data
    inputs, labels = data   #都是张量
    #2. Forward
    y_pred = model(inputs)
    loss = criterion(y_pred, labels)
    print(epoch, i, loss.item())
    #3. Backward
    optimizer.zero_grad()
    #4. Update
    optimizer.step()

0 0 0.6656588315963745
0 1 0.6909439563751221
0 2 0.6709595322608948
0 3 0.6806279420852661
0 4 0.680678129196167
0 5 0.665860652923584
0 6 0.675305962562561
0 7 0.6856374144554138
0 8 0.6756274700164795
0 9 0.671160876750946
0 10 0.6862059831619263
0 11 0.6609906554222107
0 12 0.6908774375915527
0 13 0.6910154819488525
0 14 0.6506491899490356
0 15 0.6508225202560425
0 16 0.6604098677635193
0 17 0.6557705402374268
0 18 0.6559960246086121
0 19 0.6706719398498535
0 20 0.6651352643966675
0 21 0.6757704019546509
0 22 0.6598924994468689
0 23 0.671518862247467
1 0 0.680553674697876
1 1 0.691001296043396
1 2 0.6602373719215393
1 3 0.6808112263679504
1 4 0.6906570196151733
1 5 0.6708197593688965
1 6 0.6760855317115784
1 7 0.6761552691459656
1 8 0.6655197739601135
1 9 0.6607648134231567
1 10 0.6611604690551758
1 11 0.6705379486083984
1 12 0.6710931062698364
1 13 0.7110304236412048
1 14 0.6709049940109253
1 15 0.6706440448760986
1 16 0.6452445983886719
1 17 0.6605185270309448
1 18 0.675787270069