In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms

In [2]:
batch_size=200
learning_rate=0.01
epochs=10

train_loader=DataLoader(
    datasets.MNIST('./data/mnist_data',train=True,download=True,
                    transform=transforms.Compose([
                       transforms.RandomHorizontalFlip(),
                       transforms.RandomVerticalFlip(),
#                        transforms.RandomRotation(15), # 会改变dim???
#                        transforms.RandomRotation([90, 180, 270]),
                       transforms.Resize([32, 32]),
                       transforms.RandomCrop([28, 28]),
                       transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ])),
    batch_size=batch_size,shuffle=True
)

test_loader=DataLoader(
    datasets.MNIST('./data/mnist_data',train=False,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ])),
    batch_size=batch_size,shuffle=True
)

In [3]:
class MLP(nn.Module):
    
    def __init__(self):
        super(MLP,self).__init__()
        
        self.model=nn.Sequential(
            nn.Linear(784,200),
            nn.ReLU(inplace=True),
            nn.Linear(200,200),
            nn.ReLU(inplace=True),
            nn.Linear(200,10),
            nn.ReLU(inplace=True),
        )
        
    def forward(self,x):
        x=self.model(x)
        return x

In [4]:
device=torch.device('cuda:0')
net=MLP().to(device)
optimizer=optim.SGD(net.parameters(),lr=learning_rate)
criteon=nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):
    
    for batch_idx,(data,target) in enumerate(train_loader):
        data=data.view(-1,28*28)
        data,target=data.to(device),target.to(device)
        
        logits=net(data)
        loss=criteon(logits,target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 ==0: 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            
    test_loss=0 
    correct=0 
    for data,target in test_loader:
        data=data.view(-1,28*28)
        data,target=data.to(device),target.to(device)
        
        logits=net(data)
        test_loss+=criteon(logits,target).item()
        
        pred=logits.argmax(dim=1)
        correct+=pred.eq(target.data).sum()
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


Test set: Average loss: 0.0100, Accuracy: 3242/10000 (32%)


Test set: Average loss: 0.0091, Accuracy: 3443/10000 (34%)


Test set: Average loss: 0.0090, Accuracy: 3464/10000 (35%)


Test set: Average loss: 0.0091, Accuracy: 3293/10000 (33%)


Test set: Average loss: 0.0088, Accuracy: 3435/10000 (34%)




Test set: Average loss: 0.0082, Accuracy: 3765/10000 (38%)


Test set: Average loss: 0.0078, Accuracy: 4107/10000 (41%)


Test set: Average loss: 0.0064, Accuracy: 5304/10000 (53%)


Test set: Average loss: 0.0057, Accuracy: 5918/10000 (59%)


Test set: Average loss: 0.0052, Accuracy: 6191/10000 (62%)

