<a href="https://colab.research.google.com/github/Tensor-Reloaded/AI-Learning-Hub/blob/main/resources/advanced_pytorch/InferenceOptimizationAndTTA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inference Optimization and Test Time Augmentation (TTA)

This is the inference pipeline for the model trained in  [A complex yet simple efficient training pipeline for CIFAR-10](https://github.com/Tensor-Reloaded/AI-Learning-Hub/blob/main/resources/advanced_pytorch/ComplexYetSimpleTrainingPipeline.ipynb).

In this Notebook you will learn about simple inference optimization you can do while serving pytorch models.

For more advanced information, please check [TorchScript](https://docs.pytorch.org/docs/stable/jit.html) and [onnx](https://onnxruntime.ai/).

In [1]:
!pip install timed-decorator

Collecting timed-decorator
  Downloading timed_decorator-1.6.1-py3-none-any.whl.metadata (18 kB)
Downloading timed_decorator-1.6.1-py3-none-any.whl (12 kB)
Installing collected packages: timed-decorator
Successfully installed timed-decorator-1.6.1


In [2]:
import os
from itertools import product
from pathlib import Path
from typing import Tuple

import timm
import torch
from prettytable import PrettyTable
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2
from torchvision.transforms.v2.functional import hflip
from timed_decorator.simple_timed import timed
from tqdm import tqdm

This is the same model trained in the previous training notebook

In [3]:
class ClassificationModel(nn.Module):
    def __init__(self, backbone_name: str = "resnet18", num_classes: int = 10):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=False)
        self.backbone.fc = nn.Linear(self.backbone.fc.weight.size(1), num_classes)

    def forward(self, x):
        return self.backbone(x)


This method creates a model on a device and prepares it for serving.
Depending on the optimization type, the model is optimized using different TorchScript jit utilities, or even compiled using torch.compile.

For a detailed description of each jit/compile method, please check the official documentation and official tutorials.

In [4]:
def create_model(model_path: str, device: torch.device, model_type: str):
    model_data = torch.load(model_path, map_location=device, weights_only=True)

    model = ClassificationModel(model_data["model_name"], model_data["num_classes"])
    model = model.to(device)
    model.load_state_dict(model_data["model_state_dict"])
    model.eval()

    if model_type == "raw model":
        return model
    if model_type == "scripted model":
        return torch.jit.script(model)
    if model_type == "traced model":
        return torch.jit.trace(model, torch.rand((5, 3, 32, 32), device=device))
    if model_type == "frozen model":
        return torch.jit.freeze(torch.jit.script(model))
    if model_type == "optimized for inference":
        return torch.jit.optimize_for_inference(torch.jit.script(model))
    if model_type == "compiled model":
        if os.name == "nt":
            print("torch.compile is not supported on Windows. Try Linux or WSL instead.")
            return model
        return torch.compile(model)
    raise RuntimeError("std::unreachable")


This function performs inference and TTA, while measuring the elapsed time.

There are 4 versions of TTA:
* no TTA
* mirroring: performs a horizontal flip for the input images, doing an additional inference pass
* translate: performs 8 translations of the input images, doing 8 additional inference passes
* mirroring_and_translate

In [5]:
@timed(stdout=False, return_time=True, use_seconds=True)
def tta_inference(model, batches: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], device: torch.device,
                  tta_type: str) -> float:
    total = 0
    correct = 0

    for data, target in batches:
        data = data.to(device)

        predicted = model(data)
        if tta_type == "mirroring":
            predicted += model(hflip(data))
        elif tta_type == "translate":
            padding_size = 2
            image_size = 32
            # We pad using the same value the model has seen during training
            padded = v2.functional.pad(data, [padding_size], fill=0.5)
            for i in [-2, 0, 2]:
                for j in [-2, 0, 2]:
                    if i == 0 and j == 0:
                        continue
                    x = padding_size + i
                    y = padding_size + j
                    predicted += model(padded[:, :, x:x + image_size, y:y + image_size])
        elif tta_type == "mirroring_and_translate":
            padding_size = 2
            image_size = 32
            padded = v2.functional.pad(data, [padding_size], fill=0.5)
            for i in [-2, 0, 2]:
                for j in [-2, 0, 2]:
                    if i == 0 and j == 0:
                        continue
                    x = padding_size + i
                    y = padding_size + j
                    aux = padded[:, :, x:x + image_size, y:y + image_size]
                    predicted += model(aux)
                    predicted += model(hflip(aux))

        correct += (predicted.cpu().argmax(dim=1) == target).sum().item()
        total += data.size(0)

    return round(correct / total, 4)


