## Running MXNet BERT-Base on Inf1 instances

BERT (Bidirectional Encoder Representations from Transformers) is a Google Research project published in 2018 [https://arxiv.org/abs/1810.04805](https://arxiv.org/abs/1810.04805). BERT has a number of practical applications, it can be used for question answering, sequence prediction and sequence classification amongst other tasks.

This tutorial will walk you through the process of modifying and compiling Bert Base (L-12 H-768 A-12) with sequence length 64 and batch size of 8 to run on Inferentia servers. 

### PRE-REQUISITES
The following pip packages are used in this notebook: `wget`, `os `, `shutil`, `sys`. Kindly install those prior to proceeding. 

### SETUP
We used publicly available instructions to generate a saved model for open source BERT using fine-tuned SST-2 weights. The steps to generate this model can be found [here](https://gluon-nlp.mxnet.io/v0.9.x/model_zoo/bert/index.html#sentence-classification), or you can download a trained model on SST-2 from [here](https://dist-bert.s3.amazonaws.com/demo/finetune/sst.params). Place the saved model in a directory named "gluonnlp_bert" under the bert_demo directory (it is assumed that this notebook is inside bert_demo directory). Download gluon-nlp package and put it in system path so that we can import it in python as a module. 

In [1]:
import os 
import shutil
import sys
import wget

# Clone a copy of gluon-nlp and check out 184a0007bc4165d5fe080a58dd3ff9bb413203a6
if os.path.isdir('gluon-nlp'):
    print("Removing Gluon-nlp... ")
    shutil.rmtree('gluon-nlp')
os.system("git clone https://github.com/dmlc/gluon-nlp; cd gluon-nlp; \
           git checkout 184a0007bc4165d5fe080a58dd3ff9bb413203a6")
p = 'gluon-nlp/src/'
sys.path.insert(0,p)

# Download a copy of sst.params 
if os.path.isdir('gluonnlp_bert'):
    print("Removing existing bert params... ")
    shutil.rmtree('gluonnlp_bert')
os.mkdir('gluonnlp_bert')
print('Beginning download of sst params...')
url = 'https://dist-bert.s3.amazonaws.com/demo/finetune/sst.params'
wget.download(url, 'gluonnlp_bert/sst.params')

# Remove output_director if present from previous runs
if os.path.isdir('output_dir'):
    print("Removing existing output_dir... ")
    shutil.rmtree('output_dir')
os.mkdir('output_dir')
print('Download of all necessary files complete.')

Removing Gluon-nlp... 
Removing existing bert params... 
Beginning download of sst params...
Removing existing output_dir... 
Download of all necessary files complete.


### MODIFYING BERT FOR INFERENTIA

Create an instance of BERT classifier from gluonnlp and modify it to make it work on inferentia. 

In [2]:
import mxnet as mx
import gluonnlp as nlp
from gluonnlp.model import get_model, BERTClassifier
import logging
import warnings

nlp.utils.check_version('0.8.1')

# BERT Model design parameters
seq_length = 64
model_parameters = 'gluonnlp_bert/sst.params'
model_name = 'bert_12_768_12'
dataset_name = 'book_corpus_wiki_en_uncased'
output_dir = 'output_dir'
batch_size = 8
dropout = 0.1
num_units = 1024 if model_name == 'bert_24_1024_16' else 768

# Create an instance of bert classifier and hybridize it 
bert, _ = get_model(
    name=model_name,
    dataset_name=dataset_name,
    pretrained=False,
    use_pooler=True,
    use_decoder=False,
    use_classifier=False,
    dropout=0.0)
net = BERTClassifier(bert, num_classes=2, dropout=dropout)

# Load the parameters from downloaded sst.params files
net.load_parameters(model_parameters)
net.hybridize(static_alloc=True, static_shape=True)

Make the modifications necessary to make bert classifier inferentia compatible and extract maximum performance from the hardware. Following modifications are made:  
1. Remove dropouts from the inference graph
2. Embedding lookup and processing is removed from the network and done on cpu. 
3. Mask used when sequence length is less than max_sequence length is also generated on CPU and feed into inferentia as an input tensor.

In [3]:
import math
f = mx.sym

def broadcast_axis(data=None, axis=None, size=None, out=None, name=None, **kwargs):
    assert axis == 1
    ones = f.ones((1,size,1,1))
    out = f.broadcast_div(data, ones)
    return out

def div_sqrt_dim(data=None, out=None, name=None, **kwargs):
    assert '1024' in model_name or '768' in model_name
    units = 1024/16 if '1024' in model_name else 768/12
    return data / math.sqrt(units)

def embedding_op(data=None, weight=None, input_dim=None, output_dim=None, dtype=None,
                 sparse_grad=None, out=None, name=None, batch_mode=True, **kwargs):
    repeat = seq_length if batch_mode else test_batch_size * seq_length
    output_shape = (seq_length, output_dim) if batch_mode else (test_batch_size, seq_length, output_dim)
    return data

def embedding(self, F, x, weight):
    out = embedding_op(x, weight, name='fwd', batch_mode=False, **self._kwargs)
    return out

def gelu(self, F, x):
    return 0.5 * x * (1 + F.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3))))

