In [36]:
import  time

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
from    torch.autograd import Variable

import  torchvision.datasets as dsets
import  torchvision.transforms as trans


In [37]:
def conv5x5(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, bias=False)

In [38]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.feature = nn.Sequential(
          nn.Conv2d(1,10,5,padding=2),
          nn.AvgPool2d(2),
          nn.ReLU(),
          nn.Conv2d(10,20,5,padding=2),
          nn.AvgPool2d(2),
          nn.ReLU())
        self.classifier = nn.Sequential(
          nn.Linear(7*7*20,500),
          nn.ReLU(),
          nn.Linear(500,10))
        
    def forward(self, x):
        o = self.feature(x)
        o = o.view(x.size(0),-1)
        o = self.classifier(o)
        return o

In [39]:
net = LeNet()

# 检查是否支持MPS并使用MPS
if torch.backends.mps.is_available():
    net = net.to('mps')

print(net)

LeNet(
  (feature): Sequential(
    (0): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (2): ReLU()
    (3): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): ReLU()
  )
  (classifier): Sequential(
    (0): Linear(in_features=980, out_features=500, bias=True)
    (1): ReLU()
    (2): Linear(in_features=500, out_features=10, bias=True)
  )
)


In [40]:
batch_size = 1000
train_set = dsets.MNIST(root = '../data/',transform=trans.ToTensor(),train=True,download=True)
test_set = dsets.MNIST(root='../data/',transform=trans.ToTensor(),train=False)
train_dl = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,shuffle=True)
test_dl = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,shuffle=False)


In [41]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
nepochs = 30

In [42]:
def eval(model,criterion,dataloader):
    loss = 0
    accuracy = 0
    for batch_x,batch_y in dataloader:
        batch_x = Variable(batch_x)
        batch_y = Variable(batch_y)
        if torch.backends.mps.is_available():
            batch_x = batch_x.to('mps')
            batch_y = batch_y.to('mps')
            
        logits = model(batch_x)
        error = criterion(logits,batch_y)
        loss += error.item()
        
        probs,pred_y = logits.data.max(dim=1)
        accuracy += (pred_y==batch_y.data).sum()/batch_y.size(0)
        
    loss /= len(dataloader)
    accuracy = accuracy*100/len(dataloader)
    return loss,accuracy

In [43]:
for epoch in range(nepochs):
    since = time.time()
    for batch_x,batch_y in train_dl:
        batch_x = Variable(batch_x)
        batch_y = Variable(batch_y)
        if torch.backends.mps.is_available():
            batch_x = batch_x.to('mps')
            batch_y = batch_y.to('mps')
            
        optimizer.zero_grad()
        logits = net(batch_x)
        error = criterion(logits,batch_y)
        error.backward()
        optimizer.step()
        
    now = time.time()
    train_loss,train_acc = eval(net,criterion,train_dl)
    test_loss,test_acc = eval(net,criterion,test_dl)
    print('%2d/%d, %.0f sec|\t train loss: %.4f, train acc: %.2f |\t  test loss: %.4f, test acc: %.2f' % (epoch+1,nepochs,now-since,train_loss,train_acc,test_loss,test_acc))

 1/30, 4 sec|	 train loss: 0.3420, train acc: 89.74 |	  test loss: 0.3325, test acc: 90.18
 2/30, 4 sec|	 train loss: 0.2306, train acc: 93.09 |	  test loss: 0.2176, test acc: 93.54
 3/30, 4 sec|	 train loss: 0.1596, train acc: 95.22 |	  test loss: 0.1502, test acc: 95.57
 4/30, 4 sec|	 train loss: 0.1192, train acc: 96.39 |	  test loss: 0.1132, test acc: 96.58
 5/30, 4 sec|	 train loss: 0.0952, train acc: 97.13 |	  test loss: 0.0901, test acc: 97.31
 6/30, 4 sec|	 train loss: 0.0778, train acc: 97.58 |	  test loss: 0.0742, test acc: 97.61
 7/30, 4 sec|	 train loss: 0.0668, train acc: 97.95 |	  test loss: 0.0653, test acc: 97.86
 8/30, 4 sec|	 train loss: 0.0606, train acc: 98.10 |	  test loss: 0.0619, test acc: 97.98
 9/30, 4 sec|	 train loss: 0.0502, train acc: 98.44 |	  test loss: 0.0521, test acc: 98.20
10/30, 4 sec|	 train loss: 0.0444, train acc: 98.67 |	  test loss: 0.0499, test acc: 98.38
11/30, 4 sec|	 train loss: 0.0424, train acc: 98.67 |	  test loss: 0.0478, test acc: 98.34