# Sine wave regression in JAX with Equinox and Optax

## Import relevant packages

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt
from typing import List

## Note on importing List from the typing module

In [None]:
def int_to_str_custom(original_list: list[int]):
    converted_list = [str(int_iter) for int_iter in original_list]
    return converted_list

In [None]:
int_to_str_custom([1, 2, 3])

In [None]:
int_to_str_custom(["a", "b", "c"])

1. Importing List from the type module allows useage of type hints for lists where we can specify the type of the contents of the list
2. For the Python version on my laptop it is not a strict requirement to adhere to the specified type as can be seen from the above example
3. For the Python version on my laptop importing List from typing is also not required

## Set up hyperparameters

In [None]:
n_samples = 200
learning_rate = 0.1
n_epochs = 100000

## Set up network layer sizes

In [None]:
n_input_neurons = 1
#n_hidden_layers = 3
n_hidden_layers = 2
n_hidden_neurons = 10
n_output_neurons = 1
layers = [n_input_neurons] + n_hidden_layers*[n_hidden_neurons] + [n_output_neurons]

In [None]:
layers

## Generate toy data as column vectors

In [None]:
x_samples = jnp.linspace(0, 2*jnp.pi, n_samples)

## Check shape

In [None]:
x_samples.shape

## Reshape

In [None]:
x_samples = x_samples.reshape(-1, 1)

## Check shape after reshaping

In [None]:
x_samples.shape

## Generate sine wave column vector and check shape

In [None]:
y_samples = jnp.sin(x_samples)

In [None]:
y_samples.shape

## Plot toy data

In [None]:
plt.figure()
plt.scatter(x_samples, y_samples)
plt.xlabel("x")
plt.ylabel("y")
plt.title("toy data for regression")
plt.grid()

## Define a simple MLP class

In [None]:
class simpleMLP(eqx.Module):

    layers: List[eqx.nn.Linear]

    def __init__(self, layers_size_params, key):
        self.layers = []
        for (dim_in, dim_out) in zip(layers_size_params[:-1], layers_size_params[1:]):
            key, subkey = jax.random.split(key)
            self.layers.append(
                eqx.nn.Linear(dim_in, dim_out, use_bias=True, key=subkey)
            )

    def __call__(self, x):
        a = x
        for layer in self.layers[:-1]:
            a = jax.nn.sigmoid(layer(a))
        a = self.layers[-1](a)
        return a

## Notes on simple MLP class

