#### NN class:

In [5]:
import numpy as np
import pickle

class NN(object):

    def __init__(self, hidden_dims=(1024, 2048), n_hidden=2, mode='train', datapath=None, model_path=None):

        assert len(hidden_dims) == n_hidden, "Hidden dims mismatch!"

        self.hidden_dims = hidden_dims
        self.n_hidden = n_hidden
        self.mode = mode
        self.datapath = datapath
        self.model_path = model_path
        #self.tr, self.va, self.te = np.load(open(datapath, "rb"))
        u = pickle._Unpickler(open(datapath, 'rb'))
        u.encoding = 'latin1'
        self.tr, self.va, self.te = u.load()
        
        self.epsilon = 1e-6
        self.lr = 1e-1 # learning rate
        self.n_epochs = 1000
        self.batch_size = 1000

    def initialize_weights(self, dims):
        """
        :param dims: the size of input/output layers
        :return: None
        """
        if self.mode == "train":
            self.weights = {}
            all_dims = [dims[0]] + list(self.hidden_dims) + [dims[1]]
            print(all_dims)
            for layer_n in range(1, self.n_hidden + 2):
                self.weights[f"W{layer_n}"] = np.random.rand(all_dims[layer_n - 1], all_dims[layer_n]) / 50
                self.weights[f"b{layer_n}"] = np.zeros((1, all_dims[layer_n]))  # np.random.rand(1, all_dims[layer_n])
        elif self.mode == "test":
            pass
        else:
            raise Exception("Unknown Mode!")

    def activation(self, input, prime=False): # Prime for Heavyside, else ReLu
        if prime:
            return input > 0
        return np.maximum(0, input)

    def loss(self, prediction, labels):  #
        # TODO
        prediction[np.where(prediction < self.epsilon)] = self.epsilon
        prediction[np.where(prediction > 1 - self.epsilon)] = 1 - self.epsilon
        return - np.sum(labels * np.log(prediction)) # / prediction.shape[0]

    def softmax(self, input):  # Computes the stable softmax of the input
        Z = np.exp(input - np.max(input)) # softmax(x+C) = softmax(x)
        return Z / np.sum(Z, axis=1, keepdims=True)
    
    def forward(self, input):  #
        cache = {"H0": input}
        for layer in range(1, self.n_hidden + 1):
            cache[f"A{layer}"] = cache[f"H{layer-1}"] @ self.weights[f"W{layer}"] + self.weights[f"b{layer}"]
            cache[f"H{layer}"] = self.activation(cache[f"A{layer}"])

        layer = self.n_hidden + 1
        cache[f"A{layer}"] = cache[f"H{layer-1}"] @ self.weights[f"W{layer}"] + self.weights[f"b{layer}"]
        cache[f"H{layer}"] = self.softmax(cache[f"A{layer}"]) # softmax on last layer
        return cache

    def backward(self, cache, labels):  #
        # TODO
        output = cache[f"H{self.n_hidden+1}"]
        grads = {
            f"dA{self.n_hidden+1}": - (labels - output),
        }
        for layer in range(self.n_hidden + 1, 0, -1):
            # print(f"Shape dA=", grads[f"dA{layer}"].shape)
            # print(f"Shape H=", cache[f"H{layer-1}"].shape)

            grads[f"dW{layer}"] = cache[f"H{layer-1}"].T @ grads[f"dA{layer}"]
            grads[f"db{layer}"] = grads[f"dA{layer}"]

            if layer > 1:
                grads[f"dH{layer-1}"] = grads[f"dA{layer}"] @ self.weights[f"W{layer}"].T
                grads[f"dA{layer-1}"] = grads[f"dH{layer-1}"] * self.activation(cache[f"A{layer-1}"], prime=True)
                # print(f"Shape dA=", grads[f"dA{layer-1}"].shape)
        return grads

    def update(self, grads):  #
        # rint(grads.keys())
        for layer in range(1, self.n_hidden + 1):
            # print(grads[f"dW{layer}"].shape,self.weights[f"W{layer}"].shape)
            self.weights[f"W{layer}"] = self.weights[f"W{layer}"] - self.lr * grads[f"dW{layer}"] / self.batch_size

    def train(self):
        X_train, y_train = self.tr
        y_onehot = np.eye(np.max(y_train) - np.min(y_train) + 1)[y_train]
        # print(y_train.shape,y_onehot.shape)
        dims = [X_train.shape[1], y_onehot.shape[1]]
        self.initialize_weights(dims)

        n_batches = int(np.ceil(X_train.shape[0] / self.batch_size))

        for epoch in range(self.n_epochs):
            predictedY = np.zeros_like(y_train)
            trainLoss = 0
            for batch in range(n_batches):
                minibatchX = X_train[self.batch_size * batch:self.batch_size * (batch + 1), :]
                minibatchY = y_onehot[self.batch_size * batch:self.batch_size * (batch + 1), :]
                cache = self.forward(minibatchX)
                grads = self.backward(cache, minibatchY)
                self.update(grads)

                trainLoss += self.loss(cache[f"H{self.n_hidden+1}"], minibatchY)
                predictedY[self.batch_size * batch:self.batch_size * (batch + 1)] = np.argmax(
                    cache[f"H{self.n_hidden + 1}"], axis=1)

            X_val, y_val = self.va
            onVal_y = np.eye(np.max(y_train) - np.min(y_train) + 1)[y_val]
            valCache = self.forward(X_val)

            predicted_valY = np.argmax(valCache[f"H{self.n_hidden + 1}"], axis=1)
            valAccuracy = np.mean(y_val == predicted_valY)
            valLoss = self.loss(valCache[f"H{self.n_hidden+1}"], onVal_y)

            trAccuracy = np.mean(y_train == predictedY)

            print(f"Epoch= {epoch}, Loss={trainLoss:10.2f}, Accuracy={trAccuracy:4.2f}, Val.Loss={valLoss:10.2f}, Val.Accuracy= {valAccuracy:4.2f}")
            # break

    def test(self):
        pass

