In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys,os
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
import pickle

In [2]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

In [3]:
def softmax(a):
    c = np.max(a)
    exp_a = np.exp(a-c) # 溢出对策
    sum_exp_a = np.sum(exp_a)
    y = exp_a/sum_exp_a
    
    return y

In [4]:
def get_data():
    (x_train,t_train),(x_test,t_test) = load_mnist(normalize=True,flatten=True,one_hot_label=False)
    return x_test,t_test

In [5]:
def init_network():
    with open("sample_weight.pkl","rb") as f:
        network = pickle.load(f)
    return network

In [6]:
def predict(network,x):
    W1,W2,W3 = network['W1'],network['W2'],network['W3']
    b1,b2,b3 = network['b1'],network['b2'],network['b3']
    
    a1 = np.dot(x,W1) + b1
    z1 = sigmoid(a1)
    
    a2 = np.dot(z1,W2) + b2
    z2 = sigmoid(a2)
    
    a3 = np.dot(z2,W3) + b3
    y = softmax(a3)
    
    return  y

#### 实现推理处理

In [7]:
x,t = get_data()
network = init_network()

accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network,x[i])
    p = np.argmax(y) # 获取概率最高的元素索引
    if p == t[i]:
        accuracy_cnt += 1 #与测试数据集进行比较计数
print("Accuracy:"+str(float(accuracy_cnt)/len(x)))#计算模型准确率

Accuracy:0.9352


#### 批处理

In [8]:
x,_ = get_data()

In [9]:
network = init_network()

In [10]:
W1,W2,W3 = network['W1'],network['W2'],network['W3']

In [11]:
x.shape

(10000, 784)

In [12]:
x[0].shape

(784,)

In [13]:
W1.shape

(784, 50)

In [14]:
W2.shape

(50, 100)

In [15]:
W3.shape

(100, 10)

#### 基于批处理的代码实现

In [16]:
x,t = get_data()
network = init_network()

In [17]:
batch_size = 100#批数量
accuracy_cnt = 0

In [18]:
# for i in range(0,len(x),batch_size):
#     x_batch = x[i:i+batch_size]
#     y_batch = predict(network,x_batch)
#     p = np.argmax(y_batch,axis=1)
#     accuracy_cnt += np.sum(p==t[i:i+batch_size])
    
# print("Accuracy:"+ str(float(accuracy_cnt)/len(x)))

for i in range(0,len(x),batch_size):
    x_batch=x[i:i+batch_size] # 通过x[i:i+batch_size]从输入数据中抽出批数据
    y_batch=predict(network,x_batch)
    p=np.argmax(y_batch,axis=1) #  通过argmax()获取值最大的元素索引
    accuracy_cnt+=np.sum(p==t[i:i+batch_size])
print("Accuracy:"+ str(float(accuracy_cnt)/len(x)))

Accuracy:0.9352


#### 解释

In [19]:
list(range(0,10))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [20]:
list(range(0,10,3))

[0, 3, 6, 9]

In [21]:
x = np.array([
    [0.1,0.8,0.1],
    [0.3,0.1,0.6],
    [0.2,0.5,0.3],
    [0.8,0.1,0.1]
])
y = np.argmax(x,axis=1)#找到每一行中对应最大元素的索引位置
print(y)

[1 2 1 0]


In [22]:
#比较批处理结果与实际答案
y = np.array([1,2,1,0])
t = np.array([1,2,0,0])
print(y==t)

[ True  True False  True]


In [23]:
np.sum(y==t)

3