In [7]:
import time
import torch
from torch import nn,optim

import sys
sys.path.append('../code/')
import d2lzh_pytorch as d2l
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(1,6,5),
            nn.Sigmoid(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(6,16,5),
            nn.Sigmoid(),
            nn.MaxPool2d(2,2)
        )
        
        self.fc=nn.Sequential(
            nn.Linear(16*4*4,120),
            nn.Sigmoid(),
            nn.Linear(120,84),
            nn.Sigmoid(),
            nn.Linear(84,10)
        )
        
    def forward(self,img):
        feature=self.conv(img)
        output=self.fc(feature.view(img.shape[0],-1))
        return output

In [9]:
net=LeNet()
print(net)

LeNet(
  (conv): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Sigmoid()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): Sigmoid()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): Sigmoid()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): Sigmoid()
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)


In [10]:
batch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size,root='/workspace/mycode/Dive-into-DL-PyTorch/data')

In [11]:
def evaluate_accuracy(data_iter,net,device=None):
    if device is None and isinstance(net,torch.nn.Module):
        device =list(net.parameters())[0].device
    
    acc_sum,n=0.0,0
    with torch.no_grad():
        for X,y in data_iter:
            if isinstance(net,nn.Module):
                net.eval()#评估模式,关闭dropout
                acc_sum+=(net(X.to(device)).argmax(dim=1)==y.to(device)).float().sum().cpu().item()
                net.train()#改回训练模式
            else:
                if('is_training' in net.__code__.co_varnames):
                    acc_sum+=(net(X,is_training=False).argmax(dim=1)==y).float().sum().item()
                else:
                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
            n+=y.shape[0]
    return acc_sum/n

In [16]:
def train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs):
    net=net.to(device)
    print('training on',device)
    loss = torch.nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        train_l_sum,train_acc_sum,n,batch_count,start=0.0,0.0,0,0,time.time()
        for X,y in train_iter:
            X=X.to(device)
            y=y.to(device)
            y_hat=net(X)
            #print(y.shape)
            print(y_hat.shape)
            l=loss(y_hat,y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum+=l.cpu().item()
            train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().cpu().item()
            #print(train_acc_sum)
            n+=y.shape[0]
            batch_count+=1
        test_acc=evaluate_accuracy(test_iter,net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' % (epoch+1,train_l_sum/batch_count,train_acc_sum/n,test_acc,time.time()-start))

In [17]:
lr,num_epochs=0.001,5
optimizer=torch.optim.Adam(net.parameters(),lr=lr)
train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)

training on cuda
torch.Size([256, 10])
190.0
torch.Size([256, 10])
369.0
torch.Size([256, 10])
554.0
torch.Size([256, 10])
729.0
torch.Size([256, 10])
909.0
torch.Size([256, 10])
1082.0
torch.Size([256, 10])
1260.0
torch.Size([256, 10])
1440.0
torch.Size([256, 10])
1610.0
torch.Size([256, 10])
1801.0
torch.Size([256, 10])
1962.0
torch.Size([256, 10])
2144.0
torch.Size([256, 10])
2325.0
torch.Size([256, 10])
2508.0
torch.Size([256, 10])
2685.0
torch.Size([256, 10])
2848.0
torch.Size([256, 10])
3029.0
torch.Size([256, 10])
3211.0
torch.Size([256, 10])
3397.0
torch.Size([256, 10])
3569.0
torch.Size([256, 10])
3744.0
torch.Size([256, 10])
3932.0
torch.Size([256, 10])
4116.0
torch.Size([256, 10])
4278.0
torch.Size([256, 10])
4454.0
torch.Size([256, 10])
4641.0
torch.Size([256, 10])
4828.0
torch.Size([256, 10])
5014.0
torch.Size([256, 10])
5191.0
torch.Size([256, 10])
5369.0
torch.Size([256, 10])
5549.0
torch.Size([256, 10])
5732.0
torch.Size([256, 10])
5908.0
torch.Size([256, 10])
6084.0
to

KeyboardInterrupt: 

In [51]:
for X,y in train_iter:
    print(X.shape)
    break

torch.Size([256, 1, 28, 28])


    28*28 -> 6*24*24 -> 6*12*12 -> 16*8*8 -> 16*4*4