Pytorch复现GoogleNet，对minist数据集进行分类  
- GoogleNet是一个并行结构

In [47]:
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Prepare Dataset

In [48]:
batch_size=64

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

train_dataset=datasets.MNIST(root='./dataset/mnist',
                            train=True,
                            download=True,
                            transform=transform)
train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size) #训练集的batch打乱，制造随机性

test_dataset=datasets.MNIST(root='./dataset/mnist',
                           train=False,
                           download=True,
                           transform=transform)
test_loader=DataLoader(test_dataset,shuffle=False,batch_size=batch_size)

<font color='red'>画图出现了问题，不知为什么</font>

# Design Model

<img src="https://markdown-yqguo.oss-cn-beijing.aliyuncs.com/markdown-yqguo/image-20220206065653861.png" alt="image-20220206065653861" style="zoom:25%;" />

In [49]:
class InceptionA(nn.Module): #非线性结构
    def __init__(self,in_channels):
        super(InceptionA,self).__init__()
        self.branch_pool=nn.Conv2d(in_channels,24,kernel_size=1)
        
        self.branch1x1=nn.Conv2d(in_channels,16,kernel_size=1)
        
        self.branch5x5_1=nn.Conv2d(in_channels,16,kernel_size=1)
        self.branch5x5_2=nn.Conv2d(16,24,kernel_size=5,padding=2)
        
        self.branch3x3_1=nn.Conv2d(in_channels,16,kernel_size=1)
        self.branch3x3_2=nn.Conv2d(16,24,kernel_size=3,padding=1)
        self.branch3x3_3=nn.Conv2d(24,24,kernel_size=3,padding=1)
        
    def forward(self,x):
        branch_pool=F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)
        branch_pool=self.branch_pool(branch_pool)
        
        branch1x1=self.branch1x1(x)
        
        branch5x5=self.branch5x5_1(x)
        branch5x5=self.branch5x5_2(branch5x5)
        
        branch3x3=self.branch3x3_1(x)
        branch3x3=self.branch3x3_2(branch3x3)
        branch3x3=self.branch3x3_3(branch3x3)
        
        outputs=[branch_pool,branch1x1,branch5x5,branch3x3]
        
        return torch.cat(outputs,dim=1) #batch_size,channel,w,h

In [50]:
class Net(nn.Module): #线性结构
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,10,kernel_size=5)
        self.conv2=nn.Conv2d(88,20,kernel_size=5)
        
        self.incept1=InceptionA(in_channels=10)
        self.incept2=InceptionA(in_channels=20)
        
        self.mp=nn.MaxPool2d(2)
        self.fc=nn.Linear(1408,10)
        
    def forward(self,x):
        in_size=x.size(0) #batch中元素的个数
        
        x=F.relu(self.mp(self.conv1(x)))
        x=self.incept1(x) #没有加激活函数
        x=F.relu(self.mp(self.conv2(x)))
        x=self.incept2(x)
        
        x=x.view(in_size,-1)
        #print(x.size())
        x=self.fc(x) #后面不用加激活函数，因为CrossEntropyLoss自带softmax损失函数
        return x

In [51]:
#计算全连接层的输入应该，输入多少维
net=Net()
data_loader=iter(train_loader)
images,labels=data_loader.next()
net(images) #torch.Size([64, 1408])

