# Fairseq for AWS Inferentia

**Separate encoder/decoder approach**

This notebook demonstrates how to compile the Fairseq encoder and decoder for Inferentia, and then swap the compiled models back into the original Fairseq model object.

This approach is more flexible than the alternative nn.Sequential "stacked encoder/decoder" approach, as variable sequence length can be specified at inference time. However, a possible drawback is that separate inference requests are required for each autogressive decoder call (proportional to sequence length) which could introduce latency for longer sequences.

**Reference:** https://github.com/facebookresearch/fairseq

## 1) Install dependencies
**Tested with:** Python 3.8.x

Fairseq also requires GCC to compile some C++ files. If you're using Ubuntu, install build-essential python3-setuptools and python3-dev

In [None]:
# Set Pip repository  to point to the Neuron repository
%pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
# now restart the kernel

In [None]:
#Install Neuron PyTorch
%pip install -U --force-reinstall torch==1.11.0 torch-neuron==1.11.0.* neuron-cc[tensorflow] \
    requests tensorboardX --extra-index-url=https://download.pytorch.org/whl/torch_stable.html
# use --force-reinstall if you're facing some issues while loading the modules
# now restart the kernel again

In [None]:
import os

if not os.path.isdir('fairseq-local'):
    !git clone https://github.com/pytorch/fairseq fairseq-local && \
    cd fairseq-local && git checkout acd9a53607d1e5c64604e88fc9601d0ee56fd6f1 && \
    pip3 install --editable ./ && \
    pip3 --no-cache-dir install sacremoses torch==1.11.0 torchaudio==0.11.0 "numpy==1.22.1" scikit-learn fastBPE

**Remember to restart kernel before continuing!**

## 2) Initialize libraries and prepare input samples

In [None]:
import os
import types
import torch
import torch.neuron
import torch.nn.functional as F
print(torch.__version__)
assert(torch.__version__.startswith("1.11.0"))

max_length=32 # you can increase this, but it can impact on performance
sentences = [
    "i've seen things, you people wouldn't believe, hmmm.",
    "attack ships on fire off the shoulder of Orion.",
    "I've watched c Beams glitter in the dark near the Tannhauser Gate.",
    "All those moments, will be lost in time like tears in rain.",
    "time to die"
]

## 3) Load a pre-trained model and check if it is .jit traceable

In [None]:
model = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')

In [None]:
# do this if you hit `No module named 'skearn'` above
# !rm -rf ~/.cache/torch/hub/pytorch_fairseq_main/fairseq
# model = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')

In [None]:
# do this if you hit `Primary config directory not found.` above
# !rm -rf ~/.cache/torch/hub/
# model = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')

### 3.1) Adjust the encoder to make it traceable

In [None]:
def e(self, src_tokens, src_lengths=None, **kwargs):    
    if torch.jit.is_tracing():
        print('tracing encoder...')
        values = list(self.encoder.forward_(src_tokens, src_lengths).values())
        return values[0],tuple(values[1]),values[2],tuple(values[3])
    elif hasattr(self.encoder, 'forward_neuron'):        
        delta = torch.as_tensor(self.encoder.max_decoder_length - src_tokens.shape[1])
        pad_size = (0, delta)
        src_tokens = F.pad(src_tokens, pad_size, "constant", 1) # 1 is the pad_token_id
        strc_lengths = torch.ones([max_length], dtype=torch.int64)
        out = self.encoder.forward_neuron(src_tokens, src_lengths)        
        # we'll not unpad to make it already prepared for the decoder
        return {
            'encoder_out': out[0], 'encoder_padding_mask':out[1],
            'encoder_embedding':out[2], 'encoder_states':out[3],
            'fc_results':None, 'src_tokens':src_tokens,'src_lengths': [src_lengths]
        }
    else:
        return self.encoder.forward_(src_tokens, src_lengths)
if not hasattr(model.models[0].encoder, 'forward_'):
    model.models[0].encoder.forward_ = model.models[0].encoder.forward
model.models[0].encoder.max_decoder_length = max_length
model.models[0].encoder.forward = types.MethodType(e, model.models[0])

### 3.2) Adjust the decoder to make it traceable
The decoder is more complex because it is invoked many times during prediction with different input shapes. We need to pad the input shapes before compiling the model.

In [None]:
def reduce(self, logits, index):
    _, n_length, _ = logits.shape

    # Create selection mask
    mask = torch.arange(n_length, dtype=torch.int32) == index
    mask = mask.view(1, -1, 1)

    # Broadcast mask
    masked = torch.multiply(logits, mask.to(torch.float32))

    # Reduce along 1st dimension    
    return torch.unsqueeze(torch.sum(masked, 1), 1)

def pad(self, tensor, pad_val=(0,0), value=1):
    return F.pad(tensor, pad_val, "constant", value)

