In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets,transforms

In [2]:
class LeNet5(nn.Module):
    def __init__(self):            
        super(LeNet5,self).__init__()
        self.conv1=nn.Conv2d(1,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*4*4,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
        
    def forward(self,x):
        x=F.max_pool2d(F.relu(self.conv1(x)),2)
        x=F.max_pool2d(F.relu(self.conv2(x)),2)
        x=x.view(x.size(0),-1)
        #x=x.view(-1,self.num_feature(x))
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x
        
    def num_feature(x):
        size=x.size()[1:] #除 了batch维度外的size
        num=1
        for s in size:
            num*=s
        return num

net=LeNet5()

In [3]:
net

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [4]:
for para in net.parameters():
    print(para)
    break

Parameter containing:
tensor([[[[-0.1226,  0.1347, -0.0456,  0.0330,  0.0242],
          [-0.1411,  0.0318,  0.1375,  0.0863,  0.0974],
          [ 0.0476,  0.1202,  0.0639, -0.1113, -0.0474],
          [ 0.1993,  0.1792, -0.1009,  0.1733, -0.0313],
          [-0.0106,  0.1715, -0.0557,  0.0424,  0.0604]]],


        [[[ 0.0793,  0.1044, -0.1555,  0.0484, -0.1154],
          [-0.0040, -0.1498,  0.0778,  0.0688, -0.0038],
          [ 0.0222,  0.1041,  0.0944, -0.0759,  0.1170],
          [ 0.0897, -0.1792, -0.0655,  0.1935, -0.0540],
          [-0.0044,  0.1258, -0.1332,  0.1713,  0.0547]]],


        [[[-0.1979, -0.0432, -0.0947, -0.1417,  0.1269],
          [ 0.1446,  0.0166,  0.0835, -0.1418, -0.1900],
          [-0.1536,  0.1150,  0.0581,  0.0904, -0.1142],
          [ 0.0632, -0.0834, -0.1349,  0.1824, -0.1178],
          [ 0.1205,  0.0644,  0.0648,  0.0111, -0.1547]]],


        [[[ 0.1830,  0.1759,  0.1250, -0.0054,  0.1459],
          [ 0.0627,  0.0394,  0.0658, -0.1398,  0.1930

In [5]:
train_set = datasets .MNIST('./data', train=True, download=True)
test_set = datasets .MNIST('./data', train=False, download=True)

In [6]:
import numpy as np
def data_tf(x):
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5 # 标准化，这个技巧之后会讲到
    x = transforms.ToTensor()(x)  
    return x

In [7]:
train_set = datasets .MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集，申明定义的数据变换
test_set = datasets .MNIST('./data', train=False, transform=data_tf, download=True)
print(train_set[0][0].size())

torch.Size([1, 28, 28])


In [8]:
from torch.utils.data import DataLoader
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降，学习率 0.1

In [10]:
def get_acc(output, label):
    total = output.shape[0]    #output是二维的，零维度是batch，一维度是每个图片的预测标签
    _, pred_label = output.max(1)   #output.max(1)结果会返output一维度标签中最大的值和最大标签在一维度中的索引，这里只需用到索引
    num_correct = (pred_label == label).sum().item()    #.sum把前面那个张量中所有值加了起来，变成0维tensor，.item（）把它变为数值
    return num_correct / total

def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    if torch.cuda.is_available():
        net = net.cuda()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for batchnum,(im, label) in enumerate(train_data):
            if torch.cuda.is_available():
                im = Variable(im.cuda())  
                label = Variable(label.cuda())  
            else:
                im = Variable(im)
                label = Variable(label)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 计算误差
            train_loss += loss.item()
            train_acc += get_acc(output, label)

        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for batchnum,(im, label) in enumerate(valid_data):
                if torch.cuda.is_available():
                    im = Variable(im.cuda(), volatile=True)
                    label = Variable(label.cuda(), volatile=True)
                else:
                    im = Variable(im, volatile=True)
                    label = Variable(label, volatile=True)
                # forward
                output = net(im)
                loss = criterion(output, label)
                # 计算误差
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, " % (epoch, train_loss / len(train_data),
                   train_acc / len(train_data), valid_loss / len(valid_data),valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %(epoch, train_loss / len(train_data),train_acc / len(train_data)))
        
        print(epoch_str)

In [11]:
train(net,train_data,test_data,1,optimizer,criterion)



Epoch 0. Train Loss: 0.397717, Train Acc: 0.864939, Valid Loss: 0.091704, Valid Acc: 0.970728, 


In [12]:
for para in net.parameters():
    print(para)
    break

Parameter containing:
tensor([[[[-3.2593e-01,  1.9140e-01,  2.0145e-01,  1.7434e-01, -2.1830e-02],
          [-2.3530e-01,  3.0155e-01,  4.6830e-01,  1.3936e-01, -6.0559e-02],
          [ 1.6782e-01,  5.4521e-01,  3.0777e-01, -2.5025e-01, -3.0329e-01],
          [ 4.3497e-01,  5.1621e-01, -6.6410e-02, -5.9770e-02, -3.0931e-01],
          [ 2.5330e-01,  3.8484e-01, -1.1651e-01, -1.9185e-01, -2.1075e-01]]],


        [[[ 4.3469e-02,  7.3512e-02, -2.3307e-01, -1.1905e-01, -3.1038e-01],
          [ 2.9152e-02, -8.7453e-02,  9.1925e-02, -4.1822e-02, -1.6014e-01],
          [ 8.4461e-02,  2.4220e-01,  2.6914e-01, -2.4191e-04,  1.1679e-01],
          [ 5.9223e-02, -1.7235e-01,  3.2607e-02,  3.3406e-01,  1.9349e-02],
          [-1.3334e-01,  2.7348e-02, -1.9602e-01,  1.6883e-01,  7.4229e-02]]],


        [[[-1.6284e-01, -9.1872e-02, -1.8384e-01, -2.1426e-01,  3.0583e-02],
          [ 1.3682e-01, -1.0884e-01, -8.0079e-02, -2.3544e-01, -2.7675e-01],
          [-6.1871e-02,  3.3529e-02, -1.0771e-