<a href="https://colab.research.google.com/github/JuheonChu/Natural-Language-Processing/blob/main/techspike_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvcc --version
!python --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Feb_14_21:12:58_PST_2021
Cuda compilation tools, release 11.2, V11.2.152
Build cuda_11.2.r11.2/compiler.29618528_0
Python 3.7.15


In [None]:
!pip install --upgrade "https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.39-cp36-none-linux_x86_64.whl"
!pip install --upgrade jax 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
[31mERROR: jaxlib-0.1.39-cp36-none-linux_x86_64.whl is not a supported wheel on this platform.[0m
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# Set up Colab to run on TPU (avoid Google Cloud TPU install command "jax[tpu]"")
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

import jax
import jax.numpy as np

key = jax.random.PRNGKey(0)

print('JAX is running on', jax.lib.xla_bridge.get_backend().platform)

JAX is running on gpu


In [None]:
ls ./input/nlp-getting-started/


ls: cannot access './input/nlp-getting-started/': No such file or directory


**1) Data preprocess**



*   Read data


In [None]:
import pandas as pd

df = pd.read_csv('./input/nlp-getting-started/train.csv', encoding='utf-8')
test_df = pd.read_csv('./input/nlp-getting-started/test.csv', encoding='utf-8')

# pandas method for returning the first n rows (5 by default
df.head()


FileNotFoundError: ignored




*   Clean up data





In [None]:
import re
import string

def clean_tweet(tweet: str) -> str:
    # Remove URLs
    url = re.compile(r'https?://\S+|www\.\S+')
    tweet = url.sub(r'',tweet)
    
    # Remove HTML Tags
    html = re.compile(r'<.*?>')
    tweet = html.sub(r'', tweet)
    
    # Remove Emojis
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F"  # emoticons
                           u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                           u"\U0001F680-\U0001F6FF"  # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    tweet = emoji_pattern.sub(r'', tweet)
    
    # Remove punctuation
    tweet = re.sub('([.,!?()#])', r' \1 ', tweet)
    tweet = re.sub('\s{2,}', ' ', tweet)

    # Lowercase all text
    return tweet.lower()

df['text'] = df['text'].apply(clean_tweet)
test_df['text'] = test_df['text'].apply(clean_tweet)

df.head()



*   Create our vocabulary using Keras text preprocessing tools (built on TensorFlow backend)



In [None]:
from keras.preprocessing.text import Tokenizer
from keras_preprocessing.sequence import pad_sequences

MAX_SEQUENCE_LENGTH = 64

tokenizer = Tokenizer(num_words=None, filters='')
tokenizer.fit_on_texts(df.text.tolist() + test_df.text.tolist())

train_sequences = tokenizer.texts_to_sequences(df.text.tolist())
train_sequences = pad_sequences(train_sequences, maxlen=MAX_SEQUENCE_LENGTH)

test_sequences = tokenizer.texts_to_sequences(test_df.text.tolist())
test_sequences = pad_sequences(test_sequences, maxlen=MAX_SEQUENCE_LENGTH)

word2idx = tokenizer.word_index

In [None]:
list(word2idx.items())[:10]

**2) Layers implementation (from scratch)**
*   Superclass `Layer` --> Keras
*   `nn.Module` --> PyTorch
* Superclass `JaxModule`
* Convolutional 1D implementation



In [None]:
import jax
import jax.numpy as np
from typing import Any, Callable, Container, NamedTuple

Parameter = Container[np.ndarray]
ForwardFn = Callable[[Parameter, np.ndarray, Any], np.ndarray]

key = jax.random.PRNGKey(0)