tensor([[ 1.4115e-01,  4.1917e-02,  1.2853e-03,  1.1810e-01, -1.2072e-01,
         -1.5185e-02, -1.1001e-02,  8.5034e-02, -7.9355e-02,  4.2557e-02],
        [ 1.2360e-01,  4.0648e-02,  8.6135e-03,  1.2012e-01, -1.0114e-01,
         -8.6694e-03, -2.0209e-02,  8.7965e-02, -7.9593e-02,  7.1254e-02],
        [ 1.4153e-01,  1.6969e-02,  2.5939e-02,  1.0123e-01, -1.3274e-01,
         -2.0148e-02,  1.2811e-02,  8.4670e-02, -8.9970e-02,  5.6658e-02],
        [ 1.3901e-01,  4.7550e-02,  1.8580e-02,  1.2406e-01, -1.2049e-01,
         -1.7718e-02,  2.2021e-02,  9.3100e-02, -8.3868e-02,  5.3675e-02],
        [ 1.3026e-01,  6.6438e-02,  3.0532e-02,  1.0511e-01, -1.1871e-01,
         -1.4563e-02,  1.7048e-02,  8.8765e-02, -8.1581e-02,  3.8888e-02],
        [ 1.4007e-01,  2.5110e-02,  1.4264e-02,  1.0755e-01, -1.3827e-01,
          6.4618e-03, -4.4953e-03,  9.3038e-02, -8.9960e-02,  3.7366e-02],
        [ 1.5370e-01,  4.0888e-02, -5.1293e-03,  1.1385e-01, -1.2856e-01,
         -1.8696e-02, -1.0020e-0

In [59]:
model = Net()

# Construct loss and optimizer

In [60]:
criterion = nn.CrossEntropyLoss() #损失函数
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

# Training cycle

In [61]:
def train(epoch):
    running_loss = 0.0
    for batch_idx,data in enumerate(train_loader,0):
        images, targets = data
        optimizer.zero_grad() #梯度清零
        
        outputs = model(images)
        loss = criterion(outputs,targets)
        loss.backward() #反向传播求梯度
        optimizer.step() #梯度下降法更新参数
        
        running_loss += loss.item()
        if batch_idx%299 == 0:
            print('[%d, %5d] loss: %.3f' % (epoch+1,batch_idx+1,running_loss/300))
            running_loss = 0.0

In [62]:
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, targets = data
            outputs = model(images)
            _, predicted = torch.max(outputs, dim=1) #行是batch中的每个样本，列是每种类别的预测值
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    print('Accuracy on test set: %d %%' % (100*correct/total))

In [63]:
a=torch.rand((2,3))
print(a)
val,index=torch.max(a,dim=1)
print(index,val)

tensor([[0.9237, 0.9674, 0.8925],
        [0.1451, 0.8994, 0.5883]])
tensor([1, 1]) tensor([0.9674, 0.8994])


In [64]:
for epoch in range(10):
    train(epoch)
    test()

[1,     1] loss: 0.008
[1,   300] loss: 0.964
[1,   599] loss: 0.212
[1,   898] loss: 0.151
Accuracy on test set: 96 %
[2,     1] loss: 0.000
[2,   300] loss: 0.116
[2,   599] loss: 0.094
[2,   898] loss: 0.097
Accuracy on test set: 97 %
[3,     1] loss: 0.000
[3,   300] loss: 0.080
[3,   599] loss: 0.072
[3,   898] loss: 0.072
Accuracy on test set: 98 %
[4,     1] loss: 0.000
[4,   300] loss: 0.068
[4,   599] loss: 0.066
[4,   898] loss: 0.056
Accuracy on test set: 98 %
[5,     1] loss: 0.000
[5,   300] loss: 0.058
[5,   599] loss: 0.052
[5,   898] loss: 0.057
Accuracy on test set: 98 %
[6,     1] loss: 0.000
[6,   300] loss: 0.050
[6,   599] loss: 0.047
[6,   898] loss: 0.050
Accuracy on test set: 98 %
[7,     1] loss: 0.000
[7,   300] loss: 0.044
[7,   599] loss: 0.043
[7,   898] loss: 0.047
Accuracy on test set: 98 %
[8,     1] loss: 0.000
[8,   300] loss: 0.039
[8,   599] loss: 0.042
[8,   898] loss: 0.040
Accuracy on test set: 98 %
[9,     1] loss: 0.000
[9,   300] loss: 0.038
[9