In [26]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batchsize=128
# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='./mnist/', train=False, download=True,transform=transforms.ToTensor()),batch_size=batchsize)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='./mnist/', train=False, download=False,transform=transforms.ToTensor()),batch_size=batchsize)


### 按论文中P6的TABLE 1 构建

In [56]:
class MNIST_Model(nn.Module):
    def __init__(self):
        super(MNIST_Model, self).__init__()  
        self.conv1 = nn.Conv2d(1, 32, 3,)
        self.conv2 = nn.Conv2d(32, 32, 3)
        self.conv3 = nn.Conv2d(32, 64, 3)
        self.conv4 = nn.Conv2d(64, 64, 3)
        self.max_pooling1=torch.nn.MaxPool2d(2,2)
        self.max_pooling2=torch.nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(4*4*64, 200) 
        self.fc2 = nn.Linear(200, 10) # 10分类
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.max_pooling1(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.max_pooling2(x)
        
        x = x.permute((0, 2, 3, 1))
        x = x.contiguous().view(-1, 4 * 4 * 64)
        x = x.view(-1, 4 * 4 * 64)
        
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5)
        x = self.fc2(x)
        x = F.relu(x)
#         x = F.log_softmax(x,dim=1)
        return x

model = MNIST_Model().cuda()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9,nesterov=True)

In [57]:
criteria = nn.CrossEntropyLoss()
def train(epoch):
    train_loss = 0
    train_correct = 0
    bs=10
    model.train()
    for batch_index, (data, target) in enumerate(train_loader):
        data, target = Variable(data).cuda(), Variable(target).cuda()
        optimizer.zero_grad()
        output = model(data)
#         loss = F.nll_loss(output, target)
        loss = criteria(output, target.long())
        loss.backward()
        optimizer.step()
        #记录当前训练集的结果
        pred = output.argmax(dim=1, keepdim=True)
        train_correct += pred.eq(target.long().view_as(pred)).sum().item()
        # 显示当前轮的
        if batch_index % bs == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}'.format(
                epoch, batch_index, len(train_loader),
                100.0 * batch_index / len(train_loader), loss.item()/ bs))
    print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        train_loss/len(train_loader.dataset), train_correct, len(train_loader.dataset),
        100.0 * train_correct / len(train_loader.dataset)))
    
def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            test_loss += criteria(output, target.long()).item()
            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [58]:
total_epoch = 50
for epoch in range(1, total_epoch+1):
    train(epoch)
    test()


Train set: Average loss: 0.0000, Accuracy: 4432/10000 (44%)


Test set: Average loss: 0.0041, Accuracy: 8573/10000 (86%)


Train set: Average loss: 0.0000, Accuracy: 9213/10000 (92%)


Test set: Average loss: 0.0021, Accuracy: 9230/10000 (92%)


Train set: Average loss: 0.0000, Accuracy: 9504/10000 (95%)


Test set: Average loss: 0.0015, Accuracy: 9464/10000 (95%)


Train set: Average loss: 0.0000, Accuracy: 9646/10000 (96%)


Test set: Average loss: 0.0024, Accuracy: 9144/10000 (91%)


Train set: Average loss: 0.0000, Accuracy: 9682/10000 (97%)


Test set: Average loss: 0.0009, Accuracy: 9639/10000 (96%)


Train set: Average loss: 0.0000, Accuracy: 9766/10000 (98%)


Test set: Average loss: 0.0005, Accuracy: 9780/10000 (98%)


Train set: Average loss: 0.0000, Accuracy: 9805/10000 (98%)


Test set: Average loss: 0.0007, Accuracy: 9766/10000 (98%)


Train set: Average loss: 0.0000, Accuracy: 9785/10000 (98%)


Test set: Average loss: 0.0004, Accuracy: 9861/10000 (99%)


Train set: Aver


Train set: Average loss: 0.0000, Accuracy: 9945/10000 (99%)


Test set: Average loss: 0.0002, Accuracy: 9926/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9923/10000 (99%)


Test set: Average loss: 0.0002, Accuracy: 9938/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9924/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9939/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9937/10000 (99%)


Test set: Average loss: 0.0002, Accuracy: 9929/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9948/10000 (99%)


Test set: Average loss: 0.0003, Accuracy: 9895/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9910/10000 (99%)


Test set: Average loss: 0.0002, Accuracy: 9939/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9932/10000 (99%)


Test set: Average loss: 0.0002, Accuracy: 9940/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9946/10000 (99%)


Test set: Average loss: 0.0002, Accuracy: 9931/10000 (99%)


Train set: Aver


Train set: Average loss: 0.0000, Accuracy: 9961/10000 (100%)


Test set: Average loss: 0.0002, Accuracy: 9928/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9919/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9947/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9959/10000 (100%)


Test set: Average loss: 0.0001, Accuracy: 9946/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9932/10000 (99%)


Test set: Average loss: 0.0004, Accuracy: 9847/10000 (98%)


Train set: Average loss: 0.0000, Accuracy: 9950/10000 (100%)


Test set: Average loss: 0.0001, Accuracy: 9970/10000 (100%)


Train set: Average loss: 0.0000, Accuracy: 9957/10000 (100%)


Test set: Average loss: 0.0002, Accuracy: 9944/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9955/10000 (100%)


Test set: Average loss: 0.0002, Accuracy: 9936/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9948/10000 (99%)


Test set: Average loss: 0.0004, Accuracy: 9903/10000 (99%)


Train set


Train set: Average loss: 0.0000, Accuracy: 9940/10000 (99%)


Test set: Average loss: 0.0002, Accuracy: 9941/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9949/10000 (99%)


Test set: Average loss: 0.0002, Accuracy: 9946/10000 (99%)


Train set: Average loss: 0.0000, Accuracy: 9961/10000 (100%)


Test set: Average loss: 0.0001, Accuracy: 9958/10000 (100%)


Train set: Average loss: 0.0000, Accuracy: 9985/10000 (100%)


Test set: Average loss: 0.0000, Accuracy: 9978/10000 (100%)



In [59]:
torch.save(model.state_dict(), "mnist.pt")