In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms

In [2]:
batch_size=200
learning_rate=0.01
epochs=10

# 加载数据
train_loader=torch.utils.data.DataLoader(
    datasets.MNIST('./data/mnist_data',train=True,download=True,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ])),
    batch_size=batch_size,shuffle=True
)
test_loader=torch.utils.data.DataLoader(
    datasets.MNIST('./data/mnist_data',train=False,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ])),
    batch_size=batch_size,shuffle=True
)

In [3]:
class MLP(nn.Module):
    
    def __init__(self):
        super(MLP,self).__init__()
        
        self.model=nn.Sequential(
            nn.Linear(784,200),
            nn.ReLU(inplace=True),
            nn.Linear(200,200),
            nn.ReLU(inplace=True),
            nn.Linear(200,10),
            nn.ReLU(inplace=True),
        )
        
    def forward(self,x):
        x=self.model(x)
        return x

In [4]:
device=torch.device('cuda:0')
net=MLP().to(device)
optimizer=optim.SGD(net.parameters(),lr=learning_rate)
criteon=nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):
    
    for batch_idx,(data,target) in enumerate(train_loader):
        data=data.view(-1,28*28)
        data,target=data.to(device),target.to(device)
        
        logits=net(data)
        loss=criteon(logits,target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 ==0: 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            
    test_loss=0 
    correct=0 
    for data,target in test_loader:
        data=data.view(-1,28*28)
        data,target=data.to(device),target.to(device)
        
        logits=net(data)
        test_loss+=criteon(logits,target).item()
        
        pred=logits.argmax(dim=1)
        correct+=pred.eq(target.data).sum()
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


Test set: Average loss: 0.0074, Accuracy: 5578/10000 (56%)


Test set: Average loss: 0.0058, Accuracy: 6120/10000 (61%)


Test set: Average loss: 0.0053, Accuracy: 6485/10000 (65%)


Test set: Average loss: 0.0051, Accuracy: 6674/10000 (67%)


Test set: Average loss: 0.0039, Accuracy: 7467/10000 (75%)




Test set: Average loss: 0.0038, Accuracy: 7737/10000 (77%)


Test set: Average loss: 0.0037, Accuracy: 7836/10000 (78%)


Test set: Average loss: 0.0036, Accuracy: 7923/10000 (79%)


Test set: Average loss: 0.0035, Accuracy: 7982/10000 (80%)


Test set: Average loss: 0.0035, Accuracy: 8052/10000 (81%)

