# Neuron

> Neuron with activation function included to create neural networks with individual neuron activations

In [None]:
#| default_exp neuron

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import jax
import jax.numpy as jnp
import equinox as eqx

In [None]:
#| export
class Neuron(eqx.Module):
    """
    A simple neuron with a weight vector, bias, and activation function.
    """
    weight: jax.Array
    bias: jax.Array
    activation: callable

    def __init__(self, in_features, activation=jax.nn.relu, key=None):
        if key is None:
            key = jax.random.PRNGKey(0)
            key, _ = jax.random.split(key)
        w_key, b_key = jax.random.split(key)
        self.weight = jax.random.normal(w_key, (in_features,))
        self.bias = jax.random.normal(b_key, ())

        self.activation = activation

    def __call__(self, x):
        return self.activation(jnp.dot(self.weight, x) + self.bias)
    
    def importance(self):
        """
        Returns the importance of the neuron. This is the L2 norm of the weight vector.
        """
        return jnp.linalg.norm(self.weight)/jnp.sqrt(self.weight.size)

In [None]:
#| test 
neuron = Neuron(10)
x = jax.random.normal(jax.random.PRNGKey(0), (10,))
y = neuron(x)
assert y.shape == ()
assert neuron.importance().shape == ()


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()