We use the automated mixed precision module to automatically cast to our desired data type. We measure the accuracy and the elapsed time of the configuration.

In [6]:
def inference(model, batches: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], device: torch.device, tta_type: str,
              dtype: torch.dtype, model_type: str) -> Tuple[float, float]:
    enable_autocast = device.type == "cuda" and dtype != torch.float32
    # Autocast is slow for cpu, so we disable it.
    # Also, if the device type is mps, autocast might not work (?)
    accuracy, elapsed = "N/A", "N/A"
    try:

        with torch.autocast(device_type=device.type, dtype=dtype, enabled=enable_autocast), torch.inference_mode():
            accuracy, elapsed = tta_inference(model, batches, device, tta_type)
    except:
        # Debug only

        # import traceback
        # traceback.print_exc()
        print(f"Model type {model_type} failed on {dtype} on {device.type}")

    return accuracy, elapsed


We prepare the data and pack it into a tuple for fast dataloading and pre-batching. During inference, data is read directly from memory.

In [7]:
def prepare_data(data_path: str) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
    transforms = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=(0.491, 0.482, 0.446), std=(0.247, 0.243, 0.261), inplace=True)
    ])
    dataset = CIFAR10(root=data_path, train=False, transform=transforms, download=True)
    dataloader = DataLoader(dataset, batch_size=200)
    return tuple([x for x in dataloader])


This is the speed test, and the results from my computer. The results for Google Colab are below.