def layer_norm(self, F, data, gamma, beta):
    mean = data.mean(axis=self._axis, keepdims=True)
    delta = F.broadcast_sub(data, mean)
    var = (delta ** 2).mean(axis=self._axis, keepdims=True)
    X_hat = F.broadcast_div(delta, var.sqrt() + self._epsilon)
    return F.broadcast_add(F.broadcast_mul(gamma, X_hat), beta)

def arange_like(x, axis):
    arange = f.arange(start=0, repeat=num_units, step=1, stop=seq_length, dtype='float32')
    return arange

def where(condition=None, x=None, y=None, name=None, attr=None, out=None, **kwargs):
    return (condition == 0) * y + (1 - condition == 0) * x

def dropout(data=None, p=None, mode=None, axes=None, cudnn_off=None, out=None, name=None, **kwargs):
    return data

def bert_model___call__(self, inputs, valid_length=None, mask=None, masked_positions=None):
    # pylint: disable=dangerous-default-value, arguments-differ
    """Generate the representation given the inputs.
    This is used in training or fine-tuning a BERT model.
    """
    return super(nlp.model.BERTModel, self).__call__(inputs, valid_length, mask, masked_positions)

def bert_model_hybrid_forward(self, F, inputs, valid_length=None, mask=None, masked_positions=None): # abi
    # pylint: disable=arguments-differ
    """Generate the representation given the inputs.
    This is used in training or fine-tuning a BERT model.
    """
    outputs = []
    seq_out, attention_out = self._encode_sequence(inputs, valid_length, mask)
    outputs.append(seq_out)

    if self.encoder._output_all_encodings:
        assert isinstance(seq_out, list)
        output = seq_out[-1]
    else:
        output = seq_out

    if attention_out:
        outputs.append(attention_out)

    if self._use_pooler:
        pooled_out = self._apply_pooling(output)
        outputs.append(pooled_out)
        if self._use_classifier:
            next_sentence_classifier_out = self.classifier(pooled_out)
            outputs.append(next_sentence_classifier_out)
    if self._use_decoder:
        assert masked_positions is not None, \
            'masked_positions tensor is required for decoding masked language model'
        decoder_out = self._decode(F, output, masked_positions)
        outputs.append(decoder_out)
    return tuple(outputs) if len(outputs) > 1 else outputs[0]

def bert_model__encode_sequence(self, inputs, valid_length=None, mask=None): #abi
    """Generate the representation given the input sequences.
    This is used for pre-training or fine-tuning a BERT model.
    """
    outputs, additional_outputs = self.encoder(inputs, valid_length=valid_length, mask=mask)
    return outputs, additional_outputs

def bert_encoder___call__(self, inputs, states=None, valid_length=None, mask=None): #pylint: disable=arguments-differ abi
    return mx.gluon.HybridBlock.__call__(self, inputs, states, valid_length, mask)

