In [1]:
from linear import *
from convolutional import *
from activations import *
from loss import CategoricalCrossEntropyLoss
import mnist

np.random.seed(0)

In [2]:
class MnistConv:
    def __init__(self, lr = 0.01) -> None:
        self.layers = [
            ConvLayer(1, 32, 5, activaton=ReLu()),
            MaxPool(2),
            Flatten(),
            Linear(32*12**2, 64, Sigmoid()),
            Linear(64, 10, None)
        ]
        self.lr = lr
        self.res = None
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        self.res = x
        return x
    
    def backward(self, expected_output):
        loss = CategoricalCrossEntropyLoss()
        loss_val = loss(self.res, expected_output)
        err = loss.derivatives()
        for layer in reversed(self.layers):
            err = layer.backward(err, self.lr)
        return loss_val

In [3]:
x_train, t_train, x_test, t_test = mnist.load()
model = MnistConv()

In [4]:
EPOCHES = 3

for epoch in range(EPOCHES):
    print(f"epoch {epoch}/{EPOCHES}")
    loss_val = 0
    for i, (img, label) in enumerate(zip(x_train[:1000], t_train[:1000])):

        res = model.forward(img.reshape((1, 1, 28, 28))/255.0)
        loss_val += model.backward(label)
        if i % 100 == 0 and i > 0:
            print(f"epoch: {epoch}.{i//100}, loss: {loss_val/100}")
            loss_val = 0




epoch 0/3


KeyboardInterrupt: 

In [None]:
results = np.zeros((3, 10))
for (img, label) in zip(x_test, t_test):
    results[0][label] += 1
    r = model.forward(img.reshape((1, 1, 28, 28))/255.0)
    if np.argmax(r, axis = 1) == label:
        results[1][label] +=1
    results[2][np.argmax(r, axis = 1)] += 1

print(results[1])
print(results[0])
print(results[2])

[ 955. 1110.  910.  949.  831.  787.  892.  909.  842.  950.]
[ 980. 1135. 1032. 1010.  982.  892.  958. 1028.  974. 1009.]
[1025. 1140.  975. 1100.  885.  849.  960.  966.  904. 1196.]


In [None]:
np.sum(results[1])/np.sum(results[0])

0.9135

In [None]:
results[1]/results[0]

array([0.9744898 , 0.97797357, 0.88178295, 0.93960396, 0.84623218,
       0.882287  , 0.93110647, 0.88424125, 0.86447639, 0.94152626])