In [8]:
def do_speed_test(data: Tuple[Tuple[torch.Tensor, torch.Tensor], ...],
                  model_types: Tuple[str, ...],
                  dtypes: Tuple[torch.dtype, ...],
                  tta_types: Tuple[str, ...],
                  devices: Tuple[torch.device | None, ...],
                  model_path: str):
    tta_type = "none"
    with tqdm(total=len(devices) * len(dtypes) * len(model_types), desc="Speed experiments") as tbar:
        for device, dtype in product(devices, dtypes):
            if device is None:
                tbar.update(len(model_types))
                continue
            speed_results = PrettyTable()
            speed_results.field_names = ["Device", "Dtype", "TTA Type", "Model Type", "Accuracy", "Elapsed"]

            for model_type in model_types:
                model = create_model(model_path, device, model_type)
                accuracy, elapsed = inference(model, data, device, tta_type, dtype, model_type)
                speed_results.add_row([device, dtype, tta_type, model_type, accuracy, elapsed])
                tbar.update()

            print(speed_results)

    # CUDA Results
    # +--------+----------------+----------+-------------------------+----------+-------------+
    # | Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed   |
    # +--------+----------------+----------+-------------------------+----------+-------------+
    # |  cuda  | torch.bfloat16 |   none   |        raw model        |  0.8627  |  0.74056251 |
    # |  cuda  | torch.bfloat16 |   none   |      scripted model     |  0.8627  | 0.550711881 |
    # |  cuda  | torch.bfloat16 |   none   |      scripted model     |  0.8627  | 0.466999062 |
    # |  cuda  | torch.bfloat16 |   none   |       traced model      |  0.8627  | 0.505114635 |
    # |  cuda  | torch.bfloat16 |   none   |       traced model      |  0.8627  | 0.497691016 |
    # |  cuda  | torch.bfloat16 |   none   |       frozen model      |  0.8618  | 0.630178739 |
    # |  cuda  | torch.bfloat16 |   none   |       frozen model      |  0.8618  | 0.431321397 |
    # |  cuda  | torch.bfloat16 |   none   | optimized for inference |   N/A    |     N/A     |
    # |  cuda  | torch.bfloat16 |   none   | optimized for inference |   N/A    |     N/A     |
    # |  cuda  | torch.bfloat16 |   none   |      compiled model     |  0.863   |  1.37197609 |
    # |  cuda  | torch.bfloat16 |   none   |      compiled model     |  0.863   | 0.439737346 |
    # +--------+----------------+----------+-------------------------+----------+-------------+

    # +--------+----------------+----------+-------------------------+----------+-------------+
    # | Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed   |
    # +--------+----------------+----------+-------------------------+----------+-------------+
    # |  cuda  | torch.float16  |   none   |        raw model        |  0.8629  | 0.934939784 |
    # |  cuda  | torch.float16  |   none   |      scripted model     |  0.8629  | 0.776701284 |
    # |  cuda  | torch.float16  |   none   |      scripted model     |  0.8629  |  0.65642132 |
    # |  cuda  | torch.float16  |   none   |       traced model      |  0.8629  | 0.770792187 |
    # |  cuda  | torch.float16  |   none   |       traced model      |  0.8629  | 0.761494488 |
    # |  cuda  | torch.float16  |   none   |       frozen model      |  0.8629  | 0.449910122 |
    # |  cuda  | torch.float16  |   none   |       frozen model      |  0.8629  | 0.428042867 |
    # |  cuda  | torch.float16  |   none   | optimized for inference |   N/A    |     N/A     |
    # |  cuda  | torch.float16  |   none   | optimized for inference |   N/A    |     N/A     |
    # |  cuda  | torch.float16  |   none   |      compiled model     |  0.863   | 1.041609176 |
    # |  cuda  | torch.float16  |   none   |      compiled model     |  0.863   | 0.304629578 |
    # +--------+----------------+----------+-------------------------+----------+-------------+

    # +--------+----------------+----------+-------------------------+----------+-------------+
    # | Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed   |
    # +--------+----------------+----------+-------------------------+----------+-------------+
    # |  cuda  | torch.float32  |   none   |        raw model        |  0.8628  | 0.928239116 |
    # |  cuda  | torch.float32  |   none   |      scripted model     |  0.8628  | 0.869112261 |
    # |  cuda  | torch.float32  |   none   |      scripted model     |  0.8628  | 0.818328065 |
    # |  cuda  | torch.float32  |   none   |       traced model      |  0.8628  | 0.831756814 |
    # |  cuda  | torch.float32  |   none   |       traced model      |  0.8628  | 0.835166337 |
    # |  cuda  | torch.float32  |   none   |       frozen model      |  0.8628  | 0.635774185 |
    # |  cuda  | torch.float32  |   none   |       frozen model      |  0.8628  | 0.884842387 |
    # |  cuda  | torch.float32  |   none   | optimized for inference |  0.8628  | 6.401095805 |
    # |  cuda  | torch.float32  |   none   | optimized for inference |  0.8628  | 6.383807842 |
    # |  cuda  | torch.float32  |   none   |      compiled model     |  0.8628  | 0.979224372 |
    # |  cuda  | torch.float32  |   none   |      compiled model     |  0.8628  | 0.510377062 |
    # +--------+----------------+----------+-------------------------+----------+-------------+

    # CPU Results
    # +--------+----------------+----------+-------------------------+----------+-------------+
    # | Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed   |
    # +--------+----------------+----------+-------------------------+----------+-------------+
    # |  cpu   | torch.bfloat16 |   none   |        raw model        |  0.8628  | 2.859445197 |
    # |  cpu   | torch.bfloat16 |   none   |      scripted model     |  0.8628  | 2.635952067 |
    # |  cpu   | torch.bfloat16 |   none   |      scripted model     |  0.8628  | 2.604736663 |
    # |  cpu   | torch.bfloat16 |   none   |       traced model      |  0.8628  | 2.631448843 |
    # |  cpu   | torch.bfloat16 |   none   |       traced model      |  0.8628  | 2.576900248 |
    # |  cpu   | torch.bfloat16 |   none   |       frozen model      |  0.8628  | 2.546161701 |
    # |  cpu   | torch.bfloat16 |   none   |       frozen model      |  0.8628  | 2.502300936 |
    # |  cpu   | torch.bfloat16 |   none   | optimized for inference |  0.8628  | 2.281604414 |
    # |  cpu   | torch.bfloat16 |   none   | optimized for inference |  0.8628  | 2.225087941 |
    # |  cpu   | torch.bfloat16 |   none   |      compiled model     |  0.8628  |  3.58207681 |
    # |  cpu   | torch.bfloat16 |   none   |      compiled model     |  0.8628  | 1.722796112 |
    # +--------+----------------+----------+-------------------------+----------+-------------+

    # +--------+----------------+----------+-------------------------+----------+-------------+
    # | Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed   |
    # +--------+----------------+----------+-------------------------+----------+-------------+
    # |  cpu   | torch.float16  |   none   |        raw model        |  0.8628  | 2.737279273 |
    # |  cpu   | torch.float16  |   none   |      scripted model     |  0.8628  | 2.562341959 |
    # |  cpu   | torch.float16  |   none   |      scripted model     |  0.8628  | 2.652842815 |
    # |  cpu   | torch.float16  |   none   |       traced model      |  0.8628  | 2.639518142 |
    # |  cpu   | torch.float16  |   none   |       traced model      |  0.8628  | 2.735255652 |
    # |  cpu   | torch.float16  |   none   |       frozen model      |  0.8628  | 2.903561699 |
    # |  cpu   | torch.float16  |   none   |       frozen model      |  0.8628  | 2.962546338 |
    # |  cpu   | torch.float16  |   none   | optimized for inference |  0.8628  | 2.344554807 |
    # |  cpu   | torch.float16  |   none   | optimized for inference |  0.8628  | 2.360003218 |
    # |  cpu   | torch.float16  |   none   |      compiled model     |  0.8628  | 1.730791658 |
    # |  cpu   | torch.float16  |   none   |      compiled model     |  0.8628  | 1.754020479 |
    # +--------+----------------+----------+-------------------------+----------+-------------+

    # +--------+----------------+----------+-------------------------+----------+-------------+
    # | Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed   |
    # +--------+----------------+----------+-------------------------+----------+-------------+
    # |  cpu   | torch.float32  |   none   |        raw model        |  0.8628  |  2.65187362 |
    # |  cpu   | torch.float32  |   none   |      scripted model     |  0.8628  | 2.620745015 |
    # |  cpu   | torch.float32  |   none   |      scripted model     |  0.8628  | 2.583938025 |
    # |  cpu   | torch.float32  |   none   |       traced model      |  0.8628  |  2.68518527 |
    # |  cpu   | torch.float32  |   none   |       traced model      |  0.8628  | 2.670001929 |
    # |  cpu   | torch.float32  |   none   |       frozen model      |  0.8628  | 2.853723278 |
    # |  cpu   | torch.float32  |   none   |       frozen model      |  0.8628  | 2.903551512 |
    # |  cpu   | torch.float32  |   none   | optimized for inference |  0.8628  |  2.47234354 |
    # |  cpu   | torch.float32  |   none   | optimized for inference |  0.8628  | 2.308440241 |
    # |  cpu   | torch.float32  |   none   |      compiled model     |  0.8628  |  1.78701703 |
    # |  cpu   | torch.float32  |   none   |      compiled model     |  0.8628  | 1.771141438 |
    # +--------+----------------+----------+-------------------------+----------+-------------+


