# JAX for NLP

## Installing JAX on GPU

To install JAX with GPU support, we first need to figure out which Python version and CUDA version are installed in our machine.
To do so, we just need to run the following two commands:

In [1]:
!nvcc --version
!python --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243
Python 3.6.6 :: Anaconda, Inc.


According to the command's output, we have the following python and CUDA versions:
- Python 3.6
- CUDA 10.1

Lastly, to install JAX, we have to retrieve its wheels with GPU support from: `https://storage.googleapis.com/jax-releases/\${CUDA_VER}/jaxlib-0.1.39-\${PYTHON_VER}-none-linux_x86_64.whl`

In [2]:
!pip install --upgrade "https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.39-cp36-none-linux_x86_64.whl"
!pip install --upgrade jax 

Collecting jaxlib==0.1.39
  Downloading https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.39-cp36-none-linux_x86_64.whl (65.5 MB)
[K     |████████████████████████████████| 65.5 MB 52 kB/s 
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.1.38
    Uninstalling jaxlib-0.1.38:
      Successfully uninstalled jaxlib-0.1.38
Successfully installed jaxlib-0.1.39
Requirement already up-to-date: jax in /opt/conda/lib/python3.6/site-packages (0.1.59)


We are done here, now we can start with the cool part 😊

In [3]:
import jax
import jax.numpy as np

key = jax.random.PRNGKey(0)

print('JAX is running on', jax.lib.xla_bridge.get_backend().platform)

JAX is running on gpu


## Model explanation

In this notebook, we are going to implement a Natural Language Processing (NLP) model based on "Word Embeddings".   

"Word embeddings" are a family of natural language processing techniques aiming at mapping semantic meaning into a geometric space. This is done by associating a numeric vector to every word in a dictionary, such that the distance (e.g. L2 distance or more commonly cosine distance) between any two vectors would capture part of the semantic relationship between the two associated words. The geometric space formed by these vectors is called an embedding space (Fraçois Chollet, 2016).

Since this competition is a *getting-started* one, we do not have enough data to build a set of robust embeddings, hence we are going to use embeddings that were pretrained on a huge text corpus, more precisely, we are going to use the Global Vectors for Word Representation [(GloVe)](https://nlp.stanford.edu/projects/glove/) embeddings.

On top of the GloVe embeddings, we are going to stack a set of 1D Convolutional layers and finally, a pair of linear or dense layers to do the classification. 

But wait... Does JAX provide all of those features to work with NLP? For now, it doesn't. 

Various python packages provide extra Deep Learning features over JAX, for example, the most popular packages nowadays are FLAX, Haiku, and STAX.

In this notebook we are not going to use any extra package apart from JAX, we are going to implement all the needed Layers from scratch.

## Data preprocess

Before the implementation, we are going to take a quick look to the data and build our vocabulary.

In [4]:
!ls ../input/nlp-getting-started

sample_submission.csv  test.csv  train.csv


In [5]:
import pandas as pd

df = pd.read_csv('../input/nlp-getting-started/train.csv', encoding='utf-8')
test_df = pd.read_csv('../input/nlp-getting-started/test.csv', encoding='utf-8')

df.head()

Unnamed: 0,id,keyword,location,text,target
0,1,,,Our Deeds are the Reason of this #earthquake M...,1
1,4,,,Forest fire near La Ronge Sask. Canada,1
2,5,,,All residents asked to 'shelter in place' are ...,1
3,6,,,"13,000 people receive #wildfires evacuation or...",1
4,7,,,Just got sent this photo from Ruby #Alaska as ...,1


Now we are going to clean the data applying these simple steps:

1. Remove URLs
2. Remove HTML Tags
3. Remove Emojis
4. Remove punctuation
5. Lowercase all text

In [6]:
import re
import string

def clean_tweet(tweet: str) -> str:
    url = re.compile(r'https?://\S+|www\.\S+')
    tweet = url.sub(r'',tweet)
    
    html = re.compile(r'<.*?>')
    tweet = html.sub(r'', tweet)
    
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F"  # emoticons
                           u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                           u"\U0001F680-\U0001F6FF"  # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    tweet = emoji_pattern.sub(r'', tweet)
    
    tweet = re.sub('([.,!?()#])', r' \1 ', tweet)
    tweet = re.sub('\s{2,}', ' ', tweet)

    return tweet.lower()

df['text'] = df['text'].apply(clean_tweet)
test_df['text'] = test_df['text'].apply(clean_tweet)

df.head()

Unnamed: 0,id,keyword,location,text,target
0,1,,,our deeds are the reason of this # earthquake ...,1
1,4,,,forest fire near la ronge sask . canada,1
2,5,,,all residents asked to 'shelter in place' are ...,1
3,6,,,"13 , 000 people receive # wildfires evacuation...",1
4,7,,,just got sent this photo from ruby # alaska as...,1


We are going to use Keras text preprocessing tools to create our vocabulary. This will help us create the vocabulary and also pad the sequences in an intuitive manner.

In [7]:
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

MAX_SEQUENCE_LENGTH = 64

tokenizer = Tokenizer(num_words=None, filters='')
tokenizer.fit_on_texts(df.text.tolist() + test_df.text.tolist())

train_sequences = tokenizer.texts_to_sequences(df.text.tolist())
train_sequences = pad_sequences(train_sequences, maxlen=MAX_SEQUENCE_LENGTH)

test_sequences = tokenizer.texts_to_sequences(test_df.text.tolist())
test_sequences = pad_sequences(test_sequences, maxlen=MAX_SEQUENCE_LENGTH)

word2idx = tokenizer.word_index

Using TensorFlow backend.


In [8]:
list(word2idx.items())[:10]

[('.', 1),
 ('#', 2),
 ('the', 3),
 ('?', 4),
 ('a', 5),
 ('to', 6),
 ('in', 7),
 ('of', 8),
 ('and', 9),
 ('i', 10)]

## Layers implementation

Before starting the layers implementation, it is a good practice to define a common abstraction for all our modules. You can think of Keras using the superclass `Layer` and PyTorch using the `nn.Module` class.
 For simplicity our superclass will be a named tuple called `JaxModule`.

In [9]:
from typing import Any, Callable, Container, NamedTuple

Parameter = Container[np.ndarray]
ForwardFn = Callable[[Parameter, np.ndarray, Any], np.ndarray]

class JaxModule(NamedTuple):
    # Dict or list contining the layer parameters
    parameters: Parameter
    
    # How we operate with parameters to generate an output
    forward_fn: ForwardFn
    
    def update(self, new_parameters) -> 'JaxModule':
        # As tuples are immutable, we create a new jax module keeping the
        # forward_fn but with new parameters
        return JaxModule(new_parameters, self.forward_fn)
    
    def __call__(self, x: np.ndarray, **kwargs) -> np.ndarray:
        # Automatically injects parameters
        return self.forward_fn(self.parameters, x, **kwargs)

To understand how we are going to use the `JaxModule` along the implementations, let's see how a simple linear layer would be implemented.

As you may know, a linear layer follows the below equation:

$ f(x) = W · x + b $ where W and b are trainable parameters.

In [10]:
def linear(key: np.ndarray, 
           in_features: int, 
           out_features: int, 
           activation = lambda x: x) -> 'JaxModule':
    x_init = jax.nn.initializers.xavier_normal()
    u_init = jax.nn.initializers.uniform()

    key, W_key, b_key = jax.random.split(key, 3)
    W = x_init(W_key, shape=(in_features, out_features))
    b = u_init(b_key, shape=(out_features,))
    params = dict(W=W, bias=b)
    
    def forward_fn(params: Parameter, x: np.ndarray, **kwargs):
        return np.dot(x, params['W']) + params['bias']
    
    return JaxModule(parameters=params, forward_fn=forward_fn)

def flatten(key: np.ndarray) -> 'JaxModule':
    # Reshapes the data to have 2 dims [BATCH, FEATURES]
    def forward_fn(params, x, **kwargs):
        bs = x.shape[0]
        return x.reshape(bs, -1)
    return JaxModule({}, forward_fn)


key, subkey, linear_key = jax.random.split(key, 3)

# Input vector of shape [BATCH, FEATURES]
mock_in = jax.random.uniform(subkey, shape=(10, 32))
linear_layer = linear(linear_key, 32, 512)

print('Input shape:', mock_in.shape, 
      'Output shape after linear:', linear_layer(mock_in).shape)

Input shape: (10, 32) Output shape after linear: (10, 512)


Pretty easy, isn't it? Let's move to Convolutional 1D implementation. 

This one is going to be a little bit more tricky. This is because we will get a bit deeper and use with a JAX primitive. A JAX primitive it is a function that directly wraps an XLA operation. 

So, to develop the 1D Conv, we are going to need the primitive `conv_general_dilated`. We won't get into details on how to use this primitive here. If you are interested in it, I recommend you reading [this section](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Convolutions) of JAX documentation.

In [11]:
def conv_1d(key: np.ndarray, 
            in_features: int, 
            out_features: int,
            kernel_size: int,
            strides: int = 1,
            padding: str = 'SAME',
            activation = lambda x: x) -> 'JaxModule':
    
    # [KERNEL_WIDTH, IN_FEATURES, OUT_FEATURES]
    kernel_shape = (kernel_size, in_features, out_features)
    # [BATCH, WIDTH, IN_FEATURES]
    seq_shape = (None, None, in_features)
    
    # Declare convolutional specs
    dn = jax.lax.conv_dimension_numbers(
        seq_shape, kernel_shape, ('NWC', 'WIO', 'NWC'))
    
    key, k_key, b_key = jax.random.split(key, 3)
    
    kernel = jax.nn.initializers.glorot_normal()(k_key, shape=kernel_shape)
    b = jax.nn.initializers.uniform()(b_key, shape=(out_features,))
    params = dict(kernel=kernel, bias=b)
    
    def forward_fn(params: Parameter, x: np.ndarray, **kwargs):
        return activation(jax.lax.conv_general_dilated(
            x, params['kernel'], 
            (strides,), padding, 
            (1,), (1,), dimension_numbers=dn) + params['bias'])

    return JaxModule(params, forward_fn)


key, subkey, conv_key = jax.random.split(key, 3)

# Input vector of shape [BATCH, FEATURES]
mock_in = jax.random.uniform(subkey, shape=(10, 128, 32))
conv = conv_1d(conv_key, 32, 512, kernel_size=3, strides=2)

print('Input shape:', mock_in.shape, 
      'Output shape after 1D Convolution:', conv(mock_in).shape)

Input shape: (10, 128, 32) Output shape after 1D Convolution: (10, 64, 512)


We are almost done. We only have the Embedding layer implementation left. Believe me when I say that this implementation is the simplest that we are going to see today. 

Actually, as we won't train embeddings from scratch, we don't really need a JaxModule initialized with random weights. What we are going to do, is just load the embeddings from a GloVe file and create a matrix where each row corresponds to one embedding of our vocabulary.

Surisingly, NLP Stanford GloVe webpage has a set of embedding pretrained on Twitter text corpus available to download, which is going to be pretty handy for this competion as we are dealing with tweets classification.

In [12]:
!ls ../input

glove-twitter-27b-50d  nlp-getting-started


In [13]:
import numpy as onp # Use onp to avoid overheat of moving each embedding to GPU

embeddings_index = {}
embedding_matrix = onp.zeros((len(word2idx) + 1, 50))

f = open('../input/glove-twitter-27b-50d/glove.twitter.27B.50d.txt')

for line in f:
    values = line.split()
    word = values[0]
    coefs = onp.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

for word, i in word2idx.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        embedding_matrix[i] = embedding_vector

embeddings = jax.device_put(embedding_matrix)
print('Embedding for "hello" is:', embeddings[word2idx['hello']])

Embedding for "hello" is: [ 0.28751    0.31323   -0.29318    0.17199   -0.69232   -0.4593
  1.3364     0.709      0.12118    0.11476   -0.48505   -0.088608
 -3.0154    -0.54024   -1.326      0.39477    0.11755   -0.17816
 -0.32272    0.21715    0.043144  -0.43666   -0.55857   -0.47601
 -0.095172   0.0031934  0.1192    -0.23643    1.3234    -0.45093
 -0.65837   -0.13865    0.22145   -0.35806    0.20988    0.054894
 -0.080322   0.48942    0.19206    0.4556    -1.642     -0.83323
 -0.12974    0.96514   -0.18214    0.37733   -0.19622   -0.12231
 -0.10496    0.45388  ]


We create a helper function to index our embedding matrix with a sequence.

In [14]:
def embed_sequence(sequence: np.ndarray) -> np.ndarray:
    return embeddings[sequence.astype('int32')]

key, subkey, embedding_key = jax.random.split(key, 3)

# Input vector of shape [BATCH, FEATURES]
mock_in = jax.random.randint(subkey, 
                             minval=0,
                             maxval=512,
                             shape=(10, 128))

print('Input shape:', mock_in.shape, 
      'Output shape after Embeddings:', embed_sequence(mock_in).shape)

Input shape: (10, 128) Output shape after Embeddings: (10, 128, 50)


Bfff... was a long journey, wasn't it? 😩 But that's all we need to do for our implementation, let's move to model training.

## Training a model

First of all we have to declare our model. To do so, we are going to use a similar API interface as the nn.Sequential defined in PyTorch or tf.keras.models.Sequential defined in Keras.

In [15]:
from typing import Sequence
from functools import partial

# Partially evaluated layers without the random key
PartialLayer = Callable[[np.ndarray], JaxModule]

def sequential(key: np.ndarray, *modules: Sequence[PartialLayer]) -> JaxModule:
    key, *subkeys = jax.random.split(key, len(modules) + 1)
    model = [m(k) for k, m in zip(subkeys, modules)]
    
    def forward_fn(params, x, **kwargs):
        for m, p in zip(model, params):
            x = m.forward_fn(p, x, **kwargs)
        return x
    
    return JaxModule([m.parameters for m in model], forward_fn)

mock_model = sequential(
    key,
    partial(conv_1d, 
            in_features=50, out_features=256, 
            kernel_size=5, strides=2, activation=jax.nn.relu),
    flatten,
    partial(linear, 
            in_features=16384, out_features=128, 
            activation=jax.nn.relu),
    partial(linear, 
            in_features=128, out_features=1, 
            activation=jax.nn.sigmoid))

embedded_in = embed_sequence(mock_in)

print('Input shape:', mock_in.shape, 
      'Output shape after model:', mock_model(embedded_in).shape)

Input shape: (10, 128) Output shape after model: (10, 1)


With JAX we can easily compute the model's weights gradients with respect an error function. Since we are dealing with a binary classification, we are going to implement the binary cross entropy loss (BCE loss). The BCE Loss follows the equation described below:

$ bce = - (y * log(\hat{y}) + (1 - y) * log(1 - \hat{y})) $

In [16]:
def bce(y_hat: np.ndarray, y: np.ndarray) -> float:
    y_hat = y_hat.reshape(-1)
    y = y.reshape(-1)
    y_hat = np.clip(y_hat, 1e-6, 1 - 1e-6)
    pt = np.where(y == 1, y_hat, 1 - y_hat)
    loss = -np.log(pt)
    return np.mean(loss)


def create_backward(model: JaxModule):
    # Backward is just a function that receives the model parameters and combines them
    # using the forward_fn. This will allow as to compute the partial derivate of the 
    # weights with respect to the loss
    def backward(params: Sequence[Parameter], 
                 x: np.ndarray, y: np.ndarray) -> float:
        y_hat = model.forward_fn(params, x)
        return bce(y_hat, y)
    return backward

# Compile and differentiate the function
backward_fn = jax.jit(jax.value_and_grad(create_backward(mock_model)))

Once we are capable of computing the gradients with respect the loss, we need an optimizer to update our parameters in the correct direction so we can minimize the loss. For simplicity, we implement a basic stocastic gradient descent (SGD). The SGD update rule is as follows:

`new_param = old_param - learning_rate * grad`

In [17]:
LEARNING_RATE = 1e-4

def optimizer_step(params: Sequence[Parameter], 
                   gradients: Sequence[Parameter]) -> Sequence[Parameter]:
    def optim_single(param, grad):
        for p in param:
            param[p] = param[p] - LEARNING_RATE * grad[p]
        return param

    return [optim_single(p, g) for p, g in zip(params, gradients)]

Let's put it all together and try a single update on our model and see if the parameters get updated.

In [18]:
# Mock labels
y_trues = np.array([0, 1, 1, 0, 0, 1, 0, 0, 1, 0])

loss, grads = backward_fn(mock_model.parameters, embedded_in, y_trues)
print('Loss:', loss)

# Update the parameters with the obtained gradients
new_params = optimizer_step(mock_model.parameters, grads)
model = mock_model.update(new_params)

loss, grads = backward_fn(mock_model.parameters, embedded_in, y_trues)
print('Loss after one step:', loss)

Loss: 3.1141462
Loss after one step: 1.7008412


Now we are ready to do what we already know how to do:

1. Split the data to compute the metrics on a validation set and ensure that we are not overfitting the dataset
2. Create a training loop and update the model at each training step
3. Create the submission file

In [19]:
from sklearn.model_selection import train_test_split

x = train_sequences
y = df.target.values

x_train, x_val, y_train, y_val = train_test_split(x, y, train_size=.9)

x_train = jax.device_put(x_train)
x_val = jax.device_put(x_val)
y_train = jax.device_put(y_train)
y_val = jax.device_put(y_val)

We define the final model. The one that should work... 😯

> The aim of this kernel is not to provide the best solution, it is just to guide and encourage you to use JAX and see how flexible it is.

In [20]:
model = sequential(
    key,
    partial(conv_1d, in_features=50, out_features=128,  kernel_size=7),
    partial(conv_1d, in_features=128, out_features=128, kernel_size=7, strides=2, activation=jax.nn.relu),
    
    partial(conv_1d, in_features=128, out_features=256, kernel_size=5),
    partial(conv_1d, in_features=256, out_features=256, kernel_size=5, strides=2, activation=jax.nn.relu),
    
    flatten,
    partial(linear, in_features=4096, out_features=128, activation=jax.nn.relu),
    partial(linear, in_features=128, out_features=1, activation=jax.nn.sigmoid))

backward_fn = jax.jit(jax.value_and_grad(create_backward(model)))

We define the train and validation steps.

In [21]:
from sklearn.metrics import accuracy_score, f1_score, recall_score

EPOCHS = 20
BATCH_SIZE = 32
train_steps = x_train.shape[0] // BATCH_SIZE

for epoch in range(EPOCHS):
    
    running_loss = 0.0
    
    for step in range(train_steps):        
        key, subkey = jax.random.split(key)
        batch_idx = jax.random.randint(subkey, 
                                       minval=0, 
                                       maxval=x_train.shape[0],
                                       shape=(BATCH_SIZE,))
        
        x_batch = embed_sequence(x_train[batch_idx])
        y_batch = y_train[batch_idx]
        
        loss, grads = backward_fn(model.parameters, x_batch, y_batch)
        running_loss += loss
        model = model.update(optimizer_step(model.parameters, grads))
        
    loss_mean = running_loss / float(step)
    print(f'Epoch[{epoch}] loss: {loss_mean:.4f}')
        
    predictions = model(embed_sequence(x_val))
    loss = bce(predictions, y_val)
    predictions = predictions > .5
    
    print('-- Validation --')
    print('Loss: {}'.format(loss))
    print('Accuracy: {}'.format(accuracy_score(y_val, predictions)))
    print('Recall: {}'.format(recall_score(y_val, predictions)))
    print('F1-Score: {}'.format(f1_score(y_val, predictions)))

Epoch[0] loss: 0.7677
-- Validation --
Loss: 0.6951687932014465
Accuracy: 0.610236220472441
Recall: 0.3484848484848485
F1-Score: 0.4364326375711575
Epoch[1] loss: 0.6686
-- Validation --
Loss: 0.6256797313690186
Accuracy: 0.6469816272965879
Recall: 0.44242424242424244
F1-Score: 0.5204991087344027
Epoch[2] loss: 0.6492
-- Validation --
Loss: 0.6162245869636536
Accuracy: 0.6745406824146981
Recall: 0.48787878787878786
F1-Score: 0.5649122807017543
Epoch[3] loss: 0.6260
-- Validation --
Loss: 0.6151533126831055
Accuracy: 0.6902887139107612
Recall: 0.43636363636363634
F1-Score: 0.549618320610687
Epoch[4] loss: 0.6162
-- Validation --
Loss: 0.6123585104942322
Accuracy: 0.7073490813648294
Recall: 0.4696969696969697
F1-Score: 0.5816135084427768
Epoch[5] loss: 0.5930
-- Validation --
Loss: 0.5699965357780457
Accuracy: 0.7125984251968503
Recall: 0.5696969696969697
F1-Score: 0.6319327731092437
Epoch[6] loss: 0.5895
-- Validation --
Loss: 0.55873042345047
Accuracy: 0.7257217847769029
Recall: 0.6
F1

Finally, we create the submission file. 🎇🎆

In [22]:
ids = test_df.id
targets = (model(embed_sequence(test_sequences)) > .4).reshape(-1)
pd.DataFrame(dict(id=ids, target=targets.astype('int32'))).to_csv('submission.csv', index=False)