def bert_encoder_hybrid_forward(self, F, inputs, states=None, valid_length=None, mask=None, position_weight=None): #abi
    if self._dropout:
        inputs = self.dropout_layer(inputs)
    inputs = self.layer_norm(inputs)
    outputs = inputs

    all_encodings_outputs = []
    additional_outputs = []
    for cell in self.transformer_cells:
        outputs, attention_weights = cell(inputs, mask)
        inputs = outputs
        if self._output_all_encodings:
            if valid_length is not None:
                outputs = F.SequenceMask(outputs, sequence_length=valid_length,
                                         use_sequence_length=True, axis=1)
            all_encodings_outputs.append(outputs)

        if self._output_attention:
            additional_outputs.append(attention_weights)

    if valid_length is not None and not self._output_all_encodings:
        outputs = F.SequenceMask(outputs, sequence_length=valid_length,
                                 use_sequence_length=True, axis=1)

    if self._output_all_encodings:
        return all_encodings_outputs, additional_outputs
    return outputs, additional_outputs

def bert_classifier___call__(self, inputs, valid_length=None, mask=None):
    # pylint: disable=dangerous-default-value, arguments-differ
    return super(BERTClassifier, self).__call__(inputs, valid_length, mask)

def bert_classifier_hybrid_forward(self, F, inputs, valid_length=None, mask=None):
    # pylint: disable=arguments-differ
    _, pooler_out = self.bert(inputs, valid_length, mask)
    return self.classifier(pooler_out)

nlp.model.GELU.hybrid_forward = gelu
mx.gluon.nn.LayerNorm.hybrid_forward = layer_norm
mx.gluon.nn.Embedding.hybrid_forward = embedding
f.contrib.arange_like = arange_like
f.Embedding = embedding_op
f.contrib.div_sqrt_dim = div_sqrt_dim
f.broadcast_axis = broadcast_axis
f.where = where
f.Dropout = dropout
nlp.model.bert.BERTModel.__call__ = bert_model___call__
nlp.model.bert.BERTModel._encode_sequence = bert_model__encode_sequence
nlp.model.bert.BERTModel.hybrid_forward = bert_model_hybrid_forward
nlp.model.bert.BERTEncoder.__call__ = bert_encoder___call__
nlp.model.bert.BERTEncoder.hybrid_forward = bert_encoder_hybrid_forward
nlp.model.bert.BERTClassifier.__call__ = bert_classifier___call__
nlp.model.bert.BERTClassifier.hybrid_forward = bert_classifier_hybrid_forward

Now that we have modified the graph, lets save the graph to generate symbol and param files that will be used for compilation for inferentia. Since we partioned the embedding part of the graph to be executed on CPU, we will bring those out from the original params so that we can load it on cpu for pre-processing.

In [4]:
# Dummy variables for the new inputs we have:
# inputs: embeddings. shape: (batch_size * seq_length * num_units) 
# valid_length: number of valid tokens in the input. shape: (batch_size,)
# mask: mask to remove invalid positions in the graph. 

inputs = mx.nd.arange(batch_size * seq_length * num_units)
inputs = inputs.reshape(shape=(batch_size, seq_length, num_units))
valid_length = mx.nd.arange(batch_size)
steps = mx.nd.arange(start=0, stop=seq_length, dtype='float32')
ones = mx.nd.ones_like(steps)
mask = mx.nd.broadcast_lesser(mx.nd.reshape(steps, shape=(1, -1)),
                          mx.nd.reshape(valid_length, shape=(-1, 1)))
mask = mx.nd.broadcast_mul(mx.nd.expand_dims(mask, axis=1),
                       mx.nd.broadcast_mul(ones, mx.nd.reshape(ones, shape=(-1, 1))))

def export(batch, prefix):
    """Export the model."""
    print('Exporting the model ... ')
    out = net(inputs, valid_length, mask)
    export_special(net, prefix, epoch=0)
    assert os.path.isfile(prefix + '-symbol.json')
    assert os.path.isfile(prefix + '-0000.params')

