In [4]:
# データの読み込み(sklearn)
import torch
from skorch import NeuralNetClassifier
from torch import nn
import torch.nn.functional as F
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
mnist = fetch_mldata('MNIST original', data_home='./data/mnist')
X = mnist.data.astype('float32')
y = mnist.target.astype('int64')
X /= 255
XCnn = X.reshape(-1, 1, 28, 28)
XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42)
# Networkの設計(PyTorch)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(1600, 128) # 1600 = number channels * width * height
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) # flatten over channel, height and width = 1600
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        x = F.softmax(x, dim=-1)
        return x
# ラッパーを使う(skorch)
net = NeuralNetClassifier(
    Net,
    max_epochs=10,
    lr=1,
    optimizer=torch.optim.Adadelta,
    # use_cuda=True,  # uncomment this to train with CUDA
)
# training
net.fit(XCnn_train, y_train)

# test
y_pred = net.predict(XCnn_test)
print(classification_report(y_test, y_pred))

  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m0.4107[0m       [32m0.9741[0m        [35m0.0838[0m  50.1330
      2        [36m0.1332[0m       [32m0.9812[0m        [35m0.0603[0m  48.3953
      3        [36m0.1077[0m       [32m0.9845[0m        [35m0.0467[0m  59.6833
      4        [36m0.0896[0m       [32m0.9863[0m        [35m0.0420[0m  50.3603
      5        [36m0.0808[0m       0.9863        [35m0.0394[0m  47.4268
      6        [36m0.0733[0m       [32m0.9882[0m        [35m0.0358[0m  49.1700
      7        [36m0.0672[0m       [32m0.9890[0m        [35m0.0347[0m  52.3128
      8        [36m0.0636[0m       0.9889        [35m0.0342[0m  48.8148
      9        [36m0.0590[0m       0.9887        [35m0.0339[0m  46.8803
     10        [36m0.0548[0m       [32m0.9901[0m        [35m0.0306[0m  47.4012
             precision    recall  f1-score   support

       

In [None]:
print()