Results from Google Colab CUDA

In [None]:
# +--------+---------------+----------+-------------------------+----------+-------------+
# | Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed   |
# +--------+---------------+----------+-------------------------+----------+-------------+
# |  cuda  | torch.float16 |   none   |        raw model        |  0.8628  | 0.456860901 |
# |  cuda  | torch.float16 |   none   |      scripted model     |  0.8627  | 0.245520422 |
# |  cuda  | torch.float16 |   none   |      scripted model     |  0.8627  |  0.24478505 |
# |  cuda  | torch.float16 |   none   |       traced model      |  0.8628  | 0.297093739 |
# |  cuda  | torch.float16 |   none   |       traced model      |  0.8628  | 0.298865424 |
# |  cuda  | torch.float16 |   none   |       frozen model      |  0.8629  | 0.189003802 |
# |  cuda  | torch.float16 |   none   |       frozen model      |  0.8629  | 0.188668645 |
# |  cuda  | torch.float16 |   none   | optimized for inference |   N/A    |     N/A     |
# |  cuda  | torch.float16 |   none   | optimized for inference |   N/A    |     N/A     |
# |  cuda  | torch.float16 |   none   |      compiled model     |  0.8628  | 0.351826348 |
# |  cuda  | torch.float16 |   none   |      compiled model     |  0.8628  | 0.245900978 |
# +--------+---------------+----------+-------------------------+----------+-------------+

# +--------+---------------+----------+-------------------------+----------+-------------+
# | Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed   |
# +--------+---------------+----------+-------------------------+----------+-------------+
# |  cuda  | torch.float16 |   none   |        raw model        |  0.8628  | 0.456860901 |
# |  cuda  | torch.float16 |   none   |      scripted model     |  0.8627  | 0.245520422 |
# |  cuda  | torch.float16 |   none   |      scripted model     |  0.8627  |  0.24478505 |
# |  cuda  | torch.float16 |   none   |       traced model      |  0.8628  | 0.297093739 |
# |  cuda  | torch.float16 |   none   |       traced model      |  0.8628  | 0.298865424 |
# |  cuda  | torch.float16 |   none   |       frozen model      |  0.8629  | 0.189003802 |
# |  cuda  | torch.float16 |   none   |       frozen model      |  0.8629  | 0.188668645 |
# |  cuda  | torch.float16 |   none   | optimized for inference |   N/A    |     N/A     |
# |  cuda  | torch.float16 |   none   | optimized for inference |   N/A    |     N/A     |
# |  cuda  | torch.float16 |   none   |      compiled model     |  0.8628  | 0.351826348 |
# |  cuda  | torch.float16 |   none   |      compiled model     |  0.8628  | 0.245900978 |
# +--------+---------------+----------+-------------------------+----------+-------------+

