In [None]:
import sys

# !{sys.executable} -m pip install neuronx-cc==2.* torch-neuronx torchvision
!{sys.executable} -m pip install transformers

## 0. Import libraries

In [None]:
import transformers
import torch_neuronx
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## 1. Load model pretrained on MNLI

In [None]:
from transformers import BartForSequenceClassification, BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli', export=True)
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli', export=True)
model_cpu = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
model_dir = "Bart"

## 1.1 Test loaded model

In [None]:
# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = 'What is your favorite team, Madrid or Barca?'
hypothesis = 'This text is about sports.'
max_length = 128

# run through model pre-trained on MNLI
encoded_input = tokenizer.encode_plus(premise, hypothesis, return_tensors='pt', truncation='only_first', padding="max_length", max_length=max_length)
logits = model(encoded_input["input_ids"], encoded_input["attention_mask"], use_cache=False)[0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')

## 1.2 Test tracing the model as it comes

In [None]:
neuron_encoder = torch_neuronx.trace(
        model, 
        encoded_input["input_ids"],
        compiler_args='--target inf2 --model-type transformer --auto-cast all',
        compiler_workdir='./enc_dir')

Given this model is around 400M params (1.5GB), it fits into just 1 core when quantized to bf16. Also, this model is an encoder-decoder, so the strategy is to compile both components individually and then put them back into the original model structure. After that, both encoder and decoder will be accelerated on inf2.

In [None]:
dim_enc=model.config.max_position_embeddings
dim_dec=model.config.d_model
print(f'Dim enc: {dim_enc}; Dim dec: {dim_dec}')
max_dec_len = 1024

In [None]:
import torch
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions

# Define one function for the encoder part
def enc_f(self, input_ids, attention_mask, **kwargs):
    if hasattr(self, 'forward_neuron'):
        out = self.forward_neuron(input_ids, attention_mask)
    else:
        out = self.forward_(input_ids, attention_mask=attention_mask, return_dict=True)
    return BaseModelOutput(**out)


# Define one function for the decoder part
def dec_f(self, input_ids, encoder_hidden_states, encoder_attention_mask, **kwargs):    
    out = None
    
    if input_ids.shape[1] > self.max_length:
        raise Exception(f"The decoded sequence is not supported. Max: {self.max_length}")

    if hasattr(self, 'forward_neuron'):
        out = self.forward_neuron(input_ids,
                                  encoder_hidden_states,
                                  encoder_attention_mask)
    else:
        out = self.forward_(input_ids=input_ids,
                            encoder_hidden_states=encoder_hidden_states,
                            encoder_attention_mask=encoder_attention_mask,
                            return_dict=True,
                            use_cache=False,
                            output_attentions=False)
    
    # Ensure the output is compatible with BaseModelOutputWithPastAndCrossAttentions
    if 'cross_attentions' not in out:
        out['cross_attentions'] = None
    if 'hidden_states' not in out:
        out['hidden_states'] = None
    if 'attentions' not in out:
        out['attentions'] = None
    
    return BaseModelOutputWithPastAndCrossAttentions(**out)

In [None]:
import types

# Backup the original forward methods
if not hasattr(model.model.encoder, 'forward_'): 
    model.model.encoder.forward_ = model.model.encoder.forward
if not hasattr(model.model.decoder, 'forward_'): 
    model.model.decoder.forward_ = model.model.decoder.forward

# Replace the forward methods with the custom ones
model.model.encoder.forward = types.MethodType(enc_f, model.model.encoder)
model.model.decoder.forward = types.MethodType(dec_f, model.model.decoder)

# Set the max_length attribute for the decoder
model.model.decoder.max_length = max_dec_len  # or any other appropriate value

In [None]:
# Run only the encoder to prepare the sample input for the decoder
encoder_inputs = encoded_input["input_ids"], encoded_input["attention_mask"]
encoder_outputs = model.model.encoder(encoded_input["input_ids"], encoded_input["attention_mask"])

## Trace Encoder

In [None]:
import os
import torch

model_filename=f"{model_dir}/BART-large-nli-encoder.pt"

if not os.path.isfile(model_filename):
    if hasattr(model.model.encoder, 'forward_neuron'): del model.model.encoder.forward_neuron
    neuron_encoder = torch_neuronx.trace(
        model.model.encoder, 
        encoder_inputs,
        compiler_args='--target inf2 --model-type transformer --auto-cast all',
        compiler_workdir='./enc_dir')
    # neuron_encoder_dynamic_batch = torch_neuronx.dynamic_batch(neuron_encoder)
    neuron_encoder.save(model_filename)
    model.model.encoder.forward_neuron = neuron_encoder
else:
    model.model.encoder.forward_neuron = torch.jit.load(model_filename)



## Trace Decoder

In [None]:
model_filename=f"{model_dir}/BART-large-nli-decoder.pt"

if not os.path.isfile(model_filename):
    inp = encoded_input["input_ids"], encoder_outputs[0], encoded_input["attention_mask"]
    if hasattr(model.model.decoder, 'forward_neuron'): del model.model.decoder.forward_neuron
    neuron_decoder = torch_neuronx.trace(
        model.model.decoder,
        inp,
        compiler_args='--target inf2 --model-type transformer --auto-cast all',
        compiler_workdir='./dec_dir')
    # neuron_decoder_dynamic_batch = torch_neuronx.dynamic_batch(neuron_decoder)
    neuron_decoder.save(model_filename)
    model.model.decoder.forward_neuron = neuron_decoder
else:
    model.model.decoder.forward_neuron = torch.jit.load(model_filename)

## Test

In [None]:
# pass sequence as a NLI premise and label (politics) as a hypothesis
premise = 'how do you like the potatoes?'
hypothesis = 'This text is about cooking.'

# run through model pre-trained on MNLI
max_length=128
x = tokenizer.encode_plus(premise, hypothesis, return_tensors='pt', truncation='only_first', padding="max_length", max_length=max_length, return_attention_mask=True)
y = model(x["input_ids"],x["attention_mask"])
logits = y[0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')


### Now we can test the inference latency in the Inf2 chips:

In [None]:
%%timeit -r 10

model(x["input_ids"], x["attention_mask"])

### And compare it with the model hosted in the CPU:

In [None]:
%%timeit -r 10
model_cpu(x["input_ids"], x["attention_mask"])

### Finally we can compare the output of CPU model vs the Inf2

In [None]:
y = model_cpu(x["input_ids"],x["attention_mask"])
logits = y[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')


the value should be very similar to the one 3 cells above.