In [6]:
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer

model_id="yoshitomo-matsubara/bert-base-uncased-sst2"
model = ORTModelForSequenceClassification.from_pretrained(".",file_name="quantize_model.onnx")
tokenizer = AutoTokenizer.from_pretrained(model_id)

The ONNX file quantize_model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.


In [7]:
from transformers import pipeline

clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)

In [8]:
clf("I hate you")

[{'label': 'LABEL_0', 'score': 0.9988948702812195}]

In [9]:
clf("I like you")

[{'label': 'LABEL_1', 'score': 0.9995238780975342}]

In [11]:
from time import perf_counter
import numpy as np

payload="Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer this email in the next 7 days. Best regards and have a nice weekend "*2
print(f'Payload sequence length: {len(tokenizer(payload)["input_ids"])}')

def measure_latency(pipe):
    latencies = []
    # warm up
    for _ in range(10):
        _ = pipe(payload)
    # Timed run
    for _ in range(300):
        start_time = perf_counter()
        _ =  pipe(payload)
        latency = perf_counter() - start_time
        latencies.append(latency)
    # Compute run statistics
    time_avg_ms = 1000 * np.mean(latencies)
    time_std_ms = 1000 * np.std(latencies)
    time_p95_ms = 1000 * np.percentile(latencies,95)
    return f"P95 latency (ms) - {time_p95_ms}; Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f};", time_p95_ms

vanilla_clx = pipeline("text-classification",model=model_id)


vanilla_model=measure_latency(vanilla_clx)
quantized_model=measure_latency(clf)

print(f"Vanilla model: {vanilla_model[0]}")
print(f"Quantized model: {quantized_model[0]}")
print(f"Improvement through quantization: {round(vanilla_model[1]/quantized_model[1],2)}x")


Payload sequence length: 128