# +--------+---------------+----------+-------------------------+----------+-------------+
# | Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed   |
# +--------+---------------+----------+-------------------------+----------+-------------+
# |  cuda  | torch.float32 |   none   |        raw model        |  0.8628  | 0.612244879 |
# |  cuda  | torch.float32 |   none   |      scripted model     |  0.8628  | 0.714759405 |
# |  cuda  | torch.float32 |   none   |      scripted model     |  0.8628  | 0.589165576 |
# |  cuda  | torch.float32 |   none   |       traced model      |  0.8628  | 0.616923644 |
# |  cuda  | torch.float32 |   none   |       traced model      |  0.8628  | 0.614003439 |
# |  cuda  | torch.float32 |   none   |       frozen model      |  0.8628  | 0.560715173 |
# |  cuda  | torch.float32 |   none   |       frozen model      |  0.8628  | 0.559296638 |
# |  cuda  | torch.float32 |   none   | optimized for inference |  0.8628  |  0.55052995 |
# |  cuda  | torch.float32 |   none   | optimized for inference |  0.8628  | 0.551463229 |
# |  cuda  | torch.float32 |   none   |      compiled model     |  0.8628  | 0.591989673 |
# |  cuda  | torch.float32 |   none   |      compiled model     |  0.8628  | 0.589906088 |
# +--------+---------------+----------+-------------------------+----------+-------------+

Results from Google Colab CPU

In [None]:
# +--------+----------------+----------+-------------------------+----------+--------------+
# | Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed    |
# +--------+----------------+----------+-------------------------+----------+--------------+
# |  cpu   | torch.bfloat16 |   none   |        raw model        |  0.8628  | 12.251049287 |
# |  cpu   | torch.bfloat16 |   none   |      scripted model     |  0.8628  | 10.502742393 |
# |  cpu   | torch.bfloat16 |   none   |      scripted model     |  0.8628  | 10.219771897 |
# |  cpu   | torch.bfloat16 |   none   |       traced model      |  0.8628  | 9.444874722  |
# |  cpu   | torch.bfloat16 |   none   |       traced model      |  0.8628  | 10.067524619 |
# |  cpu   | torch.bfloat16 |   none   |       frozen model      |  0.8628  | 9.681582857  |
# |  cpu   | torch.bfloat16 |   none   |       frozen model      |  0.8628  |  9.74893516  |
# |  cpu   | torch.bfloat16 |   none   | optimized for inference |  0.8628  | 8.125695756  |
# |  cpu   | torch.bfloat16 |   none   | optimized for inference |  0.8628  |  6.87960141  |
# |  cpu   | torch.bfloat16 |   none   |      compiled model     |  0.8628  | 9.944255357  |
# |  cpu   | torch.bfloat16 |   none   |      compiled model     |  0.8628  | 9.934951253  |
# +--------+----------------+----------+-------------------------+----------+--------------+

# +--------+---------------+----------+-------------------------+----------+--------------+
# | Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed    |
# +--------+---------------+----------+-------------------------+----------+--------------+
# |  cpu   | torch.float16 |   none   |        raw model        |  0.8628  | 10.195400691 |
# |  cpu   | torch.float16 |   none   |      scripted model     |  0.8628  | 10.132698804 |
# |  cpu   | torch.float16 |   none   |      scripted model     |  0.8628  | 9.415232976  |
# |  cpu   | torch.float16 |   none   |       traced model      |  0.8628  | 9.189693126  |
# |  cpu   | torch.float16 |   none   |       traced model      |  0.8628  | 9.873023653  |
# |  cpu   | torch.float16 |   none   |       frozen model      |  0.8628  | 9.554109585  |
# |  cpu   | torch.float16 |   none   |       frozen model      |  0.8628  | 9.601258977  |
# |  cpu   | torch.float16 |   none   | optimized for inference |  0.8628  | 7.365994933  |
# |  cpu   | torch.float16 |   none   | optimized for inference |  0.8628  | 6.933587296  |
# |  cpu   | torch.float16 |   none   |      compiled model     |  0.8628  | 10.214564148 |
# |  cpu   | torch.float16 |   none   |      compiled model     |  0.8628  | 10.10502345  |
# +--------+---------------+----------+-------------------------+----------+--------------+

