
# 6.3.6 - Inference Optimization Strategies for QA Models

This notebook demonstrates practical techniques for optimizing inference in QA systems, focusing on performance, latency, and memory efficiency.

Covered strategies:
- Quantization (BitNet demo)
- Flash Attention
- Distillation
- Pruning
- Batch Inference


In [None]:

!pip install transformers accelerate optimum sentence-transformers


## Quantization Example (BitNet-style simulation)

In [None]:

# Simulate a quantized model output
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")

prompt = "Q: What is relativity? Context: It is the dependence of physical phenomena on relative motion."
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)

print("Quantized Model Simulation:", tokenizer.decode(outputs[0], skip_special_tokens=True))


## Flash Attention for Efficient Memory Use

In [None]:

from transformers import AutoModelForCausalLM, AutoTokenizer

# Flash Attention 2 (mocked with compatible model)
model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b", use_flash_attention_2=True)
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")

input_text = "Q: Who discovered gravity? Context: Isaac Newton developed the law of universal gravitation."
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=64)

print("Flash Attention Answer:", tokenizer.decode(outputs[0], skip_special_tokens=True))


## Model Distillation

In [None]:

from transformers import BertForQuestionAnswering, DistilBertForQuestionAnswering

teacher = BertForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
student = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")

print("Teacher model size:", sum(p.numel() for p in teacher.parameters()) / 1e6, "M")
print("Student model size:", sum(p.numel() for p in student.parameters()) / 1e6, "M")


## Model Pruning Example

In [None]:

from optimum.intel.openvino import OVModelForQuestionAnswering

# Export to OpenVINO format and prune
model = OVModelForQuestionAnswering.from_pretrained("distilbert-base-uncased")
model.prune_layers(pruning_ratio=0.3)  # Prune 30%
print("Model pruned successfully.")


## Batch Inference

In [None]:

questions = ["What is AI?", "What is ML?"]
contexts = ["AI is a field of computer science.", "ML is a branch of AI."]

inputs = tokenizer(questions, contexts, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)

print("Batch Inference - Start logits:", outputs.start_logits.shape)