def export_special(net, path, epoch):
    sym = net._cached_graph[1]
    sym.save('%s-symbol.json'%path, remove_amp_cast=False)

    arg_names = set(sym.list_arguments())
    aux_names = set(sym.list_auxiliary_states())
    arg_dict = {}
    save_fn = mx.nd.save
    embedding_dict = {}
    for name, param in net.collect_params().items():
        if 'position_weight' in name or 'word_embed_embedding0_weight' in name or 'token_type_embed_embedding0_weight' in name:
            embedding_dict[name] = param._reduce()
        elif name in arg_names:
            arg_dict['arg:%s'%name] = param._reduce()
        else:
            assert name in aux_names, name
            arg_dict['aux:%s'%name] = param._reduce()
    save_fn('%s-%04d.params'%(path, epoch), arg_dict)
    save_fn('%s-%04d.embeddings'%(path, epoch), embedding_dict)
    
prefix = os.path.join('output_dir', 'classification-' + model_name + '-' + str(seq_length))
export(batch_size, prefix)

Exporting the model ... 


  out = self.forward(*args)


Now we implement the portion of the original graph that we removed (embedding lookup and processing) as a pre-process function using mx.nd. This part of the graph gets executed on CPU. And the output are used as input tensors to inferentia graph. 

In [5]:
def pre_process(sentences, transform, max_len, embedding_dict):
    """
    This pre-processing function executes the part of the network we 
    removed in the previous sections. It creates input tensors for a 
    Batch of input data / sentences. 
    Arguments: 
        - sentences: list of inputs of shape: (batch_size, )
        - transform: Sentence transformer which tokenizes the input sentences
        - max_len: Max sequence length the network was designed for. 
        - embedding_dict: The embedding dictionary that we extacted from the 
                        graph in the previous section. its used for embedding 
                        value lookup during inference.
    Return:
        - ips_b: 
        - sq_len_b:
        - mask_b:
    """
    ips_b = None
    tk_types_b = None
    sq_len_b = None
    mask_b = None

    for sentence in sentences:
      inputs, seq_len, token_types = transform([sentence])

      inputs_arr = mx.nd.array([inputs])
      token_types_arr = mx.nd.array([token_types])
      postional_arr = mx.nd.arange(max_len)
      seq_len = mx.nd.array([seq_len])
      max_len1 = mx.nd.array([max_len])

      # bert_encoder_hybrid_forward ~~
      steps = mx.nd.arange(start=0, stop=max_len, dtype='float32')
      ones = mx.nd.ones_like(steps)
      mask = mx.nd.broadcast_lesser(mx.nd.reshape(steps, shape=(1, -1)),
                                    mx.nd.reshape(max_len1, shape=(-1, 1)))
      mask = mx.nd.broadcast_mul(mx.nd.expand_dims(mask, axis=1),
                                 mx.nd.broadcast_mul(ones, mx.nd.reshape(ones, shape=(-1, 1))))

      ips = mx.nd.take(embedding_dict['bertmodel0_word_embed_embedding0_weight'], inputs_arr)
      tk_types = mx.nd.take(embedding_dict['bertmodel0_token_type_embed_embedding0_weight'], token_types_arr)
      ps_arr = mx.nd.take(embedding_dict['bertencoder0_position_weight'], postional_arr)
      sq_len = seq_len

      # BATCHING ~~~
      if ips_b is None:
        ips_b = ips
        tk_types_b = tk_types
        sq_len_b = sq_len
        mask_b = mask
      else:
        ips_b = mx.nd.concat(ips_b, ips, dim=0)
        tk_types_b = mx.nd.concat(tk_types_b, tk_types, dim=0)
        sq_len_b = mx.nd.concat(sq_len_b, sq_len, dim=0)
        mask_b = mx.nd.concat(mask_b, mask, dim=0)

    # bert_model__encode_sequence ~~~
    ips_b = ips_b + tk_types_b

    # Broadcast add (remove positional embedding addition from Inferentia graph and
    # do that on CPU
    ips_b = mx.nd.broadcast_add(ips_b, mx.nd.expand_dims(ps_arr, axis=0))
    print(type(ips_b), type(sq_len_b), type(mask_b))
    return ips_b, sq_len_b, mask_b

