In [6]:
import numpy as np
from utils import *

# 交差エントロピー誤差について補足

## 交差エントロピー誤差
\begin{equation*}
L = -\sum_{i=1}^{N} \left( y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right)
\end{equation*}

分類問題で使うのでyは
\begin{equation*}
0 \leq x \leq 1 
\end{equation*}
の範囲の値をとる。（確率）

教師データtはワンホット表現なので正解データのみ1となるよって


\begin{equation*}
L = -\sum_{i=1}^{N} y_i \log(\hat{y}_i) 
\end{equation*}

としても損失関数としての役目を果たせる。

In [7]:
def init_network():
    network = {}
    network['W1'] = np.array([[0.1,0.3,0.5], [0.2,0.4,0.6]])
    network['b1'] = np.array([0.1, 0.2, 0.3])
    network['W2'] = np.array([[0.1, 0.4], [0.2, 0.5], [0.3, 0.6]])
    network['b2'] = np.array([0.1, 0.2])
    network['W3'] = np.array([[0.1, 0.3], [0.2, 0.4]])
    network['b3'] = np.array([0.1, 0.2])
    return network
    
def forward(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    
    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    
    
    return a3


In [8]:

network = init_network()
x = np.array([1.0, 0.5])
y = forward(network, x)
print(y)

[0.20993715 0.46282556]


In [9]:
from dataset.mnist import load_mnist
import pickle

def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

def init_network():
    with open("./sample_weight.pkl") as f:
        network = pickle.load(f)
    return network

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    
    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3

    return softmax(a3)
 

In [None]:
x, t = get_data()
network = init_network()

accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    p = np.argmax(y)
    
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy: ", str(float(accuracy_cnt / len(x))))