1. We should start by creating a skeleton of the class - that is we define the class name, the parent class it inherits from (in this case, equinox's Module) and the initialization (using init) and forward passes (using call)
2. Next we should think about filling the initialization block - the purpose of this block is to initialize the Neural Network layers - for which we need the layer size parameters and a random key
3. We start by defining an attribute "layers" that is initialized via assignment of an empty array
4. Next, we need to fill this list of layers with randomly initialized weights and since the shapes of the weights are of the form (n_in, n_out) - we will need to simultaneously iterate over the list containing the input neuron sizes and the output neuron sizes. This simultaneous iteration can be achieved via the zip function. However, we will first need iterable containers that correspond to the varying number of input and output neurons at each layers - this can be easily obtained by noticing that the inputs terminate at the last layer meaning that the layers size parameter list needs to be truncated at the last index as "layers_size_params[:-1]" and that the outputs begin after the first layer meaning that the layers size parameters list needs to start after the first index as "layers_size_params[1:]"
5. Now that we have got the dimensions for initializing weights and biases in place, we need a random key for each layer which can be obtained by splitting the input key resulting in two new keys - the first key will be used to overwrite the input key and the second key will be used for the weights initialization. The intended splitting operation is carried out via the jax.random.split(key) command
6. Finally, we have all the ingredients in place to initialize all the layers - the dimensions and the random keys. We already defined an attribute called "layers" that was initialized via assignment of an empty list. This list was intended as a placeholder and it is now time to fill it!
7. We fill this list via the classic "append" operation where we append entities that are of the type "eqx.nn.Linear" as specified in the type hint at the beginning of the simpleMLP class. We will thus use the eqx.nn.Linear command with the appropriate number of input and output neurons for each layer along with the random key. In addition, we will require the linear layers to have a bias term
8. We end the class with a method to do the forward pass via the "call" keyword. Here, we define a variable "a" to represent the network output and initialize it with the input "x". Next, we iterate through the input and hidden layers via "self.layers[:-1]" and apply a forward pass by iteratively carrying out a non-linear sigmoidal transformation on the variable "a" and feeding it back to the loop. To conclude the forward pass, we simply apply the last layer as a linear transformation outside the loop and return the variable "a" 

## Maybe animate above code block with your ChatGPT manim idea

## Initialize NN

In [None]:
model = simpleMLP(layers, key=jax.random.PRNGKey(0))

## Generate Initial prediction

In [None]:
initial_pred = model(x_samples)

## Notes on the error

1. The error is related to dimensions
2. The documentation suggests using jax's vmap feature on the model to fix the error
3. Using vmap does fix the error
4. However it is not clear to me why vmap fixes the error

## Fix error with vmap

In [None]:
initial_pred = jax.vmap(model)(x_samples)

## Plot initial prediction overlayed on toy data

In [None]:
plt.figure()
plt.scatter(x_samples, y_samples, label="toy data")
plt.plot(x_samples, initial_pred, 'red', label="initial prediction")
plt.xlabel("x")
plt.ylabel("y")
plt.title("Plot intial prediction of NN")
plt.grid()
plt.legend()

## Define MSE Loss

In [None]:
def mse_loss(model, x, y):
    delta = y - jax.vmap(model)(x)
    loss_val = jnp.mean(jnp.square(delta))
    return loss_val

## Compute Initial Loss

In [None]:
initial_loss = mse_loss(model, x_samples, y_samples)

In [None]:
initial_loss

## Introducing Loss and grad

1. Jax and Equinox have functions that provide the function along with its gradients (value and grad)
2. To be on the safe side, we would like to apply the grad with respect to parameters satisfying certain conditions
3. The "filter value and grad" functionality accomodates for such "filtering" or "selection"
4. The default filter used by filter value and grad is "inexact arrays" - where arrays stand for a collection of numbers and the simple explanation of inexact is floating point numbers
5. For our use case the filter value and grad should compute the gradients with respect to parameters that are stored as arrays of floats - in other words, we are "selecting" or "filtering" the weights and biases of the network

In [None]:
mse_loss_and_grad = eqx.filter_value_and_grad(mse_loss)

In [None]:
mse_loss_and_grad(model, x_samples, y_samples)

## Set up the optimizer

In [None]:
opt = optax.sgd(learning_rate)
opt_state = opt.init(eqx.filter(model, eqx.is_array))

## Notes on optimizer

1. We will have initialized the SGD optimizer with the learning rate specificed in the hyperparameters section
2. We then initialize the state of the optimizer with information about the weights and biases of the network
3. To be careful with using the weights and biases, we apply equinox's filtering capabilities and filter for arrays
4. It is worth noting that optax is the "optimization engine" that peforms "gradient processing" and the important information associated with the processing of these gradients required to execute the optimization algorithm at hand is stored in the state of the optimizer

## Define the "make step" function to perform a single optimization step

In [None]:
@eqx.filter_jit
def make_step(model, opt_state, x, y):
    loss, grad = mse_loss_and_grad(model, x, y)
    model_update, opt_state = opt.update(grad, opt_state, model)
    model = eqx.apply_updates(model, model_update)
    return model, opt_state, loss

## Notes on make step

1. The general neural network optimization workflow for a given step is as follows: extract gradients -> compute updates via optimization step -> apply update
2. The mse_loss_and_grad function computes the loss and grad with the model, x and y as inputs
3. The extracted gradients are used to perform the optimization step via "opt.update" and this produces a new optimizer state and updates that we could apply to the model (new weights and biases)
4. The "eqx.apply_updates" line of code updates the weights and biases of the network to the new weights and biases
5. We finally return the network with new weights and biases, the updated optimizer and the value of the loss function at the current optimization step
6. We haved added the "@eqx.filter_jit" line of code before the "make_step" function to speed up the code - this acts as a decorator to the make step function resulting in the jit compilation of this function
7. We will discuss some basic examples of decorators in Python followed by some notes on jit compilation

## Example of a simple decorator in Python

In [None]:
def my_decorator(func):
    def wrapper():
        print("Something is happening before the function is called.")
        func()
        print("Something is happening after the function is called.")
    return wrapper

In [None]:
def say_hello():
    print("Hello!")

## Execute the "say_hello" function without any decorator applied to the function definition

In [None]:
say_hello()

## Define a new "say_hello_with_dec" function and apply the "my_decorator" decorator to this function

In [None]:
@my_decorator
def say_hello_with_dec():
    print("Hello!")

## Execute the new function where a decorator has been applied to the definition

In [None]:
say_hello_with_dec()

## Notes on the decorator example

1. We see from the simple decorator example that decorators extend and modify the behavior of functions
2. In our example the decorator function took a function "func" as input, passed it to a another function called "wrapper" which extended the behavior of "func" and finally, the decorator returned "wrapper" -> the extended version of "func"
3. The typical syntax is to use the "@" symbol followed by the decorator name before defining the function to which we wish to apply the decorator
4. Here we apply the decorator "my decorator" via the syntax "@my_decorator" to the function "say_hello_with_dec" before the line that defines this function with the keyword "def"

## Notes on jit compilation

1. Jax's just in time (jit) compilation accelerates speed of computation
2. The jit compiler in JAX uses XLA (Accelerated Linear Algebra)
3. The "filter" part is again used to "filter" out the weights and biases of the network

## Training Loop

In [None]:
loss_history = []

for epoch in range(n_epochs):
    model, opt_state, loss = make_step(model, opt_state, x_samples, y_samples)
    loss_history.append(loss)

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, loss: {loss}")

## Notes on Training Loop

1. We simply perform the single optimization step iteratively in a loop via the make step function
2. The loss values are extracted so that the loss history can be analyzed later

## Plot loss history on log scale

In [None]:
plt.figure()
plt.plot(loss_history)
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("log loss")
plt.title("loss history on log scale")

## Generate final prediction

In [None]:
final_pred = jax.vmap(model)(x_samples)

## Plot results

In [None]:
plt.figure()
plt.scatter(x_samples, y_samples, label="toy data")
plt.plot(x_samples, final_pred, 'r', label="final prediction")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid()
plt.title("Sine wave regression in JAX with Equinox and Optax using simple MLP")