# Production Deployment & Monitoring

Shipping a model requires more than training: you must serialize artifacts, optimize them for deployment, and monitor live performance. This notebook walks through TorchScript/ONNX export, quantization, and monitoring hooks such as latency tracking and drift detection.

## Learning Objectives

- Export PyTorch models to TorchScript and ONNX formats.
- Apply dynamic quantization for lightweight CPU inference.
- Benchmark latency and log alerts for production metrics.
- Draft config snippets for inference services (e.g., TorchServe).

## TorchScript Export

TorchScript captures models for deployment in C++ or TorchServe. The script below traces and scripts a simple model, then saves and reloads it.

In [None]:
import torch
import torch.nn as nn

class InferenceModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.net(x)

model = InferenceModel().eval()
sample = torch.randn(1, 16)

traced = torch.jit.trace(model, sample)
scripted = torch.jit.script(model)
torch.jit.save(traced, "notebooks/03_advanced/inference_model_traced.pt")
reloaded = torch.jit.load("notebooks/03_advanced/inference_model_traced.pt")
print("Traced output", reloaded(sample))


## ONNX Export

ONNX provides cross-framework portability, enabling optimization tools like TensorRT or ONNX Runtime.

In [None]:
torch.onnx.export(
    model,
    sample,
    "notebooks/03_advanced/inference_model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
    opset_version=17,
)
print("ONNX export complete")


## Quantization

Dynamic quantization reduces model size and improves CPU latency. It works well for linear-heavy models.

In [None]:
quantized = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
print("Quantized output", quantized(sample))


## Latency Benchmarking

Measure median, p99, and average latency to set realistic service-level objectives (SLOs).

In [None]:
import time
import statistics

def benchmark_inference(model, runs=50):
    latencies = []
    with torch.no_grad():
        for _ in range(runs):
            start = time.perf_counter()
            model(sample)
            latencies.append((time.perf_counter() - start) * 1000)
    return {
        "p50_ms": statistics.median(latencies),
        "p99_ms": sorted(latencies)[int(0.99 * len(latencies)) - 1],
        "avg_ms": sum(latencies) / len(latencies),
    }

stats = benchmark_inference(traced)
print(stats)


## Mini Task – Drift Detection Snapshot

Simulate feature drift by shifting live data means relative to training baselines. Plot mean shifts per feature to flag which dimensions changed.

In [None]:
train_mean = torch.zeros(16)
prod_samples = torch.randn(256, 16) + 0.3  # simulate drift
# TODO: compute mean shifts and visualize per feature


In [None]:
mean_diff = prod_samples.mean(dim=0) - train_mean
plt.bar(range(len(mean_diff)), mean_diff.numpy())
plt.xlabel("Feature index")
plt.ylabel("Mean shift")
plt.title("Feature drift snapshot")
plt.show()


## Alerting Hooks

Implement logic that inspects latency metrics and flags anomalies (e.g., high p99).

In [None]:
def check_latency_thresholds(metrics, latency_ms=50):
    alerts = []
    if metrics["p99_ms"] > latency_ms:
        alerts.append("High latency")
    if abs(metrics["avg_ms"] - metrics["p50_ms"]) > latency_ms:
        alerts.append("Latency variance")
    return alerts

print(check_latency_thresholds(stats))


## Mini Task – TorchServe Config Snippet

Draft a configuration string for TorchServe's `config.properties`, enabling basic metrics and batch inference.

In [None]:
config_properties = """
# TODO: populate TorchServe config properties
"""


In [None]:
config_properties = """
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
default_workers_per_model=2
batch_size=8
max_batch_delay=100
enable_metrics_api=true
"""
print(config_properties)


## Comprehensive Exercise – Production Playbook

Create a Markdown string documenting a deployment checklist: artifact versioning, canary rollout, health checks, drift monitoring, and rollback strategy.

In [None]:
def production_playbook():
    # TODO: return multi-line markdown checklist
    raise NotImplementedError


In [None]:
def production_playbook():
    return """\
## Deployment Playbook

- [ ] Capture model metadata (commit hash, training data snapshot, evaluation metrics).
- [ ] Validate TorchScript/ONNX artifacts with unit and integration tests.
- [ ] Deploy a canary instance and compare live metrics against baseline.
- [ ] Monitor latency (p50/p95/p99), error rates, and drift scores in real time.
- [ ] Configure automated rollback if metrics breach SLA thresholds.
- [ ] Schedule a post-deployment review and document learnings.
"""

print(production_playbook())


## Further Reading

- TorchServe documentation: https://pytorch.org/serve/
- ONNX Runtime tuning guides for CPU/GPU inference
- ML observability platforms (Arize, WhyLabs, Fiddler) for advanced monitoring
- “Hidden Technical Debt in Machine Learning Systems” (Sculley et al.)