# MLP
> Create a multilayer perceptron using the Neuron class.
> The mlp has functionality to add or remove nodes during training.

In [None]:
#| default_exp mlp

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import jax
import jax.numpy as jnp
import equinox as eqx

from NeuralNetworkEvolution.neuron import Neuron

In [None]:
#| export
def identity(x):
    """
    Identity activation function
    """
    return x

In [None]:
class CustomMLP(eqx.Module):
    layers: list

    def __init__(self, input_size, hidden_sizes, output_size, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
        keys = jax.random.split(key, len(hidden_sizes) + 1)
        act_key = jax.random.split(keys[-1], 1)[0]
        activation_list = [jax.nn.relu, jax.nn.sigmoid, jax.nn.tanh]
        layers = []
        in_features = input_size

        # Create hidden layers
        for i, out_features in enumerate(hidden_sizes):
            layer = [Neuron(in_features, activation_list[0], key=keys[i]) for _ in range(out_features)]
            layers.append(layer)
            in_features = out_features

        # Create output layer
        output_layer = [Neuron(in_features, activation=identity, key=keys[-1]) for _ in range(output_size)]
        layers.append(output_layer)

        self.layers = layers

    def __call__(self, x):
        for layer in self.layers:
            x = jnp.array([neuron(x) for neuron in layer])
        return x[0]  # Since output layer is a single neuron

    def add_neuron(self, layer_index, activation=jax.nn.relu, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
        in_features = self.layers[layer_index][0].weight.shape[0]
        new_neuron = Neuron(in_features, activation, key)
        self.layers[layer_index].append(new_neuron)

        # Adjust the next layer's weight matrix to include the new neuron
        if layer_index + 1 < len(self.layers):
            for i, next_neuron in enumerate(self.layers[layer_index + 1]):
                new_weight = jax.random.normal(key, (1,))
                updated_weights = jnp.append(next_neuron.weight, new_weight)
                self.layers[layer_index + 1][i] = eqx.tree_at(lambda n: n.weight, next_neuron, updated_weights)

    def remove_neuron(self, layer_index, neuron_index):
        if len(self.layers[layer_index]) > 0:
            del self.layers[layer_index][neuron_index]
        
        # Adjust the next layer's weight matrix to remove the corresponding weight
        if layer_index + 1 < len(self.layers):
            for i, next_neuron in enumerate(self.layers[layer_index + 1]):
                updated_weights = jnp.delete(next_neuron.weight, neuron_index)
                self.layers[layer_index + 1][i] = eqx.tree_at(lambda n: n.weight, next_neuron, updated_weights)
    
    def get_shape(self):
        return [len(layer) for layer in self.layers]

    def least_important_neuron(self):
        all_importances = []
        layer_sizes = []
        for layer in self.layers:
            importances = [n.importance() for n in layer]
            all_importances.append(jnp.array(importances).flatten())  # Flatten the importances
            layer_sizes.append(len(layer))

        all_importances = jnp.concatenate(all_importances)
        sorted_indices = jnp.argsort(all_importances)

        # Locate the layer and neuron index
        cum_neurons = jnp.cumsum(jnp.array(layer_sizes))
        for min_importance_index in sorted_indices:
            layer_index = jnp.searchsorted(cum_neurons, min_importance_index, side="right")
            neuron_index = min_importance_index - (cum_neurons[layer_index - 1] if layer_index > 0 else 0)
            if neuron_index != len(self.layers[layer_index]) - 1:  # If the neuron is not the last one of its layer
                return layer_index, neuron_index
        
        raise ValueError("All neurons are the last ones of their layers")

    def most_important_layer(self):
        # Calculate the total importance of each layer
        layer_importances = [jnp.sum(jnp.array([n.importance() for n in layer])) for layer in self.layers[:-1]]
        most_important_layer_index = jnp.argmax(jnp.array(layer_importances))  # Convert to Jax array

        return most_important_layer_index