Now we create a sample input. Since we are compiling for batch size of 8, we put 8 random sentences into this list. The following code will download vocabulary files necessary (if not already in ~/.mxnet/models/). Using that vocab we create a tokenizer that will be used to transform the input sentences. We also load the embeddings file that we created while saving the model previously. Using these we execute the first part of the graph (pre-process function) on CPU and generate a feed_dict that is next used to feed input tensors to inferentia graph.

In [6]:
# Sample inputs that will be used for testing. 
sentences = ['Neuron is awesome',
             'Neuron is great',
             'Neuron is confusing',
             'I Like Neuron',
             'Pizza is my favorite food',
             'I love living in bay area',
             'Driving is not fun',
             'Neuron has very good performance']

if len(sentences) != batch_size:
    raise ValueError("Input dimensions don't match batch size")

_, vocabulary = nlp.model.get_model('bert_12_768_12',
                                    dataset_name='book_corpus_wiki_en_uncased',
                                    pretrained=False)
tokenizer = nlp.data.BERTTokenizer(vocabulary)
transform = nlp.data.BERTSentenceTransform(tokenizer, max_seq_length=seq_length, 
                                           pair=False, pad=True)

embedding_dict = mx.nd.load(prefix + '-0000.embeddings')
ips_b, sq_len_b, mask_b = pre_process(sentences, transform, seq_length, embedding_dict)

# ips_b: embeddings. shape: (batch_size * seq_length * num_units) 
# sq_len_b: number of valid tokens in the inputs. shape: (batch_size,)
# mask_b: mask to remove invalid positions in the graph.
feed_dict = {'data0': ips_b,
             'data1': sq_len_b,
             'data2': mask_b}

<class 'mxnet.ndarray.ndarray.NDArray'> <class 'mxnet.ndarray.ndarray.NDArray'> <class 'mxnet.ndarray.ndarray.NDArray'>


Next we create simple method that takes in a mxnet model and a feed dictionary and runs inferences and returns output values and average inference latencies. This method will be used to benchmark and compare inferentia and cpu runs.

In [7]:
import time
def run_model(sym, args, aux, ctx, args_update, num_runs):
    args.update(args_update)
    exe = sym.bind(ctx, args=args, aux_states=aux, grad_req='null')

    # Warmup inference
    start = time.time()
    exe.forward()
    out = exe.outputs[0]
    mx.nd.waitall()
    end = time.time()
    print('Warmup time : ', (end - start))

    start = time.time()
    for i in range(num_runs):
      exe.forward()
      out = exe.outputs[0]
    mx.nd.waitall()
    end = time.time()
    print('Avg inference time : ', (end - start) * 1. / num_runs)
    return out, (end - start) * 1. / num_runs

Before compiling the new graph for infentia, lets test the network we have for correctness. For this, we shall load checkpoint we generated earlier and run it with the output of the pre-processing function from above.

In [8]:
sym_ref, args_ref, aux_ref = mx.model.load_checkpoint(prefix, 0)
ref_out, _ = run_model(sym_ref, args_ref, aux_ref, mx.cpu(), feed_dict, 1)
label = mx.nd.argmax(ref_out, axis=1)
print("~~~~~~~~~~ Running on CPU ~~~~~~~~~~~~~ ")
for i, l in enumerate(label):
    print(sentences[i]+' : '+'positive sentiment' if l.asscalar() == 1 \
            else 'negative sentiment')

Warmup time :  0.44240593910217285
Avg inference time :  0.4353640079498291
~~~~~~~~~~ Running on CPU ~~~~~~~~~~~~~ 
Neuron is awesome : positive sentiment
Neuron is great : positive sentiment
negative sentiment
I Like Neuron : positive sentiment
Pizza is my favorite food : positive sentiment
I love living in bay area : positive sentiment
negative sentiment
Neuron has very good performance : positive sentiment


### COMPILING THE NETWORK FOR INFERENTIA

In [9]:
import json
def sym_nodes(sym):
    """
    Return a list of nodes from sym
    """
    return json.loads(sym.tojson())['nodes']