#### Test of the NN class with MNIST:

In [6]:
neural_net = NN(datapath="mnist.pkl", hidden_dims=(500, 400))
neural_net.train()

[784, 500, 400, 10]
Epoch= 0, Loss= 115689.38, Accuracy=0.09, Val.Loss=  22841.30, Val.Accuracy= 0.10
Epoch= 1, Loss= 114254.59, Accuracy=0.10, Val.Loss=  22806.54, Val.Accuracy= 0.10
Epoch= 2, Loss= 114047.58, Accuracy=0.10, Val.Loss=  22755.90, Val.Accuracy= 0.10
Epoch= 3, Loss= 113740.02, Accuracy=0.10, Val.Loss=  22677.23, Val.Accuracy= 0.10
Epoch= 4, Loss= 113247.81, Accuracy=0.10, Val.Loss=  22549.27, Val.Accuracy= 0.10
Epoch= 5, Loss= 112434.06, Accuracy=0.10, Val.Loss=  22336.68, Val.Accuracy= 0.10
Epoch= 6, Loss= 111077.47, Accuracy=0.11, Val.Loss=  21984.01, Val.Accuracy= 0.14
Epoch= 7, Loss= 108851.44, Accuracy=0.19, Val.Loss=  21415.45, Val.Accuracy= 0.28
Epoch= 8, Loss= 105373.89, Accuracy=0.33, Val.Loss=  20559.18, Val.Accuracy= 0.44
Epoch= 9, Loss= 100409.24, Accuracy=0.49, Val.Loss=  19402.82, Val.Accuracy= 0.57
Epoch= 10, Loss=  94135.70, Accuracy=0.58, Val.Loss=  18032.60, Val.Accuracy= 0.62
Epoch= 11, Loss=  87161.73, Accuracy=0.62, Val.Loss=  16595.42, Val.Accuracy=

