In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

### 定义SimpleNet结构

In [2]:
class simplenet(nn.Module):
    def __init__(self, classes=10, simpnet_name='simplenet'):
        super(simplenet, self).__init__()
        #print(simpnet_name)
        self.features = self._make_layers() #self._make_layers(cfg[simpnet_name])
        self.classifier = nn.Linear(256, classes)
        self.drp = nn.Dropout(0.1)

    def load_my_state_dict(self, state_dict):

        own_state = self.state_dict()

        # print(own_state.keys())
        # for name, val in own_state:
        # print(name)
        for name, param in state_dict.items():
            name = name.replace('module.', '')
            if name not in own_state:
                # print(name)
                continue
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            print("STATE_DICT: {}".format(name))
            try:
                own_state[name].copy_(param)
            except:
                print('While copying the parameter named {}, whose dimensions in the model are'
                      ' {} and whose dimensions in the checkpoint are {}, ... Using Initial Params'.format(
                    name, own_state[name].size(), param.size()))

    def forward(self, x):
        out = self.features(x)

        #Global Max Pooling
        out = F.max_pool2d(out, kernel_size=out.size()[2:]) 
        # out = F.dropout2d(out, 0.1, training=True)
        out = self.drp(out)

        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self):

        model = nn.Sequential(
                             nn.Conv2d(1, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(64, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),

                             nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),

                             nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),

                             nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),


                             nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
                             nn.Dropout2d(p=0.1),


                             nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),

                             nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(128, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),

                             nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),



                             nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
                             nn.Dropout2d(p=0.1),


                             nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),


                             nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),



                             nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
                             nn.Dropout2d(p=0.1),



                             nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(512, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),



                             nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
                             nn.Dropout2d(p=0.1),


                             nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), padding=(0, 0)),
                             nn.BatchNorm2d(2048, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),



                             nn.Conv2d(2048, 256, kernel_size=[1, 1], stride=(1, 1), padding=(0, 0)),
                             nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),


                             nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
                             nn.Dropout2d(p=0.1),


                             nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)),
                             nn.BatchNorm2d(256, eps=1e-05, momentum=0.05, affine=True),
                             nn.ReLU(inplace=True),

                            )

        for m in model.modules():
          if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))

        return model

### 加载数据

In [3]:
import torchvision
import torchvision
from torch.utils.data import DataLoader

In [13]:
train_set = torchvision.datasets.MNIST('./data', train=True, download=True, transform=torchvision.transforms.Compose([
                                    torchvision.transforms.Resize((32, 32)),                           
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize(
                                         (0.1307,), (0.3081,))
                                     ]))
test_set = torchvision.datasets.MNIST('./data', train=False, download=True, transform=torchvision.transforms.Compose([
                                    torchvision.transforms.Resize((32, 32)),                           
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize(
                                         (0.1307,), (0.3081,))
                                     ]))

### 训练

In [5]:
import torch.optim.lr_scheduler as lr_scheduler

In [6]:
epochs = 540
batch_size = 100
num_workers = 2
save_path = './checkpoints/'

In [14]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [8]:
net = simplenet()
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), lr=0.1, rho=0.9, eps=1e-3, weight_decay=0.001)

milestones = [100, 190, 306, 390, 440, 540]
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)

In [9]:
torch.cuda.is_available()

True

In [10]:
net = net.to('cuda')
loss = loss.to('cuda')

In [16]:
def evaluate(net, data_iter):
    device = list(net.parameters())[0].device
    acc_sum, n = 0.0, 0
    net.eval()
    for X, y in data_iter:
        acc_sum += (net(X.to(device)).argmax(dim= 1) == y.to(device)).float().sum().cpu().item()
        n += y.shape[0]
    net.train()
    return acc_sum / n

In [None]:
best_acc = -1
for epoch in range(epochs):
    
    for X, y in train_loader:
        X = X.to('cuda')
        y = y.to('cuda')
        l = loss(net(X), y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
    scheduler.step()
    acc = evaluate(net, test_loader)
    print('epoch %d test accuracy %f' % (epoch + 1, acc))
          
    if epoch % 10 == 0:
        torch.save(net.state_dict(), save_path + 'checkpoints_epoch_{}.pth'.format(epoch + 1))
    if acc > best_acc:
        torch.save(net.state_dict(), save_path + 'best_checkpoints_acc_{}.pth'.format(acc))
        best_acc = acc

epoch 1 test accuracy 0.802700
epoch 2 test accuracy 0.992800
epoch 3 test accuracy 0.993800
epoch 4 test accuracy 0.994200
epoch 5 test accuracy 0.994500
epoch 6 test accuracy 0.995200
epoch 7 test accuracy 0.995200
epoch 8 test accuracy 0.995700
epoch 9 test accuracy 0.994500
epoch 10 test accuracy 0.996600
epoch 11 test accuracy 0.994700
epoch 12 test accuracy 0.991800
epoch 13 test accuracy 0.996100
epoch 14 test accuracy 0.994000
epoch 15 test accuracy 0.996300
epoch 16 test accuracy 0.994400
epoch 17 test accuracy 0.989500
epoch 18 test accuracy 0.996000
epoch 19 test accuracy 0.992300
epoch 20 test accuracy 0.995400
epoch 21 test accuracy 0.995700
epoch 22 test accuracy 0.993900