def count_ops(graph_nodes):
    """
    Return number of operations in node list
    """
    return len([x['op'] for x in graph_nodes if x['op'] != 'null'])

def get_compile_stats(sym):
    """
    Return triplet of compile statistics
    - count of operations in symbol database
    - number of Neuron subgraphs
    - number of operations compiled to Neuron runtime
    """
    cnt = count_ops(sym_nodes(sym))
    neuron_subgraph_cnt = 0
    neuron_compiled_cnt = 0
    for g in sym_nodes(sym):
      if g['op'] == '_neuron_subgraph_op':
        neuron_subgraph_cnt += 1
        for sg in g['subgraphs']:
          neuron_compiled_cnt += count_ops(sg['nodes'])
    return (cnt, neuron_subgraph_cnt, neuron_compiled_cnt)

def neuron_compile(prefix, inputs):
    # compile for Inferentia using Neuron
    compiler_args = {"flags": ['--tensor-layout-heuristics=spatial-locality', \
                               '-O2',  \
                               '--fp32-cast', 'matmult-fp16']}
    sym, args_loaded, aux = mx.model.load_checkpoint(prefix, 0)
    sym, args_loaded, aux = mx.contrib.neuron.compile(sym, args_loaded, aux, inputs, \
                                                        **compiler_args)

    # Check if compilation was successful
    post_compile_cnt, neuron_subgraph_cnt, neuron_compiled_cnt = get_compile_stats(sym)
    print("INFO:mxnet: Number of operations in compiled model: ", post_compile_cnt)
    print("INFO:mxnet: Number of Neuron subgraphs in compiled model: ", neuron_subgraph_cnt)
    print("INFO:mxnet: Number of operations placed on Neuron runtime: ", neuron_compiled_cnt)
    num_ops_orig = (post_compile_cnt - neuron_subgraph_cnt + neuron_compiled_cnt)
    neuron_percentage = (neuron_compiled_cnt / num_ops_orig) * 100
    compile_success = 1 if neuron_percentage > 99.0 else 0
    assert(compile_success), "Expected > 99% on Inf, but got {}".format(neuron_percentage)
    
    # save compiled model
    mx.model.save_checkpoint(prefix + "_compiled", 0, sym, args_loaded, aux)
    
neuron_compile(prefix, feed_dict)

INFO:mxnet: Number of operations in compiled model:  1
INFO:mxnet: Number of Neuron subgraphs in compiled model:  1
INFO:mxnet: Number of operations placed on Neuron runtime:  807


### RUNNING INFERENCE ON INF1 MACHINES

Load the checkpoint from the previous step and run inference. Things to note:
1. prefix path has now changed to point to the compiled model
2. context has been changed to mx.neuron()
3. Warmup time is generally much higher than subsequence inference times (because of the time taken to load the compiled model). 

In [10]:
import numpy as np
# Load Inferentia symbol
sym, args, aux = mx.model.load_checkpoint(prefix + '_compiled', 0)

# Run model and get output and latency numbers for 100 runs 
inf_out, latency = run_model(sym, args, aux, mx.neuron(), feed_dict, 100)
label_inf = mx.nd.argmax(inf_out, axis=1)
print("~~~~~~~~~~ Running on Inferentia ~~~~~~~~~~~~~ ")
for i, l in enumerate(label_inf):
    print(sentences[i]+' : '+'positive sentiment' if l.asscalar() == 1 \
            else 'negative sentiment')
    
# Check if the results are similar to CPU
np.testing.assert_allclose(inf_out.asnumpy(), ref_out.asnumpy(), atol=1e-2, rtol=1e-2)

Warmup time :  6.391572952270508
Avg inference time :  0.02266895294189453
~~~~~~~~~~ Running on Inferentia ~~~~~~~~~~~~~ 
Neuron is awesome : positive sentiment
Neuron is great : positive sentiment
negative sentiment
I Like Neuron : positive sentiment
Pizza is my favorite food : positive sentiment
I love living in bay area : positive sentiment
negative sentiment
Neuron has very good performance : positive sentiment
