In [38]:
import torch
from torch import nn
from torchvision import datasets, transforms

# 1 parameters

In [39]:
# dataset
input_shape = 28
num_classes = 10
# hyper
batch_size = 64
num_epochs = 5
learning_rate = 1e-3
# gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2 dataset和dataloader

In [40]:
train_dataset = datasets.MNIST(root='./data',
                                           train=True,
                                           download=True,
                                           transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data',
                                           train=False,
                                           download=True,
                                           transform=transforms.ToTensor())

In [41]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           shuffle=True,
                                           batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           shuffle=True,
                                           batch_size=batch_size)

In [42]:
images, labels = next(iter(train_loader))

In [43]:
# batch_size,channels,h(i),w(j)
images.shape

torch.Size([64, 1, 28, 28])

# 3 model arch


- cnn: channel 不断递增， shape 不断减少的过程
    - 最好是 *2

In [44]:
class CNN(nn.Module):
    def __init__(self, in_channels, input_shape, num_classes):
        super(CNN, self).__init__()
        # conv2d: (b, 1, 28, 28) => (b, 16, 28, 28)
        # maxpool2d: (b, 16, 28, 28) => (b, 16, 14, 14)
        self.cnn1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=16,
                                            kernel_size=5,padding=2,stride=1),
                                  nn.BatchNorm2d(16),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2,stride=2))
        # conv2d: (b, 16, 14, 14) => (b, 32, 14, 14)
        # maxpool2d: (b, 32, 14, 14) => (b, 32, 7, 7)
        self.cnn2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32,
                                            kernel_size=5,padding=2,stride=1),
                                  nn.BatchNorm2d(32),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=2,stride=2))
        # reshape (b, 32, 7, 7) => (b, 32*7*7)
        # (b, 32*7*7) => (b, 10)
        self.fc = nn.Linear(32*(input_shape//4)*(input_shape//4) , num_classes)
        pass

    def forward(self,x):
        # (b, 1, 28, 28) => (b, 16, 14, 14)
        out = self.cnn1(x)
        # (b, 16, 14, 14) => (b, 32, 7, 7)
        out = self.cnn2(out)
        # (b, 32, 7, 7) => (b, 32*7*7)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
        pass


# torchsummary

In [45]:
from torchsummary import summary

In [46]:
model = CNN(input_shape=input_shape, num_classes=num_classes, in_channels=1).to(device)
summary(model, input_size=(1,28,28), batch_size=batch_size)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [64, 16, 28, 28]             416
       BatchNorm2d-2           [64, 16, 28, 28]              32
              ReLU-3           [64, 16, 28, 28]               0
         MaxPool2d-4           [64, 16, 14, 14]               0
            Conv2d-5           [64, 32, 14, 14]          12,832
       BatchNorm2d-6           [64, 32, 14, 14]              64
              ReLU-7           [64, 32, 14, 14]               0
         MaxPool2d-8             [64, 32, 7, 7]               0
            Linear-9                   [64, 10]          15,690
Total params: 29,034
Trainable params: 29,034
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 29.86
Params size (MB): 0.11
Estimated Total Size (MB): 30.17
-------------------------------------------

# 4 model train

In [47]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [48]:
total_batch = len(train_loader)

In [49]:
for epoch in range(num_epochs):
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        out = model(images)
        loss = criterion(out, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch_idx+1)%100 == 0:
            print(f'{epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}')

1/5, 100/938: 0.0933
1/5, 200/938: 0.1791
1/5, 300/938: 0.0963
1/5, 400/938: 0.0169
1/5, 500/938: 0.0435
1/5, 600/938: 0.0651
1/5, 700/938: 0.0598
1/5, 800/938: 0.0839
1/5, 900/938: 0.1724
2/5, 100/938: 0.0160
2/5, 200/938: 0.0204
2/5, 300/938: 0.0156
2/5, 400/938: 0.1561
2/5, 500/938: 0.0102
2/5, 600/938: 0.0561
2/5, 700/938: 0.0826
2/5, 800/938: 0.0400
2/5, 900/938: 0.0217
3/5, 100/938: 0.0436
3/5, 200/938: 0.0076
3/5, 300/938: 0.0252
3/5, 400/938: 0.0075
3/5, 500/938: 0.0769
3/5, 600/938: 0.0021
3/5, 700/938: 0.0556
3/5, 800/938: 0.0052
3/5, 900/938: 0.0212
4/5, 100/938: 0.0281
4/5, 200/938: 0.0033
4/5, 300/938: 0.0213
4/5, 400/938: 0.0227
4/5, 500/938: 0.2058
4/5, 600/938: 0.0310
4/5, 700/938: 0.0307
4/5, 800/938: 0.0056
4/5, 900/938: 0.0024
5/5, 100/938: 0.0124
5/5, 200/938: 0.0027
5/5, 300/938: 0.0621
5/5, 400/938: 0.0465
5/5, 500/938: 0.0188
5/5, 600/938: 0.0103
5/5, 700/938: 0.0052
5/5, 800/938: 0.0019
5/5, 900/938: 0.0046


# model evaluation

In [50]:
total = 0
correct = 0
for images,labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)
    out = model(images)
    preds = torch.argmax(out, dim=1)

    total += images.size(0)
    correct += (preds == labels).sum().item()
print(f'{correct}/{total} = {correct/total}')

59643/60000 = 0.99405


In [51]:
torch.save(model.state_dict(), 'cnn_mnist.ckpt')