# Making Transformers Efficient in Production

In previous notebooks, we've seen how transformers can be fine-tuned for various tasks. However in many situations irrespective of the metric, model is not very useful if it's too slow or too large to meet the buisness requirements of the application. The obvious alternative is to train a faster and more compact model, but with reduction comes the degradation in performance. What if we need a compact yet highly accurate model?

In this notebook we'll cover four techniques with Open Neural Network Exchange(ONNX) and ONNX Runtime(ORT) to reduce the prediction time and memory footprint of the transformers, they are:

  1. *Knowledge distillization*
  2. *Quantization*
  3. *Pruning*
  4. *Graph Optimization*

We'll also see how the techniques can be combined to produce significant performance gains. [How Roblox engineering team scaled Bert to serve 1+ Billion Daily Requests on CPUs](https://medium.com/@quocnle/how-we-scaled-bert-to-serve-1-billion-daily-requests-on-cpus-d99be090db26) and found that knowledge distillization and quantization improved the latency and throughput of their BERT classifier over a factor of 30!

To illustrate the benefits and tradeoffs associated with each technique, we'll use intent detection(important component of text-based assistants), where low latency is critical for maintaining a conversation in real time.
Along the way, we'll also learn how to create custom trainers and perform hyperparamter search, and gain a sense of what it takes to implement cutting-edge research with Transformers(lib).

## Intent Detection as a Case Study

Let's suppose we're trying to build a text-based assistant for a company's call center to deal with banking and bookings without human interaction. Base on a wide variety of natural language text(input from user), our assistant needs to classify the input into a set of predefiend actions or *intents*.

For example:

*Message*: Hey, i'd like to rent a vehichle on Nov1st.

The assitant will classify this as *car rental intent* which then triggers an action and an response.

The assistance must also have the ability to identify out of scope(oos) intents to be robust in a production environment.