def d(self, prev_output_tokens, encoder_out, pad_size=torch.as_tensor(0), **kwargs):
    if torch.jit.is_tracing():
        print('tracing decoder...')
        kwargs['features_only'] = True # do not apply output_projection
        encoder_out = {'encoder_out': encoder_out[0], 'encoder_padding_mask': encoder_out[1] }        

        out,extra = self.forward_(prev_output_tokens, encoder_out, **kwargs)        
        index = torch.as_tensor(out.shape[1] - 1) - pad_size        
        out = self.output_projection( self.reduce(out, index) )        
        return out,tuple(extra['attn']),tuple(extra['inner_states'])
    elif hasattr(self, 'forward_neuron'):        
        pad_size = torch.as_tensor(self.max_decoder_length - prev_output_tokens.shape[1])
        prev_output_tokens = self.pad(prev_output_tokens, (0,pad_size))
        encoder_out_new = encoder_out['encoder_out']
        encoder_padding_mask_new = encoder_out['encoder_padding_mask']
        
        out,attn,inner_states = self.forward_neuron(
            prev_output_tokens, [encoder_out_new, encoder_padding_mask_new], pad_size )

        return out, {'attn': attn, 'inner_states': inner_states}
    else:
        print('checking trace...')        
        encoder_out = {'encoder_out': encoder_out[0], 'encoder_padding_mask': encoder_out[1] }
        return self.forward_(prev_output_tokens, encoder_out, **kwargs)
        
if not hasattr(model.models[0].decoder, 'forward_'):
    model.models[0].decoder.forward_ = model.models[0].decoder.forward
model.models[0].decoder.max_decoder_length = max_length
model.models[0].decoder.forward = types.MethodType(d, model.models[0].decoder)

model.models[0].decoder.reduce = types.MethodType(reduce, model.models[0].decoder)
model.models[0].decoder.pad = types.MethodType(pad, model.models[0].decoder)

### 3.3) Check if both encoder and decoder are traceable now

In [None]:
if hasattr(model.models[0].encoder, 'forward_neuron'): del model.models[0].encoder.forward_neuron
if hasattr(model.models[0].decoder, 'forward_neuron'): del model.models[0].decoder.forward_neuron

try:
    inp_enc = (torch.ones([1,max_length], dtype=torch.int64), torch.ones([max_length], dtype=torch.int64))
    y = model.models[0].encoder(*inp_enc) # warmup
    traced_encoder = torch.jit.trace(model.models[0].encoder, inp_enc)
    print("Cool! Model is jit traceable")
except Exception as e:
    print(e)
    print(f"Ops. Something went wrong. Model is not traceable {e}")
## ok the model is .jit traceable. now let's compile it with NeuronSDK

In [None]:
prev_output_tokens = torch.zeros([5, max_length], dtype=torch.int64)
encoder_out = [torch.rand([max_length, 5, 1024], dtype=torch.float32)]
encoder_padding_mask = [torch.zeros([5, max_length], dtype=torch.bool)]
delta=torch.as_tensor(0)

if hasattr(model.models[0].decoder, 'forward_neuron'): del model.models[0].decoder.forward_neuron

try:
    with torch.no_grad():
        inp_dec = (prev_output_tokens, [encoder_out, encoder_padding_mask], delta)
        y = model.models[0].decoder(*inp_dec) # warmup
        traced_decoder = torch.jit.trace(model.models[0].decoder, inp_dec)    
        y = traced_decoder(*inp_dec) 
    print("Cool! Model is jit traceable")
except Exception as e:
    print(e)
    print(f"Ops. Something went wrong. Model is not traceable {e}")
## ok the model is .jit traceable. now let's compile it with NeuronSDK

### 3.4) Quick test to verify the traced modules

In [None]:
model.models[0].encoder.forward_neuron = traced_encoder
model.models[0].decoder.forward_neuron = traced_decoder
model.translate(sentences[0:1])

## 4) Analyze & compile the model for Inferentia with NeuronSDK

Neuron Check Model tool provides user with basic information about the compiled and uncompiled model’s operations without the use of TensorBoard-Neuron.  
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-tools/tutorial-neuron-check-model.html


The PyTorch-Neuron trace Python API provides a method to generate PyTorch models for execution on Inferentia, which can be serialized as TorchScript. It is analogous to torch.jit.trace() function in PyTorch.   
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-frameworks/pytorch-neuron/api-compilation-python-api.html?highlight=trace

In [None]:
import torch
import torch.neuron
print(torch.neuron.analyze_model(traced_encoder, example_inputs=inp_enc))
print(torch.neuron.analyze_model(traced_decoder, example_inputs=inp_dec))

In [None]:
import os
import torch
import torch.neuron

#https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-cc/command-line-reference.html#cmdoption-neuron-cc-arg-0

ops = torch.neuron.get_supported_operations() + ['aten::embedding']
if not os.path.isfile("fairseq_encoder_neuron.pt"):
    model_neuron_encoder = torch.neuron.trace(traced_encoder, example_inputs=inp_enc, op_whitelist=ops)
    ## Export to saved model
    model_neuron_encoder.save("fairseq_encoder_neuron.pt")

if not os.path.isfile("fairseq_decoder_neuron.pt"):
    model_neuron_decoder = torch.neuron.trace(traced_decoder, example_inputs=inp_dec, op_whitelist=ops)
    ## Export to saved model
    model_neuron_decoder.save("fairseq_decoder_neuron.pt")

### 4.1) Verify the optimized model

In [None]:
# run this under python3.8 kernel can leard to kernel deadeal
model.models[0].encoder.forward_neuron = torch.load('fairseq_encoder_neuron.pt')
model.models[0].decoder.forward_neuron = torch.load('fairseq_decoder_neuron.pt')
model.translate(sentences[0:1]) # warmup

## 5) A simple test to check the predictions

If the kernel die after running the cell below, please try copy all codes into a python script and run it a `python fairseq_script.py`.

In [None]:
## a good next step is to enable dynamic_batch_size to allow predicing
## multiple sentences at the same time. Also, you can compile decoders
## with different input shapes
[(s,model.translate(s)) for s in sentences]