Epoch= 99, Loss=  17355.88, Accuracy=0.90, Val.Loss=   3197.50, Val.Accuracy= 0.91
Epoch= 100, Loss=  17307.55, Accuracy=0.90, Val.Loss=   3189.06, Val.Accuracy= 0.91
Epoch= 101, Loss=  17260.06, Accuracy=0.90, Val.Loss=   3180.76, Val.Accuracy= 0.91
Epoch= 102, Loss=  17213.35, Accuracy=0.90, Val.Loss=   3172.61, Val.Accuracy= 0.91
Epoch= 103, Loss=  17167.40, Accuracy=0.90, Val.Loss=   3164.59, Val.Accuracy= 0.91
Epoch= 104, Loss=  17122.19, Accuracy=0.90, Val.Loss=   3156.70, Val.Accuracy= 0.91
Epoch= 105, Loss=  17077.73, Accuracy=0.90, Val.Loss=   3148.95, Val.Accuracy= 0.91
Epoch= 106, Loss=  17033.97, Accuracy=0.90, Val.Loss=   3141.32, Val.Accuracy= 0.91
Epoch= 107, Loss=  16990.88, Accuracy=0.90, Val.Loss=   3133.81, Val.Accuracy= 0.91
Epoch= 108, Loss=  16948.41, Accuracy=0.90, Val.Loss=   3126.42, Val.Accuracy= 0.91
Epoch= 109, Loss=  16906.58, Accuracy=0.90, Val.Loss=   3119.13, Val.Accuracy= 0.91
Epoch= 110, Loss=  16865.36, Accuracy=0.90, Val.Loss=   3111.96, Val.Accuracy

Epoch= 197, Loss=  14528.83, Accuracy=0.92, Val.Loss=   2717.28, Val.Accuracy= 0.92
Epoch= 198, Loss=  14509.19, Accuracy=0.92, Val.Loss=   2714.07, Val.Accuracy= 0.92
Epoch= 199, Loss=  14489.64, Accuracy=0.92, Val.Loss=   2710.88, Val.Accuracy= 0.92
Epoch= 200, Loss=  14470.15, Accuracy=0.92, Val.Loss=   2707.69, Val.Accuracy= 0.92
Epoch= 201, Loss=  14450.73, Accuracy=0.92, Val.Loss=   2704.52, Val.Accuracy= 0.92
Epoch= 202, Loss=  14431.38, Accuracy=0.92, Val.Loss=   2701.36, Val.Accuracy= 0.92
Epoch= 203, Loss=  14412.10, Accuracy=0.92, Val.Loss=   2698.22, Val.Accuracy= 0.92
Epoch= 204, Loss=  14392.86, Accuracy=0.92, Val.Loss=   2695.08, Val.Accuracy= 0.92
Epoch= 205, Loss=  14373.68, Accuracy=0.92, Val.Loss=   2691.96, Val.Accuracy= 0.92
Epoch= 206, Loss=  14354.56, Accuracy=0.92, Val.Loss=   2688.85, Val.Accuracy= 0.92
Epoch= 207, Loss=  14335.50, Accuracy=0.92, Val.Loss=   2685.76, Val.Accuracy= 0.92
Epoch= 208, Loss=  14316.48, Accuracy=0.92, Val.Loss=   2682.67, Val.Accurac

Epoch= 295, Loss=  12766.18, Accuracy=0.93, Val.Loss=   2431.23, Val.Accuracy= 0.93
Epoch= 296, Loss=  12749.03, Accuracy=0.93, Val.Loss=   2428.41, Val.Accuracy= 0.93
Epoch= 297, Loss=  12731.93, Accuracy=0.93, Val.Loss=   2425.60, Val.Accuracy= 0.93
Epoch= 298, Loss=  12714.86, Accuracy=0.93, Val.Loss=   2422.80, Val.Accuracy= 0.93
Epoch= 299, Loss=  12697.81, Accuracy=0.93, Val.Loss=   2420.01, Val.Accuracy= 0.93
Epoch= 300, Loss=  12680.77, Accuracy=0.93, Val.Loss=   2417.21, Val.Accuracy= 0.93
Epoch= 301, Loss=  12663.74, Accuracy=0.93, Val.Loss=   2414.41, Val.Accuracy= 0.93
Epoch= 302, Loss=  12646.71, Accuracy=0.93, Val.Loss=   2411.60, Val.Accuracy= 0.93
Epoch= 303, Loss=  12629.71, Accuracy=0.93, Val.Loss=   2408.80, Val.Accuracy= 0.93
Epoch= 304, Loss=  12612.73, Accuracy=0.93, Val.Loss=   2406.01, Val.Accuracy= 0.93
Epoch= 305, Loss=  12595.77, Accuracy=0.93, Val.Loss=   2403.22, Val.Accuracy= 0.93
Epoch= 306, Loss=  12578.83, Accuracy=0.93, Val.Loss=   2400.42, Val.Accurac