class JaxModule(NamedTuple):
    # Dict or list contining the layer parameters
    parameters: Parameter
    
    # How we operate with parameters to generate an output
    forward_fn: ForwardFn
    
    def update(self, new_parameters) -> 'JaxModule':
        # As tuples are immutable, we create a new jax module keeping the
        # forward_fn but with new parameters
        return JaxModule(new_parameters, self.forward_fn)
    
    def __call__(self, x: np.ndarray, **kwargs) -> np.ndarray:
        # Automatically injects parameters
        return self.forward_fn(self.parameters, x, **kwargs)

    # Create a linear layer
    def linear(key: np.ndarray, 
              in_features: int, 
              out_features: int, 
              activation = lambda x: x) -> 'JaxModule':
        x_init = jax.nn.initializers.xavier_normal()
        u_init = jax.nn.initializers.uniform()

        key, W_key, b_key = jax.random.split(key, 3)
        W = x_init(W_key, shape=(in_features, out_features))
        b = u_init(b_key, shape=(out_features,))
        params = dict(W=W, bias=b)
        
        def forward_fn(params: Parameter, x: np.ndarray, **kwargs):
            return np.dot(x, params['W']) + params['bias']
        
        return JaxModule(parameters=params, forward_fn=forward_fn)

    def flatten(key: np.ndarray) -> 'JaxModule':
        # Reshapes the data to have 2 dims [BATCH, FEATURES]
        def forward_fn(params, x, **kwargs):
            bs = x.shape[0]
            return x.reshape(bs, -1)
        return JaxModule({}, forward_fn)

    # Create a Convolutional 1D layer
    def conv_1d(key: np.ndarray, 
            in_features: int, 
            out_features: int,
            kernel_size: int,
            strides: int = 1,
            padding: str = 'SAME',
            activation = lambda x: x) -> 'JaxModule':
    
        # [KERNEL_WIDTH, IN_FEATURES, OUT_FEATURES]
        kernel_shape = (kernel_size, in_features, out_features)
        # [BATCH, WIDTH, IN_FEATURES]
        seq_shape = (None, None, in_features)
        
        # Declare convolutional specs
        dn = jax.lax.conv_dimension_numbers(
            seq_shape, kernel_shape, ('NWC', 'WIO', 'NWC'))
        
        key, k_key, b_key = jax.random.split(key, 3)
        
        kernel = jax.nn.initializers.glorot_normal()(k_key, shape=kernel_shape)
        b = jax.nn.initializers.uniform()(b_key, shape=(out_features,))
        params = dict(kernel=kernel, bias=b)
        
        def forward_fn(params: Parameter, x: np.ndarray, **kwargs):
            return activation(jax.lax.conv_general_dilated(
                x, params['kernel'], 
                (strides,), padding, 
                (1,), (1,), dimension_numbers=dn) + params['bias'])

        return JaxModule(params, forward_fn)


    key, subkey, linear_key = jax.random.split(key, 3)

    # Input vector of shape [BATCH, FEATURES]
    mock_in = jax.random.uniform(subkey, shape=(10, 32))
    linear_layer = linear(linear_key, 32, 512)

    print('Input shape:', mock_in.shape, 
          'Output shape after linear:', linear_layer(mock_in).shape)

##########################################

    key, subkey, conv_key = jax.random.split(key, 3)

    # Input vector of shape [BATCH, FEATURES]
    mock_in = jax.random.uniform(subkey, shape=(10, 128, 32))
    conv = conv_1d(conv_key, 32, 512, kernel_size=3, strides=2)

    print('Input shape:', mock_in.shape, 
          'Output shape after 1D Convolution:', conv(mock_in).shape)


In [None]:
!ls

sample_data


In [12]:
!git clone https://github.com/JuheonChu/Natural-Language-Processing.git

Cloning into 'Natural-Language-Processing'...
remote: Enumerating objects: 351, done.[K
remote: Counting objects: 100% (256/256), done.[K
remote: Compressing objects: 100% (143/143), done.[K
remote: Total 351 (delta 125), reused 159 (delta 76), pack-reused 95[K
Receiving objects: 100% (351/351), 78.37 KiB | 2.12 MiB/s, done.
Resolving deltas: 100% (133/133), done.


In [27]:
!cd Natural-Language-Processing
!ls 
!cd Natural-Language-Processing
!ls
!pwd
!cd Natural-Language-Processing


Natural-Language-Processing  sample_data
Natural-Language-Processing  sample_data
/content