# +--------+---------------+----------+-------------------------+----------+--------------+
# | Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed    |
# +--------+---------------+----------+-------------------------+----------+--------------+
# |  cpu   | torch.float32 |   none   |        raw model        |  0.8628  | 10.396136796 |
# |  cpu   | torch.float32 |   none   |      scripted model     |  0.8628  | 10.187017377 |
# |  cpu   | torch.float32 |   none   |      scripted model     |  0.8628  | 9.242592588  |
# |  cpu   | torch.float32 |   none   |       traced model      |  0.8628  |  9.72705402  |
# |  cpu   | torch.float32 |   none   |       traced model      |  0.8628  | 11.291153583 |
# |  cpu   | torch.float32 |   none   |       frozen model      |  0.8628  | 10.222028228 |
# |  cpu   | torch.float32 |   none   |       frozen model      |  0.8628  | 9.580895448  |
# |  cpu   | torch.float32 |   none   | optimized for inference |  0.8628  |  7.60580902  |
# |  cpu   | torch.float32 |   none   | optimized for inference |  0.8628  | 6.929277715  |
# |  cpu   | torch.float32 |   none   |      compiled model     |  0.8628  | 9.964773122  |
# |  cpu   | torch.float32 |   none   |      compiled model     |  0.8628  | 9.956867878  |
# +--------+---------------+----------+-------------------------+----------+--------------+

Here we have the TTA test. TTA can really improve the results.

In [9]:
def do_tta_test(data: Tuple[Tuple[torch.Tensor, torch.Tensor], ...],
                model_types: Tuple[str, ...],
                dtypes: Tuple[torch.dtype, ...],
                tta_types: Tuple[str, ...],
                devices: Tuple[torch.device | None, ...],
                model_path: str):
    tta_results = PrettyTable()
    tta_results.field_names = ["Device", "Dtype", "TTA Type", "Model Type", "Accuracy", "Elapsed"]

    device = devices[0] if devices[0] is not None else devices[1]
    model_type = "scripted model"

    for dtype, tta_type in tqdm(tuple(product(dtypes, tta_types)), desc="TTA experiments"):
        if device is None:
            continue
        model = create_model(model_path, device, model_type)
        accuracy, elapsed = inference(model, data, device, tta_type, dtype, model_type)
        tta_results.add_row([device, dtype, tta_type, model_type, accuracy, elapsed])

    print(tta_results)

    # +--------+----------------+-------------------------+----------------+----------+-------------+
    # | Device |     Dtype      |         TTA Type        |   Model Type   | Accuracy |   Elapsed   |
    # +--------+----------------+-------------------------+----------------+----------+-------------+
    # |  cuda  | torch.bfloat16 |           none          | scripted model |  0.8627  |  0.86833901 |
    # |  cuda  | torch.bfloat16 |        mirroring        | scripted model |  0.8729  | 0.369391463 |
    # |  cuda  | torch.bfloat16 |        translate        | scripted model |  0.8733  | 1.242899479 |
    # |  cuda  | torch.bfloat16 | mirroring_and_translate | scripted model |  0.8783  | 2.240457862 |
    # |  cuda  | torch.float16  |           none          | scripted model |  0.8629  | 0.359551112 |
    # |  cuda  | torch.float16  |        mirroring        | scripted model |  0.8729  | 0.314461259 |
    # |  cuda  | torch.float16  |        translate        | scripted model |  0.8733  | 1.377882807 |
    # |  cuda  | torch.float16  | mirroring_and_translate | scripted model |  0.8783  |  2.50420385 |
    # |  cuda  | torch.float32  |           none          | scripted model |  0.8628  | 0.302331347 |
    # |  cuda  | torch.float32  |        mirroring        | scripted model |  0.8728  | 0.323998472 |
    # |  cuda  | torch.float32  |        translate        | scripted model |  0.8735  | 1.356375027 |
    # |  cuda  | torch.float32  | mirroring_and_translate | scripted model |  0.8785  | 2.535533913 |
    # +--------+----------------+-------------------------+----------------+----------+-------------+


The full setup.

In [10]:
def main(model_path: str):
    data = prepare_data("./data")
    model_types = (
        "raw model",
        "scripted model",
        "scripted model",
        "traced model",
        "traced model",
        "frozen model",
        "frozen model",
        "optimized for inference",
        "optimized for inference",
        "compiled model",
        "compiled model",
    )
    dtypes = (
        torch.bfloat16,
        torch.half,
        torch.float32
    )
    tta_types = (
        "none",
        "mirroring",
        "translate",
        "mirroring_and_translate",
    )
    devices = (
        torch.accelerator.current_accelerator() if torch.accelerator.is_available() else None,
        torch.device("cpu"),
    )

    do_speed_test(data, model_types, dtypes, tta_types, devices, model_path)
    do_tta_test(data, model_types, dtypes, tta_types, devices, model_path)


