In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

In [2]:
import pickle
pickle_file = 'notMNIST.pickle'

with open(pickle_file, 'rb') as f:
  save = pickle.load(f)
  train_dataset = save['train_dataset']
  train_labels = save['train_labels']
  valid_dataset = save['valid_dataset']
  valid_labels = save['valid_labels']
  test_dataset = save['test_dataset']
  test_labels = save['test_labels']
  del save  # hint to help gc free up memory
  print('Training set', train_dataset.shape, train_labels.shape)
  print('Validation set', valid_dataset.shape, valid_labels.shape)
  print('Test set', test_dataset.shape, test_labels.shape)

Training set (200000, 28, 28) (200000,)
Validation set (10000, 28, 28) (10000,)
Test set (10000, 28, 28) (10000,)


In [3]:
image_size = 28
num_labels = 10
num_channels = 1 # grayscale

import numpy as np

def reformat(dataset, labels):
    dataset = dataset.reshape(
    (-1, num_channels,image_size, image_size))
    dataset = torch.from_numpy(dataset)
#     labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
    labels = torch.from_numpy(labels)
    return dataset, labels
train_dataset, train_labels = reformat(train_dataset, train_labels)
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)
test_dataset, test_labels = reformat(test_dataset, test_labels)
print('Training set', train_dataset.size(), train_labels.size())
print('Validation set', valid_dataset.size(), valid_labels.size())
print('Test set', test_dataset.size(), test_labels.size())

Training set torch.Size([200000, 1, 28, 28]) torch.Size([200000])
Validation set torch.Size([10000, 1, 28, 28]) torch.Size([10000])
Test set torch.Size([10000, 1, 28, 28]) torch.Size([10000])


In [4]:
batch_size = 16
patch_size = 5
depth = 16
num_hidden = 64
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 输入图像channel：1；输出channel：6；5x5卷积核
        self.conv1 = nn.Conv2d(1,depth, patch_size)
        self.conv2 = nn.Conv2d(depth, depth, patch_size)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 4 * 4, num_hidden)
        self.fc2 = nn.Linear(num_hidden, 10)

    def forward(self, x):
        # 2x2 Max pooling
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # 如果是方阵,则可以只使用一个数字进行定义
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # 除去批处理维度的其他所有维度
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

cuda:0


In [6]:
net.to(device)

Net(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)

In [8]:
train_dataset.dtype

torch.float32

In [9]:
train_labels.dtype

torch.int32

In [10]:
trainset = torch.utils.data.TensorDataset(train_dataset, train_labels)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=False, num_workers=2)

validset = torch.utils.data.TensorDataset(valid_dataset, valid_labels)
validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=2)

In [11]:
itero = iter(trainloader)

### Pytorch 默认接收的inputs应该是float32, labes 应该是int64

In [12]:
for epoch in range(2):
    for i,data in enumerate(trainloader):
        inputs,labels = data
        inputs,labels = inputs.to(device,dtype = torch.float32),labels.to(device=device,dtype = torch.int64)
#         inputs = inputs.float()
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if i % 2000 == 0:
            correct = 0
            total = 0
            with torch.no_grad():
                for vdata in validloader:
                    vinputs, vlabels = data
                    vinputs,vlabels = vinputs.to(device,dtype = torch.float32),vlabels.to(device=device,dtype = torch.int64)
                    voutputs = net(vinputs)
                    _, predicted = torch.max(voutputs, 1)
                    total += vlabels.size(0)
                    correct += (predicted == vlabels).sum().item()
                print('Accuracy of the network on the valid images: %d %%' % (100 * correct / total))
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i, loss.item()))
print('Finished Training')

Accuracy of the network on the valid images: 18 %
[1,     0] loss: 2.296
Accuracy of the network on the valid images: 100 %
[1,  2000] loss: 0.105
Accuracy of the network on the valid images: 93 %
[1,  4000] loss: 0.390
Accuracy of the network on the valid images: 100 %
[1,  6000] loss: 0.070
Accuracy of the network on the valid images: 87 %
[1,  8000] loss: 0.569
Accuracy of the network on the valid images: 100 %
[1, 10000] loss: 0.149
Accuracy of the network on the valid images: 93 %
[1, 12000] loss: 0.495
Accuracy of the network on the valid images: 93 %
[2,     0] loss: 0.500
Accuracy of the network on the valid images: 100 %
[2,  2000] loss: 0.043
Accuracy of the network on the valid images: 93 %
[2,  4000] loss: 0.346
Accuracy of the network on the valid images: 100 %
[2,  6000] loss: 0.033
Accuracy of the network on the valid images: 87 %
[2,  8000] loss: 0.437
Accuracy of the network on the valid images: 100 %
[2, 10000] loss: 0.182
Accuracy of the network on the valid images: 