<a href="https://www.kaggle.com/code/alembcke/titanic-multi-layer-perceptron-using-haiku-jax?scriptVersionId=103649732" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Titanic Multi-Layer Perceptron using Haiku/JAX

*by Alex Lembcke*

*August 2022*

This notebook demonstrates how to build a simple multi-layer perceptron neural network using [Haiku](https://github.com/deepmind/dm-haiku) and [JAX](https://github.com/google/jax).  We will use the Titanic dataset as part of the introductory Kaggle competition [Titanic - Machine Learning from Disaster](https://www.kaggle.com/competitions/titanic).  The goal of this competition is to predict whether an individual aboard the Titanic survived its sinking on April 15, 1912.  Competitors are given a training dataset containing various data on 891 passengers, including whether or not they survived.  They must then provide their predictions for survival for 418 other passengers, given the same information as was provided in the training dataset, except whether the passenger survived, of course.

As Haiku is not install by default in Kaggle, so we will first need to install it:

In [1]:
!pip install --upgrade dm-haiku

Collecting dm-haiku
  Downloading dm_haiku-0.0.7-py3-none-any.whl (342 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m342.4/342.4 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.7 jmp-0.0.2
[0m

Now we can import Haiku and JAX, as well as some other libraries that will help us along the way:

In [2]:
# Typing library to hold training state
from typing import NamedTuple

# Libraries for input/output and data manipulation
import pandas as pd
import numpy as np

# Make numpy values easier to read
np.set_printoptions(precision=3, suppress=True)

# Machine learning libraries
import haiku as hk
import jax
import jax.numpy as jnp

# Optimization library
import optax

# Print version numbers
print("JAX version {}".format(jax.__version__))
print("Haiku version {}".format(hk.__version__))

JAX version 0.3.14
Haiku version 0.0.7


## Data Cleansing

Unfortunately, or perhaps to add an element of realism to this Kaggle competition, the dataset provided is not entirely perfect.  But there is a whole host of notebooks detailing the issues in the dataset and how to solve them.  And since the goal of this notebook is to demonstrate how to implement a neural network using Haiku and JAX, we will simply clean the data and move on.  First, we must retrieve the datasets:

In [3]:
train = pd.read_csv("../input/titanic/train.csv")
test = pd.read_csv("../input/titanic/test.csv")

One of the data issues is that the Embarked field is missing two data points in the training dataset.  Since it is only two data points, we can check online to confirm that both of those passengers boarded at Southhampton and fill in the missing information:

In [4]:
train['Embarked'].fillna('S', inplace=True)

Next is the big issue: the Age field is missing quite a lot of information in both the training dataset and the testing dataset.  Given how much of the data is missing and how big of an impact it could make on the final predictions, it makes sense to fill in this data as accurately as possible - although it is too much to verify by hand.  One solution proposed by a number of Kagglers is to fill in the age field by inferring the value given the title of the individual (from the Name field) and for those with the title of "Miss" to add in whether they are traveling with a parent or not.

We will use the solution provided in the notebook [Titanic Missing Age Imputation Tutorial - Advanced](https://www.kaggle.com/code/allohvk/titanic-missing-age-imputation-tutorial-advanced/notebook) to fill in the missing ages:

In [5]:
# Create the Title field and fill it in
train['Title'], test['Title'] = [df.Name.str.extract (' ([A-Za-z]+)\.', expand=False) for df in [train, test]]
TitleDict = {"Capt": "Officer","Col": "Officer","Major": "Officer","Jonkheer": "Royalty", \
             "Don": "Royalty", "Sir" : "Royalty","Dr": "Royalty","Rev": "Royalty", \
             "Countess":"Royalty", "Mme": "Mrs", "Mlle": "Miss", "Ms": "Mrs","Mr" : "Mr", \
             "Mrs" : "Mrs","Miss" : "Miss","Master" : "Master","Lady" : "Royalty"}
train['Title'], test['Title'] = [df.Title.map(TitleDict) for df in [train, test]]
# Create a table showing the mean age per age group
grp = train.groupby(['Pclass','Sex','Title'])['Age'].mean().reset_index()[['Sex', 'Pclass', 'Title', 'Age']]

def fill_age(x):
    """Fills in the missing values for the Age field."""
    return grp[(grp.Pclass==x.Pclass)&(grp.Sex==x.Sex)&(grp.Title==x.Title)]['Age'].values[0]
train['Age'], test['Age'] = [df.apply(lambda x: fill_age(x) if np.isnan(x['Age']) else x['Age'], axis=1) for df in [train, test]]

While this solution does present the issue of data leakage from training dataset to testing dataset, by filling in the data in the testing dataset using values imputed from the training dataset, given the goal of this notebook and that is an introductory competition, we will simply note this is not a good idea and continue forward.

Now we have (fairly) clean training and testing datasets.

## Data Pipeline

As our goal is to setup a multi-layer perceptron using Haiku/JAX, we will skip the exploratory data analysis step and go straight to building our model - starting with building the data pipeline.  We will start by choosing which features to use in our model.  We can drop the `Title` field we created to impute the age, as well as the `Cabin` and `Fare` fields as they require a significant amount of feature engineering to provide value, and again that is not the stated goal of this notebook.

In [6]:
# Setup training data
titanic_features = train[['Survived', 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Embarked']]
titanic_labels = titanic_features.pop('Survived')

# Setup testing data
titanic_test = test[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Embarked']]

Next we need to convert categorical data into one-hot vectors and normalize numerical data.  The two categorical fields are `Sex` and `Embarked`, while the remainder of the fields are numerical.  Then we can merge the fields back together to get something suitable as input for a neural network.

It should be noted that there is a lot of feature engineering that could be done here, and which others have demonstrated.  For example, `Age` would likely be better when broken into buckets, rather than left as a numerical input.  Also, the fields `SibSp` and `Parch` would be better used to match families together.  But we will focus on our goal of providing a simplified demonstration of a neural network using Haiku and JAX and leave the advanced feature engineering as an exercise for the reader (that's you).

In [7]:
def preprocess(dataset):
    """Preprocesses the inputs, transforming categorical inputs to one-hot vectors and
    normalizes numerical inputs. Combines and returns result."""
    # Convert categorical data to one-hot vectors
    sex_numeric = dataset['Sex'].map( {'male': 0, 'female': 1} ).astype(int).to_numpy()
    sex_one_hot = jax.nn.one_hot(sex_numeric, num_classes=2)
    embarked_numeric = dataset['Embarked'].map( {'S': 0, 'C': 1, 'Q': 2} ).astype(int).to_numpy()
    embarked_one_hot = jax.nn.one_hot(embarked_numeric, num_classes=3)
    # Normalize numeric inputs
    numeric_inputs = dataset[['Pclass', 'Age', 'SibSp', 'Parch']]
    normalized_inputs = []
    for feature in numeric_inputs:
        norm = jnp.linalg.norm(numeric_inputs[feature].to_numpy())
        normalized_inputs.append(numeric_inputs[feature].to_numpy() / norm)
    # Append all of the inputs together
    return jnp.array([sex_one_hot[:,0],
                      sex_one_hot[:,1],
                      embarked_one_hot[:,0],
                      embarked_one_hot[:,1],
                      embarked_one_hot[:,2],
                      normalized_inputs[0],
                      normalized_inputs[1],
                      normalized_inputs[2],
                      normalized_inputs[3]
                     ]).T

titanic_features = preprocess(titanic_features)
titanic_labels = jnp.array(titanic_labels.to_numpy()).reshape((titanic_features.shape[0], 1))

titanic_test = preprocess(titanic_test)

Now our data is suitable to be passed into a neural network, so time to get to the fun part!

## Multi-Layer Perceptron Model

Now that we cleaned and preprocessed our data, it is time to build our model.  We will start by declaring a class to hold our training state.  The training state will keep track of the trained parameters, an exponential average of the trained parameters and our optimizer state:

In [8]:
class TrainingState(NamedTuple):
    params: hk.Params
    avg_params: hk.Params
    opt_state: optax.OptState

Next, we will define our network.  As mentioned in the title, we will use a simple multi-layer perceptron model, which was taken from the [MNIST Example](https://github.com/deepmind/dm-haiku/blob/main/examples/mnist.py) provided on the GitHub page for the Haiku library - in fact, much of this code was adapted from that example, with modifications as required.  But as that example is for a classification problem, and we are solving a probability problem, the network was adjusted to output a probability by adding a sigmoid activation function to the last layer, and reducing that layer to one neuron:

In [9]:
def net_fn(features: jnp.ndarray) -> jnp.ndarray:
    """Standard LeNet-300-100 MLP network."""
    mlp = hk.Sequential([
        hk.Flatten(),
        hk.Linear(300), jax.nn.relu,
        hk.Linear(100), jax.nn.relu,
        hk.Linear(1), jax.nn.sigmoid
    ])
    return mlp(features)

Now we need to create an instance our model and optimizer, for which we will use the Adam algorithm:

In [10]:
network = hk.without_apply_rng(hk.transform(net_fn))
optimiser = optax.adam(1e-3)

As the problem we are solving is one of probability, the chance of survival aboard the Titanic, we will use  binary cross-entropy loss:

In [11]:
def binary_cross_entropy(logits: jnp.ndarray, labels: jnp.ndarray, epsilon):
    return labels * jnp.log(logits + epsilon) + (1 - labels)*jnp.log(1 - logits + epsilon)

We will add gradient clipping to our loss function to avoid vanishing/exploding gradients, a problem that did arise when the network was run without gradient clipping.  If you are new to machine learning and want to see what happens when we don't clip the gradients, then set `epsilon = 0` below and rerun the notebook - you should notice a problem when running the training loop a few steps later.

In [12]:
def loss(params: hk.Params, features: jnp.ndarray, labels: jnp.ndarray):
    """Binary cross entropy loss with gradient clipping."""
    epsilon = 1e-7
    jnp.clip(labels, epsilon, 1-epsilon)
    m = labels.shape[1]
    logits = network.apply(params, features)
    return -1/m * jnp.mean(binary_cross_entropy(logits, labels, epsilon))

Then we will create our `evaluate` function, which we will use to keep track of how our model is performing in the training stage:

In [13]:
@jax.jit
def evaluate(params: hk.Params, features: jnp.ndarray, labels: jnp.ndarray):
    """Checks the accuracy of predictions compared to labels."""
    logits = network.apply(params, features)
    predictions = jnp.around(logits, 0)
    return jnp.mean(predictions == labels)

We will also need to update our parameters as we train our model, by using `jax.grad` to calculate the gradients of our parameters and then updating them using `optax`:

In [14]:
@jax.jit
def update(state: TrainingState, features: jnp.ndarray, labels: jnp.ndarray) -> TrainingState:
    """Learning rule (stochastic gradient descent)."""
    grads = jax.grad(loss)(state.params, features, labels)
    updates, opt_state = optimiser.update(grads, state.opt_state)
    params = optax.apply_updates(state.params, updates)
    # Compute avg_params, the exponential moving average of the "live" params.
    # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
    avg_params = optax.incremental_update(params, state.avg_params, step_size=0.001)
    return TrainingState(params, avg_params, opt_state)

With a lot of the heavy lifting now done, we can initalize the parameters and optimizer of our model:

In [15]:
# Initialise network and optimiser; note we draw an input to get shapes.
initial_params = network.init(jax.random.PRNGKey(seed=42), titanic_features[0])
initial_opt_state = optimiser.init(initial_params)
state = TrainingState(initial_params, initial_params, initial_opt_state)

And it is finally time to train our model.  Here will will do 10,000 training loops and call our `evaluate` function every 1,000 loops to report our accuracy against the training data:

In [16]:
# Training & evaluation loop.
for step in range(10001):
    if step % 1000 == 0:
        # Periodically evaluate classification accuracy on training set.
        accuracy = np.array(evaluate(state.avg_params, titanic_features, titanic_labels)).item()
        print({"step": step, "accuracy": f"{accuracy:.3f}"})

    # Do SGD on training examples.
    state = update(state, titanic_features, titanic_labels)

{'step': 0, 'accuracy': '0.616'}
{'step': 1000, 'accuracy': '0.791'}
{'step': 2000, 'accuracy': '0.815'}
{'step': 3000, 'accuracy': '0.840'}
{'step': 4000, 'accuracy': '0.847'}
{'step': 5000, 'accuracy': '0.853'}
{'step': 6000, 'accuracy': '0.860'}
{'step': 7000, 'accuracy': '0.866'}
{'step': 8000, 'accuracy': '0.870'}
{'step': 9000, 'accuracy': '0.870'}
{'step': 10000, 'accuracy': '0.872'}


With our model having been trained on the training data, our last step is to make predictions on the test data and save a file for submission to Kaggle:

In [17]:
predictions = jnp.round(network.apply(state.params, titanic_test), 0).astype(int)
output = pd.DataFrame(data=predictions, columns=['Survived'], index=test['PassengerId'])
output.to_csv('submission.csv')

And that is it...for now.  I will be updating this notebook with more detailed explanations (and likely fixing some mistakes, so please let me know if you spot any).