# PyTorch Deployment & Optimization Playbook


Senior MLEs often need repeatable recipes for squeezing the most out of PyTorch models.
This notebook highlights practical techniques with bite-sized examples that you can copy into
production scripts. We work with a lightweight classifier so you can focus on the tooling, not the math.

**Topics covered**
- Latency and inference speed optimizations
- Memory efficiency patterns (training & inference)
- Quantization for CPU deployments
- Structured pruning to shrink models
- Exporting models to ONNX, TensorRT, and TFLite toolchains


In [None]:
from __future__ import annotations

import time
from typing import Iterable

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")


## Baseline model and synthetic data
Keep the network intentionally small - this makes it easy to experiment with tooling.


In [None]:
class TinyClassifier(nn.Module):
    def __init__(self, in_features: int = 128, hidden: Iterable[int] = (256, 128), num_classes: int = 10, dropout_p: float = 0.1):
        super().__init__()
        hidden = tuple(hidden)
        layers = []
        last = in_features
        for width in hidden:
            layers.append(nn.Linear(last, width))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_p))
            last = width
        layers.append(nn.Linear(last, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


model = TinyClassifier().to(device)
dummy_batch = torch.randn(512, 128, device=device)
dummy_targets = torch.randint(0, 10, (512,), device=device)

print(model(dummy_batch).shape)


## Latency & inference speed
A reliable way to evaluate optimizations is to benchmark short runs.
The helper below keeps timing simple and works on both CPU and GPU.


In [None]:
@torch.inference_mode()
def benchmark(model: nn.Module, inputs: torch.Tensor, warmup: int = 10, steps: int = 50) -> float:
    model.eval()
    for _ in range(warmup):
        model(inputs)
    start = time.perf_counter()
    for _ in range(steps):
        model(inputs)
    end = time.perf_counter()
    return (end - start) / steps


base_latency = benchmark(model, dummy_batch)
print(f"Baseline latency: {base_latency * 1e3:.3f} ms per batch")


In [None]:
# TorchScript for stable inference graphs
scripted = torch.jit.script(model)
script_latency = benchmark(scripted, dummy_batch)
print(f"Scripted latency: {script_latency * 1e3:.3f} ms per batch")

compiled_latency = None
if hasattr(torch, "compile"):
    compiled = torch.compile(model, mode="reduce-overhead")
    compiled_latency = benchmark(compiled, dummy_batch)
    print(f"torch.compile latency: {compiled_latency * 1e3:.3f} ms per batch")


## Finding Bottlenecks with PyTorch Profiler
Before optimizing, you need to know where the time is spent. The PyTorch Profiler is the standard tool for this.
It traces operators on both CPU and GPU, helping you identify which parts of your model are the slowest.

In [None]:
import torch.profiler

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    record_shapes=True,
    with_stack=True,
) as prof:
    with torch.profiler.record_function("model_inference"):
        for _ in range(10):
            model(dummy_batch)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

**Quick wins for latency**
- Switch to `torch.inference_mode()` to bypass autograd entirely during inference.
- Use TorchScript (`torch.jit.script`) for backend-agnostic graph compilation.
- On PyTorch 2.x, `torch.compile(model, mode="reduce-overhead")` can remove Python dispatch.
- When serving on GPU, enable `torch.backends.cudnn.benchmark = True` for convolutional workloads with static shapes.
- Pin CPU threads via `torch.set_num_threads()` and align with your serving process affinity.


## Memory optimization patterns
Even simple models benefit from disciplined memory management, especially when batching.


In [None]:
def train_step(model: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, *, mixed_precision: bool = False) -> float:
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    scaler = torch.cuda.amp.GradScaler(enabled=mixed_precision and device.type == "cuda")

    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.amp.autocast(enabled=mixed_precision and device.type == "cuda"):
        logits = model(inputs)
        loss = F.cross_entropy(logits, targets)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    return float(loss.item())


loss = train_step(model, dummy_batch, dummy_targets)
print(f"Single step loss: {loss:.4f}")


**Memory checklist**
- Use `set_to_none=True` when zeroing gradients to skip redundant memory writes.
- Switch on mixed precision (`torch.cuda.amp`) during training to reduce activation footprint on CUDA.
- Gradient checkpointing (`torch.utils.checkpoint`) trades compute for memory on deep networks.
- Wrap inference in `torch.inference_mode()` to free autograd buffers immediately.
- Keep tensor dtypes explicit (for example `torch.float16`, `torch.bfloat16`) for compatibility with quantized or accelerated kernels.


## Regularization and Training Loop Example
Regularization techniques are essential for training robust models that generalize well. The following runnable example demonstrates a complete training loop that incorporates several key concepts:

- **Weight Decay (L2 Regularization)**: Applied via the `weight_decay` parameter in the `AdamW` optimizer to penalize large weights.
- **Dropout**: Included in our `TinyClassifier` to prevent neuron co-adaptation.
- **Gradient Accumulation**: Simulates a larger batch size by accumulating gradients over several steps, crucial for memory-constrained training.
- **Early Stopping**: Monitors validation loss and stops training when it no longer improves, saving compute and preventing overfitting.

In [None]:
from torch.utils.data import TensorDataset, DataLoader

# 1. Create dummy datasets and dataloaders
train_data = TensorDataset(dummy_batch, dummy_targets)
val_data = TensorDataset(torch.randn(128, 128, device=device), torch.randint(0, 10, (128,), device=device))
train_loader = DataLoader(train_data, batch_size=64)
val_loader = DataLoader(val_data, batch_size=128)

# 2. Instantiate model and optimizer with Weight Decay
training_model = TinyClassifier().to(device)
optimizer = torch.optim.AdamW(training_model.parameters(), lr=1e-3, weight_decay=1e-4) # L2 Regularization

# 3. Set up training parameters
num_epochs = 20
accumulation_steps = 4 # Effective batch size = 64 * 4 = 256

# 4. Early stopping parameters
patience = 3
patience_counter = 0
best_val_loss = float('inf')

for epoch in range(num_epochs):
    # --- Training Phase ---
    training_model.train()
    for i, (inputs, targets) in enumerate(train_loader):
        logits = training_model(inputs)
        loss = F.cross_entropy(logits, targets)
        loss = loss / accumulation_steps # Normalize loss for accumulation
        loss.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    # --- Validation Phase ---
    training_model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            logits = training_model(inputs)
            val_loss += F.cross_entropy(logits, targets, reduction="sum").item()
    val_loss /= len(val_loader.dataset)
    print(f"Epoch {epoch+1}, Val Loss: {val_loss:.4f}")

    # --- Early Stopping Check ---
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(training_model.state_dict(), "best_model.pth")
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print(f"Stopping early at epoch {epoch+1}. Best val loss: {best_val_loss:.4f}")
        training_model.load_state_dict(torch.load("best_model.pth")) # Restore best model
        break

## Dynamic quantization (CPU-centric)


In [None]:
quantized_model = torch.quantization.quantize_dynamic(
    model.cpu(),
    {nn.Linear},
    dtype=torch.qint8,
)

q_latency = benchmark(quantized_model, dummy_batch.cpu())
print(f"Quantized latency: {q_latency * 1e3:.3f} ms per batch (CPU)")
print(f"FP32 params: {sum(p.numel() for p in model.cpu().parameters())}")
print(f"INT8 params: {sum(p.numel() for p in quantized_model.parameters())}")


## Static Quantization (CNN-centric)
For models with fixed input sizes like CNNs, static quantization is preferred. It involves a calibration step where we feed representative data through the model to compute activation statistics. This often results in better performance than dynamic quantization.

In [None]:
class TinyConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        self.relu2 = nn.ReLU()
        self.fc = nn.Linear(32 * 13 * 13, 10)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.pool(x)
        x = self.relu2(self.conv2(x))
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

cnn_model = TinyConvNet().eval().cpu()
dummy_images = torch.randn(16, 3, 32, 32) # Batch of 16 images

# 1. Fuse modules (Conv + ReLU)
torch.quantization.fuse_modules(cnn_model, [["conv1", "relu1"], ["conv2", "relu2"]], inplace=True)

# 2. Prepare for static quantization
cnn_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
cnn_model_prepared = torch.quantization.prepare(cnn_model)

# 3. Calibrate with representative data
with torch.inference_mode():
    for _ in range(10):
        cnn_model_prepared(dummy_images)

# 4. Convert to a quantized model
cnn_quantized = torch.quantization.convert(cnn_model_prepared)

fp32_size = sum(p.numel() * p.element_size() for p in cnn_model.parameters())
q_size = sum(p.numel() * p.element_size() for p in cnn_quantized.parameters())

print(f"Original CNN size: {fp32_size / 1e3:.1f} KB")
print(f"Quantized CNN size: {q_size / 1e3:.1f} KB")

q_cnn_latency = benchmark(cnn_quantized, dummy_images)
print(f"Quantized CNN latency: {q_cnn_latency * 1e3:.3f} ms per batch (CPU)")

**Quantization tips**
- Dynamic quantization works out of the box for transformer-style `nn.Linear` layers.
- For CNNs, combine static quantization (`prepare_qat` + `convert`) with representative calibration data.
- Monitor accuracy drift after quantization. Keep the first or last layer in FP32 if needed.
- Quantized models run fastest on CPUs with vector dot product instructions (AVX512-VNNI, ARM dotprod, etc.).


## Structured pruning


In [None]:
from torch.nn.utils import prune

pruned_model = TinyClassifier()
parameters_to_prune = [
    (module, "weight")
    for module in pruned_model.modules()
    if isinstance(module, nn.Linear)
]

for module, name in parameters_to_prune:
    prune.l1_unstructured(module, name=name, amount=0.3)

# To make pruning permanent and remove the reparameterization hooks
for module, name in parameters_to_prune:
    prune.remove(module, name)

nonzero = sum(torch.count_nonzero(m.weight) for m in pruned_model.modules() if isinstance(m, nn.Linear))
total = sum(m.weight.numel() for m in pruned_model.modules() if isinstance(m, nn.Linear))
print(f"Remaining weights: {nonzero} / {total} ({nonzero / total:.2%})")


**Pruning workflow**
1. Apply magnitude-based pruning with `torch.nn.utils.prune` to identify removable weights.
2. Fine-tune the sparse model to recover accuracy.
3. Remove pruning reparameterization via `prune.remove(module, "weight")` before exporting.
4. Export to a format that understands sparsity (for example ONNX with sparsity metadata or TensorRT sparsity).


## Export to ONNX, TensorRT, and TFLite


In [None]:
onnx_path = "tiny_classifier.onnx"
torch.onnx.export(
    model.cpu(),
    dummy_batch.cpu(),
    onnx_path,
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
    opset_version=17,
)
print(f"Saved ONNX graph to {onnx_path}")


**Conversion playbook**
- TensorRT: Optimize the exported ONNX using `trtexec --onnx=tiny_classifier.onnx --saveEngine=model.plan`.
  Leverage FP16 or INT8 calibration for mixed precision speedups.
- TFLite: Convert ONNX to TensorFlow (for example via `onnx-tf`) and then apply `tf.lite.TFLiteConverter`. For pure PyTorch flows,
  `torch.export` followed by FX passes can target TFLite-compatible dialects.
- Edge runtimes: Benchmark with `onnxruntime-tools` or `torch_tensorrt` to validate latency versus accuracy budgets.
- Always store calibration datasets alongside models so quantization and engine builds are reproducible.


## Next exploration ideas
- Profile kernels with PyTorch Profiler (`torch.profiler`) to capture operator-level hotspots.
- Combine pruning and quantization (for example sparse INT8) for maximum compression.
- Automate artifact builds using CI so every commit produces ONNX or TensorRT packages with latency baselines.