We download the inference file

In [13]:
![ ! -f best.pth ] && curl -L -o best.pth https://raw.githubusercontent.com/Tensor-Reloaded/AI-Learning-Hub/main/resources/advanced_pytorch/checkpoints/best.pth

In [None]:
if __name__ == "__main__":
    torch.set_float32_matmul_precision('high')
    model_path = "best.pth"
    main(model_path)


100%|██████████| 170M/170M [00:13<00:00, 12.4MB/s]
Speed experiments:  12%|█▏        | 8/66 [02:03<11:42, 12.12s/it]

Model type optimized for inference failed on torch.bfloat16 on cuda


Speed experiments:  14%|█▎        | 9/66 [02:04<08:08,  8.56s/it]

Model type optimized for inference failed on torch.bfloat16 on cuda


W0930 12:17:15.594000 381 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode
Speed experiments:  17%|█▋        | 11/66 [02:55<15:14, 16.63s/it]

+--------+----------------+----------+-------------------------+----------+--------------+
| Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed    |
+--------+----------------+----------+-------------------------+----------+--------------+
|  cuda  | torch.bfloat16 |   none   |        raw model        |  0.8628  | 16.884628199 |
|  cuda  | torch.bfloat16 |   none   |      scripted model     |  0.8628  | 16.588373199 |
|  cuda  | torch.bfloat16 |   none   |      scripted model     |  0.8628  | 16.38698488  |
|  cuda  | torch.bfloat16 |   none   |       traced model      |  0.8628  | 16.580808231 |
|  cuda  | torch.bfloat16 |   none   |       traced model      |  0.8628  | 16.387139424 |
|  cuda  | torch.bfloat16 |   none   |       frozen model      |  0.8627  | 16.359960733 |
|  cuda  | torch.bfloat16 |   none   |       frozen model      |  0.8627  | 16.39242669  |
|  cuda  | torch.bfloat16 |   none   | optimized for inference |   N/A    |     N/A      |

Speed experiments:  29%|██▉       | 19/66 [03:03<01:26,  1.83s/it]

Model type optimized for inference failed on torch.float16 on cuda


Speed experiments:  30%|███       | 20/66 [03:03<01:10,  1.52s/it]

Model type optimized for inference failed on torch.float16 on cuda


Speed experiments:  33%|███▎      | 22/66 [03:05<00:52,  1.18s/it]

+--------+---------------+----------+-------------------------+----------+-------------+
| Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed   |
+--------+---------------+----------+-------------------------+----------+-------------+
|  cuda  | torch.float16 |   none   |        raw model        |  0.8628  | 0.456860901 |
|  cuda  | torch.float16 |   none   |      scripted model     |  0.8627  | 0.245520422 |
|  cuda  | torch.float16 |   none   |      scripted model     |  0.8627  |  0.24478505 |
|  cuda  | torch.float16 |   none   |       traced model      |  0.8628  | 0.297093739 |
|  cuda  | torch.float16 |   none   |       traced model      |  0.8628  | 0.298865424 |
|  cuda  | torch.float16 |   none   |       frozen model      |  0.8629  | 0.189003802 |
|  cuda  | torch.float16 |   none   |       frozen model      |  0.8629  | 0.188668645 |
|  cuda  | torch.float16 |   none   | optimized for inference |   N/A    |     N/A     |
|  cuda  | torch.floa

Speed experiments:  50%|█████     | 33/66 [03:20<00:42,  1.28s/it]

+--------+---------------+----------+-------------------------+----------+-------------+
| Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed   |
+--------+---------------+----------+-------------------------+----------+-------------+
|  cuda  | torch.float32 |   none   |        raw model        |  0.8628  | 0.612244879 |
|  cuda  | torch.float32 |   none   |      scripted model     |  0.8628  | 0.714759405 |
|  cuda  | torch.float32 |   none   |      scripted model     |  0.8628  | 0.589165576 |
|  cuda  | torch.float32 |   none   |       traced model      |  0.8628  | 0.616923644 |
|  cuda  | torch.float32 |   none   |       traced model      |  0.8628  | 0.614003439 |
|  cuda  | torch.float32 |   none   |       frozen model      |  0.8628  | 0.560715173 |
|  cuda  | torch.float32 |   none   |       frozen model      |  0.8628  | 0.559296638 |
|  cuda  | torch.float32 |   none   | optimized for inference |  0.8628  |  0.55052995 |
|  cuda  | torch.floa

