In [101]:
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import struct
import torch
import torch.nn as nn
import torch.nn.functional as F

In [102]:
# mnist ref:http://yann.lecun.com/exdb/mnist/
# struct ref:https://www.cnblogs.com/gala/archive/2011/09/22/2184801.html
# numpy ref:numpy.org.cn/article/basics/numpy_matrices_vectors.html
def LoadImages(file):
    try:
        F = open(file, 'rb')
    except IOError:
        print("open error")
        F.close()

    f = F.read()

    offset=0
    fmt = '>iiii'
    magic, images, rows, columns = struct.unpack_from(fmt, f, offset)

    offset += struct.calcsize(fmt)
    fmt = '>' + str(rows * columns) + 'B'

    images = 10 # for test

    # for CNN
    # ImageSet = np.empty((images, rows, columns))
    # for i in range(images):
    #     ImageSet[i] = np.array(struct.unpack_from(fmt, f, offset)).reshape((rows, columns))
    #     offset += struct.calcsize(fmt)

    # for simple softmax
    ImageSet = np.empty((images, rows * columns))
    for i in range(images):
        ImageSet[i] = np.array(struct.unpack_from(fmt, f, offset)).reshape((rows*columns))
        offset += struct.calcsize(fmt)

    F.close()

    return ImageSet

In [103]:
def LoadLabels(file):
    try:
        F = open(file, 'rb')
    except IOError:
        print("open error")
        F.close()

    f = F.read()

    offset=0
    fmt = '>ii'

    magic, items = struct.unpack_from(fmt, f, offset)

    offset += struct.calcsize(fmt)
    fmt = '>B'

    items = 10 # for test

    LabelSet = np.zeros((items, 10))
    for i in range(items):
        LabelSet[i][struct.unpack_from(fmt, f, offset)[0]] = 1
        offset += struct.calcsize(fmt)

    F.close()

    return LabelSet


In [104]:
# test
# im = LoadImages('train-images.idx3-ubyte')
# la = LoadLabels('train-labels.idx1-ubyte')
# plt.ion()
# for i in range(10):
#     plt.imshow(im[i], cmap='Greys')
#     print(la[i])
#     plt.pause(1)

In [105]:
# load training set
TrainImages_np = LoadImages('train-images.idx3-ubyte')
TrainLabels_np = LoadLabels('train-labels.idx1-ubyte')

TrainImages = torch.from_numpy(TrainImages_np)
TrainLabels = torch.from_numpy(TrainLabels_np)

In [106]:
# test
print(TrainImages)
print(TrainLabels)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]], dtype=torch.float64)


In [107]:
# softmax with torch
class SoftmaxNet(nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(SoftmaxNet, self).__init__()
        self.hidden = nn.Linear(n_feature, n_hidden)
        self.predict = nn.Linear(n_hidden, n_output)

    def forward(self, x):
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return F.softmax(x, dim=1)

In [108]:
net = SoftmaxNet(28*28, 520, 10)
print(net)


SoftmaxNet(
  (hidden): Linear(in_features=784, out_features=520, bias=True)
  (predict): Linear(in_features=520, out_features=10, bias=True)
)


In [109]:
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_func = nn.MSELoss()

TrainImages = torch.tensor(TrainImages, dtype=torch.float32)
TrainLabels = torch.tensor(TrainLabels, dtype=torch.float32)

for t in range(5):
    prediction = net(TrainImages)
    loss = loss_func(prediction, TrainLabels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # for test
    print(loss.data.numpy())
    # plt.ion()
    for i in range(10):
        print((TrainLabels.numpy())[i])
        print((prediction.detach().numpy())[i])
        # plt.imshow(((TrainImages.numpy())[i]).reshape((28, 28)), cmap='Greys')
        # plt.pause(1)

0.12205237
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[4.9467782e-20 4.1360080e-14 1.1239714e-04 2.6005052e-06 5.2913076e-21
 9.9988496e-01 9.2716762e-19 2.3027789e-24 1.4922628e-17 1.2042262e-28]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[9.3351982e-14 1.3328730e-07 1.7368343e-02 9.5789212e-01 9.1107521e-12
 2.4739409e-02 8.8541912e-23 1.3120361e-18 5.1082639e-26 2.1963773e-24]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[7.03766184e-25 1.68134685e-14 4.01520006e-12 9.99952793e-01
 4.72466745e-05 1.44424633e-13 6.50952337e-16 7.16692691e-23
 1.18336616e-17 3.09392226e-21]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[9.9496310e-07 9.9991834e-01 1.4470901e-06 4.9041599e-05 3.7820318e-08
 4.4823221e-08 1.1787675e-22 3.0305968e-05 2.9974734e-14 7.9499407e-11]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[3.57777696e-12 2.28750197e-09 1.31199045e-11 5.63690095e-10
 4.34593971e-15 9.99999881e-01 1.15539132e-16 9.13181668e-16
 1.17780427e-12 7.37193560e-08]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[6.3160729e-23 9.9998116e-01 2.2123553e-12 5.2190826e-14 1.8