In [None]:
# deployment_challenges.ipynb

# -------------------------------
# 1. Setup & Imports
# -------------------------------
!pip install transformers onnx onnxruntime -q

import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import onnx
import onnxruntime as ort
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

# -------------------------------
# 2. Load & Prepare Model
# -------------------------------
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model.eval()

text = "Artificial intelligence is"
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

# -------------------------------
# 3. Export to ONNX
# -------------------------------
onnx_path = "gpt2_export.onnx"

torch.onnx.export(
    model,
    (input_ids,),
    onnx_path,
    input_names=["input_ids"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch", 1: "sequence"}},
    do_constant_folding=True,
    opset_version=13
)

print(f"âœ… Model exported to {onnx_path}")

# -------------------------------
# 4. Inference with ONNX Runtime
# -------------------------------
def onnx_infer(onnx_path, input_ids):
    session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
    ort_inputs = {"input_ids": input_ids.cpu().numpy()}
    start = time.time()
    ort_outs = session.run(None, ort_inputs)
    end = time.time()
    return ort_outs[0], end - start

onnx_logits, onnx_time = onnx_infer(onnx_path, input_ids)
print("âš¡ ONNX Inference Time (CPU): {:.4f}s".format(onnx_time))

# -------------------------------
# 5. Compare with PyTorch Inference
# -------------------------------
start = time.time()
with torch.no_grad():
    torch_logits = model(input_ids).logits
end = time.time()
print("âš¡ PyTorch Inference Time (GPU): {:.4f}s".format(end - start))

# -------------------------------
# 6. Deployment Challenges Summary
# -------------------------------
from IPython.display import Markdown

Markdown("""
### ðŸ§  Deployment Challenges Overview

| Challenge            | Notes |
|----------------------|-------|
| **Model Size**       | GPT-2 (500MB), difficult for edge |
| **ONNX Accuracy**    | Matches FP32, but slow without GPU EP |
| **Quantization**     | ONNX supports INT8 â€” needs calibration |
| **Tokenization I/O** | Slowest part often not the model, but token I/O |
| **Runtime Support**  | ONNX > ONNX Runtime > Edge device support |
""")