As a baseline, a BERT-model has been fine-tuned with accuracy around 94% on [CLINIC150 dataset](https://arxiv.org/abs/1909.02027). This dataset includes 22,500 inscope queries across 150 intents and 10 domains like banking and travel, and also includes 1,200 *oos* intent class.

> **Note:** In practive, we would gather in-house dataset, but using public data is a greate way to iterate quickly and generate preliminary results.

First let's load the model and wrap it around a text-classification pipeline.

In [9]:
!pip install -r requirements.txt -q

In [10]:
from transformers import pipeline

bert_ckpt = "transformersbook/bert-base-uncased-finetuned-clinc"
pipe = pipeline(
    task="text-classification",
    model=bert_ckpt
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [11]:
query = """Hey, I'd like to rent a vehicle from Nov 1st to Nov 15th in
Paris and I need a 15 passenger van"""
pipe(query)

[{'label': 'car_rental', 'score': 0.5490034222602844}]

The intent is good, as usual let's create a benchmark to evaluate the performance for our baseline model.

## Creating a Performance Benchmark

Like other machine learning models, deploying transformers in production environments involves trade-off among several constraints, the most common being:

*Model Performance*

How well does our model perform on a well-crafted test set that reflects production data? This is especially important when the cost of making erros is large(and best mitigated with a human in the loop) or running inference on millions of examples and small improvements to the model metrics can translate into large gains in aggregate.

*Latency*

How fast can our model deliver predictions? We usually care about latency in a real-time environment where we deail with a lot of traffic, like how stack overfloww needed a classifier to quickly [detect unwelcome comments on the website](https://stackoverflow.blog/2020/04/09/the-unfriendly-robot-automatically-flagging-unwelcoming-comments/)

*Memory*

How can we deploy billion-parameter models like GPT or T5 which requires gigabytes of disk storage and RAM? Memory plays an important role in mobile and edge devices, where we've to generate predictions without a cloud server.

Failing to address these constrains might result in:

* Poor user experience
* Balooned cloud costs for just a few user requests

To explore how the 4 different compression techniques can be used to optimizer these. Let's create a benchmark class which measures each of these quantities for a given pipeline and a test set.

In [12]:
# Skeleton of the benchmark
class PerformanceBenchmark:
  def __init__(
      self,
      pipeline,
      dataset,
      optim_type="BERT baseline"
  ):
    self.pipeline = pipeline
    self.dataset = dataset
    self.optim_type = optim_type

  def compute_accuracy(self):
    # Performance
    pass

  def compute_size(self):
    # Memory
    pass

  def time_pipeline(self):
    # Latency
    pass

  def run_benchmark(self):
    metrics = {}
    metrics[self.optim_type] = self.compute_size()
    metrics[self.optim_type].update(self.time_pipeline())
    metrics[self.optim_type].update(self.compute_accuracy())
    return metrics

We've define `optim_type` to store results of different optimization techniques. `run_benchmark()` function to collect the metrics in a dictionary.

Let's add some flesh to the skeleton by evaluating accuracy for our baseline model on the CLINC150 dataset. Let's download it.

In [16]:
dataset_path = "clinc_oos"
from datasets import get_dataset_config_names
get_dataset_config_names(
    path=dataset_path
)

['small', 'imbalanced', 'plus']

In [17]:
# We'll  use the plus subset
from datasets import load_dataset
clinc = load_dataset(
    path=dataset_path,
    name="plus",
)

Downloading and preparing dataset clinc_oos/plus (download: 2.39 MiB, generated: 1.18 MiB, post-processed: Unknown size, total: 3.57 MiB) to /root/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1...


Downloading:   0%|          | 0.00/291k [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset clinc_oos downloaded and prepared to /root/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [18]:
clinc

DatasetDict({
    train: Dataset({
        features: ['text', 'intent'],
        num_rows: 15250
    })
    validation: Dataset({
        features: ['text', 'intent'],
        num_rows: 3100
    })
    test: Dataset({
        features: ['text', 'intent'],
        num_rows: 5500
    })
})

In [19]:
clinc["test"][2]

{'text': 'how would they say butter in zambia', 'intent': 61}

`plus` subset contains *oos* intents. Each example in the dataset has text and intent.

In [31]:
clinc["test"].features['intent']

ClassLabel(num_classes=151, names=['restaurant_reviews', 'nutrition_info', 'account_blocked', 'oil_change_how', 'time', 'weather', 'redeem_rewards', 'interest_rate', 'gas_type', 'accept_reservations', 'smart_home', 'user_name', 'report_lost_card', 'repeat', 'whisper_mode', 'what_are_your_hobbies', 'order', 'jump_start', 'schedule_meeting', 'meeting_schedule', 'freeze_account', 'what_song', 'meaning_of_life', 'restaurant_reservation', 'traffic', 'make_call', 'text', 'bill_balance', 'improve_credit_score', 'change_language', 'no', 'measurement_conversion', 'timer', 'flip_coin', 'do_you_have_pets', 'balance', 'tell_joke', 'last_maintenance', 'exchange_rate', 'uber', 'car_rental', 'credit_limit', 'oos', 'shopping_list', 'expiration_date', 'routing', 'meal_suggestion', 'tire_change', 'todo_list', 'card_declined', 'rewards_balance', 'change_accent', 'vaccines', 'reminder_update', 'food_last', 'change_ai_name', 'bill_due', 'who_do_you_work_for', 'share_location', 'international_visa', 'calend

In [32]:
intents = clinc["test"].features["intent"]

In [34]:
int2str = {i:s for i, s in enumerate(intents.names)}
int2str

{0: 'restaurant_reviews',
 1: 'nutrition_info',
 2: 'account_blocked',
 3: 'oil_change_how',
 4: 'time',
 5: 'weather',
 6: 'redeem_rewards',
 7: 'interest_rate',
 8: 'gas_type',
 9: 'accept_reservations',
 10: 'smart_home',
 11: 'user_name',
 12: 'report_lost_card',
 13: 'repeat',
 14: 'whisper_mode',
 15: 'what_are_your_hobbies',
 16: 'order',
 17: 'jump_start',
 18: 'schedule_meeting',
 19: 'meeting_schedule',
 20: 'freeze_account',
 21: 'what_song',
 22: 'meaning_of_life',
 23: 'restaurant_reservation',
 24: 'traffic',
 25: 'make_call',
 26: 'text',
 27: 'bill_balance',
 28: 'improve_credit_score',
 29: 'change_language',
 30: 'no',
 31: 'measurement_conversion',
 32: 'timer',
 33: 'flip_coin',
 34: 'do_you_have_pets',
 35: 'balance',
 36: 'tell_joke',
 37: 'last_maintenance',
 38: 'exchange_rate',
 39: 'uber',
 40: 'car_rental',
 41: 'credit_limit',
 42: 'oos',
 43: 'shopping_list',
 44: 'expiration_date',
 45: 'routing',
 46: 'meal_suggestion',
 47: 'tire_change',
 48: 'todo_list

In [36]:
int2str.get(61)

'translate'

We can use the custom int2str dict above or inbuilt int2str to translate intent ids to labels.

In [37]:
sample = clinc["test"][42]
sample

{'text': 'transfer $100 from my checking to saving account', 'intent': 133}

In [38]:
intents.int2str(sample["intent"]), int2str.get(sample["intent"])

('transfer', 'transfer')

### Compute_accuracy

In [40]:
from datasets import load_metric
accuracy_score = load_metric("accuracy")

def compute_accuracy(self):
  """
  Overrides PerformanceBenchmark.compute_accuracy() method
  """

  preds, labels = [], []
  for example in self.dataset:
    pred = self.pipeline(example['text'][0])["label"]
    label = example["intent"]
    preds.append(intents.str2int(pred))
    labels.append(label)
  accuracy = accuracy_score.compute(
      predictions=preds,
      refereneces=labels,
  )
  print(f"Accuracy on test set - {accuracy['accuracy']:.3f}")
  return accuracy

PerformanceBenchmark.compute_accuracy = compute_accuracy

Downloading:   0%|          | 0.00/1.42k [00:00<?, ?B/s]