# Neural Network from scratch in Pure Python

In [1]:
import random
import json
import time

In [2]:
class Neuron:
    def __init__(self, weights: list[float], bias: float, inputs: list[float]):
        self.weights = weights
        self.bias = bias
        self.inputs = inputs
        self.output = 0
    
    def forward(self):
        self.output = sum([self.weights[i] * self.inputs[i] for i in range(len(self.weights))]) + self.bias
        return self.output


class NeuralNetwork:
    def __init__(self, input_size: int = 784, hidden_layers = [512, 512], output_size: int = 10):
        self.input_size = input_size
        self.hidden_layers = hidden_layers
        self.output_size = output_size
        self.weights = []
        self.biases = []

        # Input to hidden Layers Network
        
        # self.weights.append(0.01 * np.random.randn(input_size, hidden_layers[0]))
        self.weights.append([[0.01 * random.gauss(0, 1) for _ in range(hidden_layers[0])] for _ in range(input_size)])


        # self.biases.append(np.zeros(output_size))
        self.biases.append([0 for _ in range(output_size)])

    
    def forward(self, inputs: list[float]):
        layers = [inputs]

        for i in range(len(self.weights)):
            # Dot product

            # layers.append(np.dot(layers[-1], self.weights[i]) + self.biases[i])
            layers.append([sum(layers[-1][j] * self.weights[i][j][k] for j in range(len(layers[-1]))) + self.biases[i][k] for k in range(len(self.biases[i]))])
        
        return layers[-1]

    def _from_pyt_state_dict(self, json_path: str):
        with open(json_path, 'r') as f:
            data = json.load(f)
        
        self.weights = []
        self.biases = []

        # we need to rotate the weights
        # a pytorch layedr Linear(a,b) has len(data[key]) = b and len(data[key][0]) = a
        # but in our case, we need to have len(data[key]) = a and len(data[key][0]) = b
        
        # json is like this: {"model.0.weight": [[], [], ...], "model.0.bias": [], "model.1.weight": [], "model.1.bias": [], ...}
        for key in data:
            if 'weight' in key:
                d = data[key]
                # rotate the matrix
                # d = np.array(d).T.tolist()
                d = list(map(list, zip(*d)))
                self.weights.append(d)
            elif 'bias' in key:
                d = data[key]
                self.biases.append(d)
            else:
                raise ValueError('Invalid key')
        
        # reset the input size, hidden layers, and output size to the ones in the model
        self.input_size = len(self.weights[0])
        self.hidden_layers = [len(self.weights[i]) for i in range(1, len(self.weights)-1)]
        self.output_size = len(self.weights[-1][0])
        
        return self
    
    def __str__(self):
        return f'NeuralNetwork(input_size={self.input_size}, hidden_layers={self.hidden_layers}, output_size={self.output_size})'
    
    def summary(self):
        n_total_params = 0
        for i in range(len(self.weights)):
            n_total_params += len(self.weights[i]) + len(self.biases[i])
        
        print(f'Neural Network Summary\n{"-"*20}\n')
        print(f'Input Size: {self.input_size}')
        print("Layer Shapes:")
        for i in range(len(self.weights)):
            print(f'Layer {i+1}: {len(self.weights[i])}x{len(self.weights[i][0])}')


In [8]:
# ./model.json is a json file containing the state_dict of the model (pretrained)
model = NeuralNetwork()._from_pyt_state_dict('./model.json')
print(model)
model.summary()

