# Neural Networks are General Function Approximators

Part of the reason that neural networks are an incredible tool, is because they can learn to represent any continuous function.

> They are _general function approximators_.

That means, that whatever your input-output relationship, a neural network can theoretically learn it.

# TODO img

In fact, a neural network with just one hidden layer, can represent _any_ continuous function, as long as that hidden layer has enough nodes. 

We refer to this as the _universal approximation theorem_.

Caveats:
- "enough nodes" can in practice mean "an infinite number of nodes"
- The network can only represent the ideal function given the ideal parameterisation, and finding that ideal parameterisation (or one that's close enough) can be extremely challenging

Universal approximation of neural networks can be demonstrated in a few steps:
1. Acknowledge that any continuous function can be accurately approximated by a piece-wise function with enough pieces
1. Write the full equation for a NN with ReLU activation
1. Show that it is equivalent to a sum of ReLUs
1. Show that the sum of ReLUs is equivalent to a piece-wise function

## Step 1: Any continuous function can be approximated by a piece-wise function

The shape of any function can be drawn roughly by connecting straight lines.

The more lines that you connect together, the more accurately you can represent the function.

With an infinite number of infinitely short lines, you can perfectly represent the function.

# TODO img of different size lines approximating a graph

## Step 2: The full equation for a NN

The full equation of a ReLU activation function is given as follows:

# TODO full eqn

This is the function that we wish to show can approximate any continuous function.

## Step 3: The single layer NN is equivalent to a sum of ReLUs

If we expand the full equation, it can be rearranged as a sum of ReLUs

## Step 4: A sum of ReLUs is equivalent to a piece-wise function

If we plot all of the different ReLU terms on the same graph, you can see the difference between each of them caused by their different parameters.


- The associated input weight stretches the x-dimension

# TODO img

- The associated output weight stretches the y-dimension

# TODO img

- The input bias shifts where the ReLU input becomes greater than zero. Beyond this value, the output is switched on.

As per the NN equation, the network is a sum of all of these ReLUs. When we add them together, we can see how they approximate the input-output relationship we are trying to model:

![](./images/side%20by%20side%20approximation%20and%20sum%20of%20relus.png)

Although it's a little complicated, you can find the code to generate these graphs below.

In [None]:
!pip install plotly
!pip install torch
!pip install numpy

In [None]:
# %%
import torch
import plotly.express as px
import numpy as np


class XYDataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        self.x = torch.tensor(np.linspace(-10, 10)).float().unsqueeze(1)
        self.y = torch.tensor(0.04*self.x**3 - 2*self.x).float()

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)


class NN(torch.nn.Module):
    def __init__(self, hidden_dim, activation):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.activation = activation()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(1, hidden_dim),
            self.activation,
            torch.nn.Linear(hidden_dim, 1)
        )

    def forward(self, X):
        return self.layers(X)

    def visualise(self, X):
        X = X.squeeze()
        params = list(self.parameters())
        b2 = params.pop().detach().numpy()
        w1, b1, w2 = params
        for idx, p in enumerate(params):
            params[idx] = list(p.view(-1).detach().numpy())
        params = zip(*params)
        plot_payload = {}
        for idx, (w1, b1, w2) in enumerate(params):
            val = w2 * self.activation(w1 * X + b1)
            val = val.detach().numpy()
            plot_payload[f"Hidden node {idx}"] = val
        plot_payload["Bias"] = np.ones(len(X)) * b2  # visualise bias
        plot_payload["Input"] = X
        fig = px.line(plot_payload, x="Input",
                      y=list(plot_payload.keys())[:-1])
        fig.update_layout(
            legend_title_text='Output contribution from',
            yaxis_title="Output"
        )
        # contribution to output
        fig.show()
        # return plot_payload


def train(model, dataloader, epochs=500):
    optimiser = torch.optim.SGD(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        for batch in dataloader:
            x, y = batch
            prediction = model(x)
            loss = torch.nn.functional.mse_loss(prediction, y)
            loss.backward()
            optimiser.step()
            optimiser.zero_grad()


def train_nn_and_visualise(
    hidden_nodes,
    activation,
    epochs
):
    dataset = XYDataset()
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=4, shuffle=True)

    model = NN(hidden_dim=hidden_nodes, activation=activation)
    train(model, dataloader, epochs)

    predictions = model(dataset.x)
    predictions = predictions.squeeze().detach()

    plot_data = {
        "Input": dataset.x.squeeze(),
        "Target": dataset.y.squeeze(),
        "Predictions": predictions
    }
    fig = px.line(plot_data, x="Input", y=["Target", "Predictions"])
    # fig.update_layout(plot_bgcolor="#18181B")
    fig.update_layout(
        legend_title_text='',
        yaxis_title="Output"
    )
    fig.show()

    model.visualise(dataset.x)


if __name__ == "__main__":
    train_nn_and_visualise(
        hidden_nodes=3,
        activation=torch.nn.ReLU,
        epochs=1000
    )
    # TODO generate gifs
    # TODO generate param permutations
    # TODO aicore theme
# %%