Epoch= 393, Loss=  11186.18, Accuracy=0.94, Val.Loss=   2172.58, Val.Accuracy= 0.94
Epoch= 394, Loss=  11171.36, Accuracy=0.94, Val.Loss=   2170.16, Val.Accuracy= 0.94
Epoch= 395, Loss=  11156.57, Accuracy=0.94, Val.Loss=   2167.74, Val.Accuracy= 0.94
Epoch= 396, Loss=  11141.80, Accuracy=0.94, Val.Loss=   2165.33, Val.Accuracy= 0.94
Epoch= 397, Loss=  11127.05, Accuracy=0.94, Val.Loss=   2162.93, Val.Accuracy= 0.94
Epoch= 398, Loss=  11112.32, Accuracy=0.94, Val.Loss=   2160.54, Val.Accuracy= 0.94
Epoch= 399, Loss=  11097.61, Accuracy=0.94, Val.Loss=   2158.14, Val.Accuracy= 0.94
Epoch= 400, Loss=  11082.92, Accuracy=0.94, Val.Loss=   2155.75, Val.Accuracy= 0.94
Epoch= 401, Loss=  11068.25, Accuracy=0.94, Val.Loss=   2153.37, Val.Accuracy= 0.94
Epoch= 402, Loss=  11053.60, Accuracy=0.94, Val.Loss=   2150.98, Val.Accuracy= 0.94
Epoch= 403, Loss=  11038.97, Accuracy=0.94, Val.Loss=   2148.61, Val.Accuracy= 0.94
Epoch= 404, Loss=  11024.38, Accuracy=0.94, Val.Loss=   2146.23, Val.Accurac

Epoch= 491, Loss=   9842.50, Accuracy=0.95, Val.Loss=   1953.85, Val.Accuracy= 0.95
Epoch= 492, Loss=   9829.84, Accuracy=0.95, Val.Loss=   1951.80, Val.Accuracy= 0.95
Epoch= 493, Loss=   9817.20, Accuracy=0.95, Val.Loss=   1949.76, Val.Accuracy= 0.95
Epoch= 494, Loss=   9804.57, Accuracy=0.95, Val.Loss=   1947.72, Val.Accuracy= 0.95
Epoch= 495, Loss=   9791.96, Accuracy=0.95, Val.Loss=   1945.69, Val.Accuracy= 0.95
Epoch= 496, Loss=   9779.39, Accuracy=0.95, Val.Loss=   1943.66, Val.Accuracy= 0.95
Epoch= 497, Loss=   9766.84, Accuracy=0.95, Val.Loss=   1941.64, Val.Accuracy= 0.95
Epoch= 498, Loss=   9754.31, Accuracy=0.95, Val.Loss=   1939.63, Val.Accuracy= 0.95
Epoch= 499, Loss=   9741.79, Accuracy=0.95, Val.Loss=   1937.61, Val.Accuracy= 0.95
Epoch= 500, Loss=   9729.30, Accuracy=0.95, Val.Loss=   1935.59, Val.Accuracy= 0.95
Epoch= 501, Loss=   9716.82, Accuracy=0.95, Val.Loss=   1933.58, Val.Accuracy= 0.95
Epoch= 502, Loss=   9704.36, Accuracy=0.95, Val.Loss=   1931.56, Val.Accurac

