# Training example

This notebook shows how to train neural network in python in a way that is compatible with the C library.
It does not shows how to solve any specific problem, but rather how to transfer data to C.

In [14]:
import numpy as np

In [18]:
# this neural network works the same way as the one in C

class NeuralNetwork:
    def __init__(self, shape: np.array, genotype: np.array):
        self.activation = np.tanh
        self.w = []
        self.genotype = genotype


        parameters_n = 0
        for i in range(len(shape) - 1):
            parameters_n += shape[i] * shape[i+1]
            parameters_n += shape[i+1]

        assert len(self.genotype) == parameters_n, f"genotype: {len(self.genotype)}, parameters: {parameters_n}"
            
        for i in range(len(shape) - 1):
            self.w.append(genotype[:shape[i] * shape[i+1]].reshape(shape[i + 1], shape[i]))
            genotype = genotype[shape[i] * shape[i+1]:]

            self.w.append(genotype[:shape[i+1]])
            self.w[-1] = self.w[-1].reshape(shape[i+1], 1)
            genotype = genotype[shape[i+1]:]

        
    def forward(self, inputs):
        inputs = inputs.reshape(1, -1).T
        for i in range(0, len(self.w), 2):
            inputs = self.w[i] @ inputs
            inputs += self.w[i + 1]
            inputs = self.activation(inputs)
        
        return inputs

In [20]:
# Let's say we want to solve some problem by implementing genetic algorithm.
# It is done by finding the best "genotype" of the function.
# In this case, our function is neural network.

# Create a random genotype
genotype = np.random.randn(26)

# Create a nn with that genotype
nn = NeuralNetwork(np.array([5, 3, 2]), genotype)

# In case you are wondering how do I know that the genotype has to have 26 elements:
# neural network class checks if genotype can fill all weights and biases and tells you if there is wrong number of elements.

# If neural network with that genotype is doing well, you can print the genotype and paste it into the C code.
print(repr(genotype))

# Of course, this genotype will not be good, because it was generated randomly.
# You will have to implement some sort of learning algorithm.

array([-2.48124910e+00,  1.31544128e+00, -5.05760264e-01,  2.39379504e+00,
        3.16341236e-02, -1.24785545e+00,  4.48755239e-01, -1.29709493e+00,
       -4.69826179e-01,  3.76892429e-01,  7.22700900e-01, -8.83840977e-01,
        7.52509880e-01,  6.22580655e-01,  1.87455640e+00,  1.57409240e+00,
        1.86056949e-01, -4.44658623e-01,  2.13087276e-03, -4.28654651e-01,
       -4.72945476e-01, -1.64698064e+00,  1.45948047e+00, -1.00046298e+00,
       -5.79613204e-01, -4.53596982e-01])