Speed experiments:  67%|██████▋   | 44/66 [05:15<03:34,  9.76s/it]

+--------+----------------+----------+-------------------------+----------+--------------+
| Device |     Dtype      | TTA Type |        Model Type       | Accuracy |   Elapsed    |
+--------+----------------+----------+-------------------------+----------+--------------+
|  cpu   | torch.bfloat16 |   none   |        raw model        |  0.8628  | 12.251049287 |
|  cpu   | torch.bfloat16 |   none   |      scripted model     |  0.8628  | 10.502742393 |
|  cpu   | torch.bfloat16 |   none   |      scripted model     |  0.8628  | 10.219771897 |
|  cpu   | torch.bfloat16 |   none   |       traced model      |  0.8628  | 9.444874722  |
|  cpu   | torch.bfloat16 |   none   |       traced model      |  0.8628  | 10.067524619 |
|  cpu   | torch.bfloat16 |   none   |       frozen model      |  0.8628  | 9.681582857  |
|  cpu   | torch.bfloat16 |   none   |       frozen model      |  0.8628  |  9.74893516  |
|  cpu   | torch.bfloat16 |   none   | optimized for inference |  0.8628  | 8.125695756  |

Speed experiments:  83%|████████▎ | 55/66 [07:06<01:49,  9.93s/it]

+--------+---------------+----------+-------------------------+----------+--------------+
| Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed    |
+--------+---------------+----------+-------------------------+----------+--------------+
|  cpu   | torch.float16 |   none   |        raw model        |  0.8628  | 10.195400691 |
|  cpu   | torch.float16 |   none   |      scripted model     |  0.8628  | 10.132698804 |
|  cpu   | torch.float16 |   none   |      scripted model     |  0.8628  | 9.415232976  |
|  cpu   | torch.float16 |   none   |       traced model      |  0.8628  | 9.189693126  |
|  cpu   | torch.float16 |   none   |       traced model      |  0.8628  | 9.873023653  |
|  cpu   | torch.float16 |   none   |       frozen model      |  0.8628  | 9.554109585  |
|  cpu   | torch.float16 |   none   |       frozen model      |  0.8628  | 9.601258977  |
|  cpu   | torch.float16 |   none   | optimized for inference |  0.8628  | 7.365994933  |
|  cpu   |

Speed experiments: 100%|██████████| 66/66 [08:59<00:00,  8.18s/it]


+--------+---------------+----------+-------------------------+----------+--------------+
| Device |     Dtype     | TTA Type |        Model Type       | Accuracy |   Elapsed    |
+--------+---------------+----------+-------------------------+----------+--------------+
|  cpu   | torch.float32 |   none   |        raw model        |  0.8628  | 10.396136796 |
|  cpu   | torch.float32 |   none   |      scripted model     |  0.8628  | 10.187017377 |
|  cpu   | torch.float32 |   none   |      scripted model     |  0.8628  | 9.242592588  |
|  cpu   | torch.float32 |   none   |       traced model      |  0.8628  |  9.72705402  |
|  cpu   | torch.float32 |   none   |       traced model      |  0.8628  | 11.291153583 |
|  cpu   | torch.float32 |   none   |       frozen model      |  0.8628  | 10.222028228 |
|  cpu   | torch.float32 |   none   |       frozen model      |  0.8628  | 9.580895448  |
|  cpu   | torch.float32 |   none   | optimized for inference |  0.8628  |  7.60580902  |
|  cpu   |

TTA experiments:   8%|▊         | 1/12 [00:17<03:08, 17.17s/it]

The full inference script is available in [inference_optimization_and_tta.py](./inference_optimization_and_tta.py).

Exercises:
1. Test this inference pipeline for a better performing model.
2. Export a model trained with your pipeline to ONNX. Check the performance gains.
3. Load a torch scripted model in C++ using a libtorch docker container. Check the performance gains.

---

| All     | [advanced_pytorch/](https://github.com/Tensor-Reloaded/AI-Learning-Hub/blob/main/resources/advanced_pytorch) |
|---------|-- |
| Current | [Inference Optimization And TTA](https://github.com/Tensor-Reloaded/AI-Learning-Hub/blob/main/resources/advanced_pytorch/InferenceOptimizationAndTTA.ipynb) |