In [1]:
pip install transformers[onnx] onnx onnxruntime



In [11]:
from transformers import BartForConditionalGeneration, BartTokenizer
import onnxruntime
import torch.onnx


# Load the BART model and tokenizer
model_name = "facebook/bart-large-cnn"
model = BartForConditionalGeneration.from_pretrained(model_name)
tokenizer = BartTokenizer.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()

# Create a dummy input for the model
dummy_text = "This is a dummy input for ONNX conversion"
inputs = tokenizer([dummy_text], max_length=1024, return_tensors="pt", truncation=True)

# Export the model to ONNX
torch.onnx.export(model,
                  (inputs['input_ids'], inputs['attention_mask']),
                  f"./{model_name.replace('/', '_')}.onnx",
                  input_names=['input_ids', 'attention_mask'],
                  output_names=['output'],
                  dynamic_axes={'input_ids' : {0 : 'batch_size'},
                                'attention_mask' : {0 : 'batch_size'},
                                'output' : {0 : 'batch_size', 1: 'sequence_length'}},
                  do_constant_folding=True,
                  opset_version=13)

print(f"Model saved to ./{model_name.replace('/', '_')}.onnx")
model_path = "./"+{model_name.replace('/', '_')}+".onnx"



Model saved to ./facebook_bart-large-cnn.onnx


In [3]:
import onnx
model_path = "/content/facebook_bart-large-cnn.onnx"

# Load the ONNX model
model = onnx.load(model_path)

# Check the model
try:
    onnx.checker.check_model(model)
    print("ONNX model is valid.")
except onnx.checker.ValidationError as e:
    print(f"Model check failed: {e}")


ONNX model is valid.


In [6]:
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np

session = ort.InferenceSession(model_path)

text = "To summarise, I think all orcs are bad they are described as evil beings with no soul by Tolkien who studied linguistics."

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

def preprocess_input(user_input, tokenizer, max_length=124):
    # Tokenize the input
    inputs = tokenizer.encode_plus(user_input, return_tensors="pt", max_length=max_length, truncation=True, padding="max_length")
    return inputs['input_ids'].numpy(), inputs['attention_mask'].numpy()

# Tokenize and preprocess the input text
input_ids_np, attention_mask_np = preprocess_input(text, tokenizer)

# Get input names from the model
input_names = [input.name for input in session.get_inputs()]

# Prepare the input dictionary
input_dict = {input_names[0]: input_ids_np, input_names[1]: attention_mask_np}

# Run the model
outputs = session.run(None, input_dict)

# Post-process and print the output
for batch in outputs:
    # Convert logits to token IDs
    token_ids = np.argmax(batch, axis=-1).flatten()
    # Decode and print the output text
    output_text = tokenizer.decode(token_ids, skip_special_tokens=True)
    print(output_text)


InvalidGraph: ignored