[Reference](https://chaimrand.medium.com/optimizing-pytorch-model-inference-on-cpu-ccd3aa5884ad)

# Inference Experiment

In [1]:
import torch, torchvision
import time


def get_model():
    model = torchvision.models.resnet50()
    model = model.eval()
    return model


def get_input(batch_size):
    batch = torch.randn(batch_size, 3, 224, 224)
    return batch


def get_inference_fn(model):
    def infer_fn(batch):
        with torch.inference_mode():
            output = model(batch)
        return output
    return infer_fn


def benchmark(infer_fn, batch):
    # warm-up
    for _ in range(10):
        _ = infer_fn(batch)

    iters = 100

    start = time.time()
    for _ in range(iters):
        _ = infer_fn(batch)
    end = time.time()

    return (end - start) / iters


batch_size = 1
model = get_model()
batch = get_input(batch_size)
infer_fn = get_inference_fn(model)
avg_time = benchmark(infer_fn, batch)
print(f"\nAverage samples per second: {(batch_size/avg_time):.2f}")


Average samples per second: 5.10


# Model Inference Optimization

In [2]:
def get_model(channels_last=False):
    model = torchvision.models.resnet50()
    if channels_last:
        model= model.to(memory_format=torch.channels_last)
    model = model.eval()
    return model

def get_input(batch_size, channels_last=False):
    batch = torch.randn(batch_size, 3, 224, 224)
    if channels_last:
        batch = batch.to(memory_format=torch.channels_last)
    return batch


batch_size = 8
model = get_model(channels_last=True)
batch = get_input(batch_size, channels_last=True)
infer_fn = get_inference_fn(model)
avg_time = benchmark(infer_fn, batch)
print(f"\nAverage samples per second: {(batch_size/avg_time):.2f}")


Average samples per second: 6.16


In [None]:
def get_inference_fn(model, enable_amp=False):
    def infer_fn(batch):
        with torch.inference_mode(), torch.amp.autocast(
                'cpu',
                dtype=torch.bfloat16,
                enabled=enable_amp
        ):
            output = model(batch)
        return output
    return infer_fn

batch_size = 8
model = get_model(channels_last=True)
batch = get_input(batch_size, channels_last=True)
infer_fn = get_inference_fn(model, enable_amp=True)
avg_time = benchmark(infer_fn, batch)
print(f"\nAverage samples per second: {(batch_size/avg_time):.2f}")