Epoch= 589, Loss=   8680.17, Accuracy=0.95, Val.Loss=   1766.37, Val.Accuracy= 0.95
Epoch= 590, Loss=   8669.09, Accuracy=0.95, Val.Loss=   1764.60, Val.Accuracy= 0.95
Epoch= 591, Loss=   8658.03, Accuracy=0.95, Val.Loss=   1762.82, Val.Accuracy= 0.95
Epoch= 592, Loss=   8646.98, Accuracy=0.95, Val.Loss=   1761.05, Val.Accuracy= 0.95
Epoch= 593, Loss=   8635.94, Accuracy=0.95, Val.Loss=   1759.29, Val.Accuracy= 0.95
Epoch= 594, Loss=   8624.93, Accuracy=0.95, Val.Loss=   1757.52, Val.Accuracy= 0.95
Epoch= 595, Loss=   8613.92, Accuracy=0.95, Val.Loss=   1755.77, Val.Accuracy= 0.95
Epoch= 596, Loss=   8602.94, Accuracy=0.95, Val.Loss=   1754.01, Val.Accuracy= 0.95
Epoch= 597, Loss=   8591.96, Accuracy=0.95, Val.Loss=   1752.27, Val.Accuracy= 0.95
Epoch= 598, Loss=   8581.01, Accuracy=0.95, Val.Loss=   1750.51, Val.Accuracy= 0.95
Epoch= 599, Loss=   8570.07, Accuracy=0.95, Val.Loss=   1748.77, Val.Accuracy= 0.95
Epoch= 600, Loss=   8559.15, Accuracy=0.95, Val.Loss=   1747.03, Val.Accurac

Epoch= 687, Loss=   7665.95, Accuracy=0.96, Val.Loss=   1605.55, Val.Accuracy= 0.96
Epoch= 688, Loss=   7656.36, Accuracy=0.96, Val.Loss=   1604.04, Val.Accuracy= 0.96
Epoch= 689, Loss=   7646.78, Accuracy=0.96, Val.Loss=   1602.54, Val.Accuracy= 0.96
Epoch= 690, Loss=   7637.23, Accuracy=0.96, Val.Loss=   1601.03, Val.Accuracy= 0.96
Epoch= 691, Loss=   7627.68, Accuracy=0.96, Val.Loss=   1599.54, Val.Accuracy= 0.96
Epoch= 692, Loss=   7618.15, Accuracy=0.96, Val.Loss=   1598.04, Val.Accuracy= 0.96
Epoch= 693, Loss=   7608.63, Accuracy=0.96, Val.Loss=   1596.56, Val.Accuracy= 0.96
Epoch= 694, Loss=   7599.14, Accuracy=0.96, Val.Loss=   1595.06, Val.Accuracy= 0.96
Epoch= 695, Loss=   7589.66, Accuracy=0.96, Val.Loss=   1593.58, Val.Accuracy= 0.96
Epoch= 696, Loss=   7580.21, Accuracy=0.96, Val.Loss=   1592.09, Val.Accuracy= 0.96
Epoch= 697, Loss=   7570.77, Accuracy=0.96, Val.Loss=   1590.61, Val.Accuracy= 0.96
Epoch= 698, Loss=   7561.35, Accuracy=0.96, Val.Loss=   1589.13, Val.Accurac

Epoch= 785, Loss=   6800.17, Accuracy=0.96, Val.Loss=   1470.88, Val.Accuracy= 0.96
Epoch= 786, Loss=   6792.07, Accuracy=0.96, Val.Loss=   1469.64, Val.Accuracy= 0.96
Epoch= 787, Loss=   6783.99, Accuracy=0.96, Val.Loss=   1468.40, Val.Accuracy= 0.96
Epoch= 788, Loss=   6775.92, Accuracy=0.96, Val.Loss=   1467.17, Val.Accuracy= 0.96
Epoch= 789, Loss=   6767.85, Accuracy=0.96, Val.Loss=   1465.94, Val.Accuracy= 0.96
Epoch= 790, Loss=   6759.82, Accuracy=0.96, Val.Loss=   1464.71, Val.Accuracy= 0.96
Epoch= 791, Loss=   6751.79, Accuracy=0.96, Val.Loss=   1463.47, Val.Accuracy= 0.96
Epoch= 792, Loss=   6743.77, Accuracy=0.96, Val.Loss=   1462.24, Val.Accuracy= 0.96
Epoch= 793, Loss=   6735.77, Accuracy=0.96, Val.Loss=   1461.02, Val.Accuracy= 0.96
Epoch= 794, Loss=   6727.78, Accuracy=0.96, Val.Loss=   1459.79, Val.Accuracy= 0.96
Epoch= 795, Loss=   6719.80, Accuracy=0.96, Val.Loss=   1458.57, Val.Accuracy= 0.96
Epoch= 796, Loss=   6711.83, Accuracy=0.96, Val.Loss=   1457.35, Val.Accurac

