### Compare various naive methods of optmizing inference speed

In [1]:
# Disable cuda
import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.datasets import SimpleOxfordPetDataset

import timeit

root = "."


In [3]:
train_dataset = SimpleOxfordPetDataset(root, "train")
valid_dataset = SimpleOxfordPetDataset(root, "valid")
test_dataset = SimpleOxfordPetDataset(root, "test")

# It is a good practice to check datasets don`t intersects with each other
assert set(test_dataset.filenames).isdisjoint(set(train_dataset.filenames))
assert set(test_dataset.filenames).isdisjoint(set(valid_dataset.filenames))
assert set(train_dataset.filenames).isdisjoint(set(valid_dataset.filenames))

print(f"Train size: {len(train_dataset)}")
print(f"Valid size: {len(valid_dataset)}")
print(f"Test size: {len(test_dataset)}")

n_cpu = os.cpu_count()
train_dataloader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, num_workers=n_cpu)
valid_dataloader = DataLoader(
    valid_dataset, batch_size=16, shuffle=False, num_workers=n_cpu)
test_dataloader = DataLoader(
    test_dataset, batch_size=16, shuffle=False, num_workers=n_cpu)


Train size: 3312
Valid size: 368
Test size: 3669


In [4]:
class PetModel(nn.Module):
    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.std = torch.tensor(params["std"]).view(1, 3, 1, 1)
        self.mean = torch.tensor(params["mean"]).view(1, 3, 1, 1)
        self.loss_fn = smp.losses.DiceLoss(
            smp.losses.BINARY_MODE, from_logits=True)

    def forward(self, image):
        image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask


In [5]:
arch = "FPN"
encoder_name = "resnet34"
in_channels = 3
out_classes = 1
model = PetModel(arch, encoder_name, in_channels, out_classes)
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

In [6]:
# take one batch from valid dataloader
batch = next(iter(valid_dataloader))

In [7]:
batch['image'].shape, batch['mask'].shape

(torch.Size([16, 3, 256, 256]), torch.Size([16, 1, 256, 256]))

In [None]:
model.load_state_dict(torch.load('model.pt'))

In [8]:
def benchmark(model, data):
    model.eval()
    with torch.no_grad():
        return model(data)

In [9]:
model = model.eval()

## Torch Inference

In [10]:
time = timeit.timeit(lambda: benchmark(model, batch['image']), number=1)
print(f"warmup time - Batch 16: {time:.4f} s")
time = timeit.timeit(lambda: benchmark(model, batch['image']), number=10)
print(f"benchmark time: {time:.4f} s")
time = timeit.timeit(lambda: benchmark(
    model, batch['image'][0].unsqueeze(0)), number=1)
print(f"warmup time - Batch 1: {time:.4f} s")

time = timeit.timeit(lambda: benchmark(
    model, batch['image'][0].unsqueeze(0)), number=10)
print(f"benchmark time (Batch 1): {time:.4f} s")
time = time/10
print(f"Per Iter time(Batch 1): {time:.4f} s")


warmup time - Batch 16: 1.1051 s


In [23]:
import intel_extension_for_pytorch as ipex

## IPEX - Float32

In [15]:
model = ipex.optimize(model, dtype=torch.float32)



In [25]:
time = timeit.timeit(lambda: benchmark(model, batch['image']), number=1)
print(f"warmup time - Batch 16: {time:.4f} s")
time = timeit.timeit(lambda: benchmark(model, batch['image']), number=10)
print(f"benchmark time (Batch 16): {time:.4f} s")
time = time/10
print(f"Per Iter time(Batch 16): {time:.4f} s")
time = timeit.timeit(lambda: benchmark(
    model, batch['image'][0].unsqueeze(0)), number=1)
print(f"warmup time - Batch 1: {time:.4f} s")

time = timeit.timeit(lambda: benchmark(
    model, batch['image'][0].unsqueeze(0)), number=10)
print(f"benchmark time (Batch 1): {time:.4f} s")
time = time/10
print(f"Per Iter time(Batch 1): {time:.4f} s")


warmup time - Batch 16: 0.6036 s
benchmark time (Batch 16): 6.6450 s
Per Iter time(Batch 16): 0.6645 s
warmup time - Batch 1: 0.0394 s
benchmark time (Batch 1): 0.3884 s
Per Iter time(Batch 1): 0.0388 s


## IPEX - BFloat16

In [22]:
model = PetModel(arch, encoder_name, in_channels, out_classes)
model.eval()
model = ipex.optimize(model, dtype=torch.bfloat16)



In [26]:
time = timeit.timeit(lambda: benchmark(model, batch['image']), number=1)
print(f"warmup time - Batch 16: {time:.4f} s")
time = timeit.timeit(lambda: benchmark(model, batch['image']), number=10)
print(f"benchmark time (Batch 16): {time:.4f} s")
time = time/10
print(f"Per Iter time(Batch 16): {time:.4f} s")

time = timeit.timeit(lambda: benchmark(
    model, batch['image'][0].unsqueeze(0)), number=1)
print(f"warmup time - Batch 1: {time:.4f} s")
time = timeit.timeit(lambda: benchmark(
    model, batch['image'][0].unsqueeze(0)), number=10)
print(f"benchmark time (Batch 1): {time:.4f} s")
time = time/10
print(f"Per Iter time(Batch 1): {time:.4f} s")


warmup time - Batch 16: 0.7158 s
benchmark time (Batch 16): 6.3737 s
Per Iter time(Batch 16): 0.6374 s
warmup time - Batch 1: 0.0379 s
benchmark time (Batch 1): 0.4146 s
Per Iter time(Batch 1): 0.0415 s


## Inference - TorchScript Mode

In [29]:
model = PetModel(arch, encoder_name, in_channels, out_classes).eval()
model = ipex.optimize(model, dtype=torch.float32)
model = torch.jit.trace(model, batch['image'])
model = torch.jit.freeze(model)


In [33]:
time = timeit.timeit(lambda: benchmark(model, batch['image']), number=1)
print(f"warmup time - Batch 16: {time:.4f} s")
time = timeit.timeit(lambda: benchmark(model, batch['image']), number=10)
print(f"benchmark time (Batch 16): {time:.4f} s")
time = time/10
print(f"Per Iter time(Batch 16): {time:.4f} s")


warmup time - Batch 16: 0.5830 s
benchmark time (Batch 16): 4.4572 s
Per Iter time(Batch 16): 0.4457 s


In [34]:
model = PetModel(arch, encoder_name, in_channels, out_classes).eval()
model = ipex.optimize(model, dtype=torch.float32)
model = torch.jit.trace(model, batch['image'][0].unsqueeze(0))
model = torch.jit.freeze(model)


In [36]:
time = timeit.timeit(lambda: benchmark(
    model, batch['image'][0].unsqueeze(0)), number=1)
print(f"warmup time - Batch 1: {time:.4f} s")
time = timeit.timeit(lambda: benchmark(
    model, batch['image'][0].unsqueeze(0)), number=10)
print(f"benchmark time (Batch 1): {time:.4f} s")
time = time/10
print(f"Per Iter time(Batch 1): {time:.4f} s")


warmup time - Batch 1: 0.1001 s
benchmark time (Batch 1): 0.5284 s
Per Iter time(Batch 1): 0.0528 s