sample_X = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3686274588108063, 0.4274509847164154, 0.4274509847164154, 0.4313725531101227, 0.5882353186607361, 0.9921568632125854, 0.9921568632125854, 0.4313725531101227, 0.05882352963089943, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04313725605607033, 0.6078431606292725, 0.8509804010391235, 0.9686274528503418, 0.9882352948188782, 0.9882352948188782, 0.9921568632125854, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.9921568632125854, 0.6235294342041016, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04313725605607033, 0.5882353186607361, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.9647058844566345, 0.843137264251709, 0.6039215922355652, 0.27843138575553894, 0.7058823704719543, 0.9882352948188782, 0.9921568632125854, 0.7019608020782471, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2862745225429535, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.501960813999176, 0.3607843220233917, 0.0, 0.0, 0.0, 0.14509804546833038, 0.9882352948188782, 0.9921568632125854, 0.7019608020782471, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2862745225429535, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.13725490868091583, 0.0, 0.0, 0.0, 0.0, 0.7098039388656616, 0.9882352948188782, 0.8078431487083435, 0.0784313753247261, 0.0, 0.24705882370471954, 0.4313725531101227, 0.05882352963089943, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03921568766236305, 0.7450980544090271, 0.9882352948188782, 0.9882352948188782, 0.13725490868091583, 0.0, 0.0, 0.0, 0.24705882370471954, 0.9490196108818054, 0.9882352948188782, 0.0784313753247261, 0.24705882370471954, 0.3686274588108063, 0.929411768913269, 0.9921568632125854, 0.8705882430076599, 0.24313725531101227, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6274510025978088, 0.9882352948188782, 0.9882352948188782, 0.5411764979362488, 0.0, 0.0, 0.062745101749897, 0.7098039388656616, 0.9882352948188782, 0.9882352948188782, 0.7137255072593689, 0.9490196108818054, 0.9882352948188782, 0.9882352948188782, 0.9921568632125854, 0.7019608020782471, 0.0784313753247261, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.062745101749897, 0.7450980544090271, 0.9882352948188782, 0.8666666746139526, 0.16078431904315948, 0.0, 0.4274509847164154, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.9921568632125854, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.7490196228027344, 0.05882352963089943, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5686274766921997, 0.9921568632125854, 0.9921568632125854, 0.9098039269447327, 0.4274509847164154, 0.6705882549285889, 0.9921568632125854, 0.9921568632125854, 0.9921568632125854, 1.0, 0.9921568632125854, 0.686274528503418, 0.24313725531101227, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0784313753247261, 0.501960813999176, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.9921568632125854, 0.9882352948188782, 0.9450980424880981, 0.7019608020782471, 0.7058823704719543, 0.21568627655506134, 0.03921568766236305, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003921568859368563, 0.5058823823928833, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.929411768913269, 0.35686275362968445, 0.239215686917305, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294117748737335, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.5647059082984924, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5686274766921997, 0.9921568632125854, 0.9921568632125854, 0.9921568632125854, 0.9921568632125854, 0.5686274766921997, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04313725605607033, 0.686274528503418, 0.9882352948188782, 0.7411764860153198, 0.9058823585510254, 0.9882352948188782, 0.9333333373069763, 0.11764705926179886, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22745098173618317, 0.9882352948188782, 0.9882352948188782, 0.13725490868091583, 0.7098039388656616, 0.9882352948188782, 0.9921568632125854, 0.5411764979362488, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7098039388656616, 0.9882352948188782, 0.7411764860153198, 0.05882352963089943, 0.7098039388656616, 0.9882352948188782, 0.9921568632125854, 0.3803921639919281, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16470588743686676, 0.8705882430076599, 0.9921568632125854, 0.5647059082984924, 0.0, 0.7137255072593689, 0.9921568632125854, 0.5686274766921997, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.545098066329956, 0.9882352948188782, 0.929411768913269, 0.5254902243614197, 0.9490196108818054, 0.9882352948188782, 0.5647059082984924, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125490203499794, 0.9254902005195618, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.4901960790157318, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5647059082984924, 0.9882352948188782, 0.9882352948188782, 0.9882352948188782, 0.658823549747467, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

t_start = time.time()
pred = model.forward(sample_X)
t_end = time.time()
pred = pred.index(max(pred))
print(f'Prediction: {pred}')
print(f'Time taken: {t_end - t_start:.4f} seconds')

NeuralNetwork(input_size=784, hidden_layers=[256, 512, 1024, 1024, 512], output_size=10)
Neural Network Summary
--------------------

Input Size: 784
Layer Shapes:
Layer 1: 784x256
Layer 2: 256x512
Layer 3: 512x1024
Layer 4: 1024x1024
Layer 5: 1024x512
Layer 6: 512x256
Layer 7: 256x10
Prediction: 8
Time taken: 0.1932 seconds