Epoch= 883, Loss=   6068.95, Accuracy=0.97, Val.Loss=   1359.85, Val.Accuracy= 0.96
Epoch= 884, Loss=   6062.12, Accuracy=0.97, Val.Loss=   1358.83, Val.Accuracy= 0.96
Epoch= 885, Loss=   6055.30, Accuracy=0.97, Val.Loss=   1357.81, Val.Accuracy= 0.96
Epoch= 886, Loss=   6048.50, Accuracy=0.97, Val.Loss=   1356.78, Val.Accuracy= 0.96
Epoch= 887, Loss=   6041.72, Accuracy=0.97, Val.Loss=   1355.76, Val.Accuracy= 0.96
Epoch= 888, Loss=   6034.94, Accuracy=0.97, Val.Loss=   1354.74, Val.Accuracy= 0.96
Epoch= 889, Loss=   6028.18, Accuracy=0.97, Val.Loss=   1353.72, Val.Accuracy= 0.96
Epoch= 890, Loss=   6021.41, Accuracy=0.97, Val.Loss=   1352.71, Val.Accuracy= 0.96
Epoch= 891, Loss=   6014.67, Accuracy=0.97, Val.Loss=   1351.70, Val.Accuracy= 0.96
Epoch= 892, Loss=   6007.93, Accuracy=0.97, Val.Loss=   1350.68, Val.Accuracy= 0.96
Epoch= 893, Loss=   6001.22, Accuracy=0.97, Val.Loss=   1349.68, Val.Accuracy= 0.96
Epoch= 894, Loss=   5994.50, Accuracy=0.97, Val.Loss=   1348.68, Val.Accurac

Epoch= 981, Loss=   5453.17, Accuracy=0.97, Val.Loss=   1268.10, Val.Accuracy= 0.97
Epoch= 982, Loss=   5447.39, Accuracy=0.97, Val.Loss=   1267.26, Val.Accuracy= 0.97
Epoch= 983, Loss=   5441.64, Accuracy=0.97, Val.Loss=   1266.42, Val.Accuracy= 0.97
Epoch= 984, Loss=   5435.89, Accuracy=0.97, Val.Loss=   1265.58, Val.Accuracy= 0.97
Epoch= 985, Loss=   5430.16, Accuracy=0.97, Val.Loss=   1264.74, Val.Accuracy= 0.97
Epoch= 986, Loss=   5424.44, Accuracy=0.97, Val.Loss=   1263.90, Val.Accuracy= 0.97
Epoch= 987, Loss=   5418.72, Accuracy=0.97, Val.Loss=   1263.06, Val.Accuracy= 0.97
Epoch= 988, Loss=   5413.02, Accuracy=0.97, Val.Loss=   1262.23, Val.Accuracy= 0.97
Epoch= 989, Loss=   5407.32, Accuracy=0.97, Val.Loss=   1261.40, Val.Accuracy= 0.97
Epoch= 990, Loss=   5401.64, Accuracy=0.97, Val.Loss=   1260.57, Val.Accuracy= 0.97
Epoch= 991, Loss=   5395.96, Accuracy=0.97, Val.Loss=   1259.74, Val.Accuracy= 0.97
Epoch= 992, Loss=   5390.31, Accuracy=0.97, Val.Loss=   1258.91, Val.Accurac