In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import import_ipynb

In [2]:
from Layer import Layer

importing Jupyter notebook from Layer.ipynb


In [3]:
class NeuralNetwork:
    layers = np.empty(0) # Make an empty list of layers
    outputs = torch.empty((1)) # Initialize an empty tensor to store outputs
    
    # Variables to count how mnay were right and wrong (for display purposes)
    numCorrect = 0
    numIncorrect = 0
    
    def __init__(self, layerSizes):
        # Makes as many layers as needed
        for layerIndex in range(0, layerSizes.size-1):
            layerAppend = Layer(layerSizes[layerIndex], layerSizes[layerIndex+1])
            self.layers = np.append(self.layers, layerAppend)
    
    def forwardProp(self, inputs):
        activation = torch.empty((1)) # Intialize a temporary tensor to store the activated output
        for layer in self.layers:
            activation, inputs = layer.calculateLayers(inputs) # Calculates the outputs for every layer until final output is calculated
            
        self.outputs = inputs # Store the final output in tensor
        return self.outputs
    
    def cost(self, predicted, expected):
        meanSquaredError = (expected - predicted)**2 # Mean Squared Error to find the error between actual and calculated
        totalCost = torch.sum(meanSquaredError) / expected.size(dim=0) # Find the average between all the values
        
        self.checkCorrectTensor(expected) # Checks if the value matches the real value
        
        return meanSquaredError, totalCost # Returns [MSQ: Tensor], [MSQ: float]
    
    def checkCorrectTensor(self, expected):
        # Gets the idnex of the highest value in tensor and compares if it matches
        # If it matches: Increment Correct, otherwise Increment Incorrect
        if self.outputs.argmax() == expected.argmax():
            self.numCorrect += 1
        else:
            self.numIncorrect += 1
            
    def checkCorrectLabel(self, label):
        # Same as above, but for general use
        if self.outputs.argmax() == label:
            self.numCorrect += 1
        else:
            self.numIncorrect += 1
            
    def backProp(self, learnRate, error):
        # Iterates over every layer until every weight and biases have been optimized
        for layer in reversed(self.layers):
            error = layer.backwardProp(learnRate, error)
    
    def train(self, inputs, label, samples, epochs, learnRate):
        for epoch in range(epochs):
            err = 0 # Final error of epoch
            for i in range(samples):
                output = inputs[i] # Gets One Input from list of inputs
                onehot = torch.nn.functional.one_hot(label[i], num_classes=10) # Calculates the onehot tensor for the label
                onehot = onehot.T.reshape((-1, 1)) # Reshapes onehot from (1, 10) => (10, 1)
                
                output = self.forwardProp(output) # Calculates the predicted output
                
                lossTensor, loss = self.cost(output, onehot) # Calculates the Cost or Error between real and predicted values
                err += loss # Adds to the error of the epoch
                
                error = 2*(output-onehot)/onehot.size(dim=0) # Find the error of the cost function
                self.backProp(learnRate, error) # Optimizes weights and biases
                
                # Every 50 Steps in an epoch, print the following details:
                # Epoch, Step, Loss of Step, Error of Epoch, Real Result Label, Expected Result Label
                if (i+1) % 50 == 0:
                    print(f'Epoch [{epoch+1} / {epochs}], Step [{i+1}/{inputs.size(dim=0)}], Loss: {loss.item():.4f}, Error: {err:.4f}')
                    print(f'Should Be: {label[i]}')
                    print(f'Result Was: {output.argmax()}')
            
            # At the end of every epoch, print the following details:
            # Accuracy of epoch, Number Correct, Number Incorrect
            acc = 100 * self.numCorrect / inputs.size(dim=0)
            print(f'Accuracy of Epoch: {acc} %')
            print(f'NumCorrect: {self.numCorrect}, NumIncorrect: {self.numIncorrect}')
            print('')
            self.numCorrect = 0
            self.numIncorrect = 0