# PyTorch in Production : Faster inference in PyTorch with TRTorch

In [None]:
import time 

import torch
import trtorch

import pickle
from collections import defaultdict
import torchvision.models as tvm

import numpy as np
import pandas as pd
import plotly.express as px
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt


DEVICE="cuda:1"
torch.cuda.set_device(DEVICE)
torch.backends.cudnn.benchmark = True


In [None]:
def benchmark_resolution(model, resolution, dtype, device):
    dummy_input = torch.ones(
        (1,3,resolution, resolution),dtype=dtype,device=device
    )

    # Warm up runs to prepare Cudnn Benchmark
    for warm_up_iter in range(10):
        prediction = model(dummy_input)
    
    # Benchmark
    with torch.no_grad():
        durations = list()
        for i in range(100):
            start = time.time()
            prediction = model(dummy_input)
            torch.cuda.synchronize()
            end = time.time()
            durations.append(end-start)
    return min(durations)

def benchmark(model, resolutions, dtype, device):
    results = [benchmark_resolution(model, resolution, dtype, device) for resolution in resolutions]
    return results

In [None]:
model = tvm.resnet101(pretrained=True)
model.cuda();
model.eval();

## Compile to TRTorch

#### FP32

In [None]:
settings = {
    "input_shapes":[
        {
            "min":[1,3,160,160],
            "opt":[1,3,160,160],
            "max":[1,3,1600,1600]
        }
    ],
    "op_precision":torch.float
}

traced_model = torch.jit.trace(model, torch.ones((1,3,160,160), device=DEVICE))
float_trt_model = trtorch.compile(traced_model, settings)

#### FP16

In [None]:
settings = {
    "input_shapes":[
        {
            "min":[1,3,160,160],
            "opt":[1,3,160,160],
            "max":[1,3,1600,1600]
        }
    ],
    "op_precision":torch.half
}

traced_model = torch.jit.trace(model, torch.ones((1,3,160,160), device=DEVICE))
half_trt_model = trtorch.compile(traced_model, settings)

### Benchmarks

In [None]:
RESOLUTIONS = [160,224,320,448,640,896,1280,1600]
benchmarks = {
    "PyTorch FP32" : benchmark(
        model,
        RESOLUTIONS,
        dtype=torch.float,
        device=DEVICE,
    ),
    "PyTorch FP16" : benchmark(
        model.half(),
        RESOLUTIONS,
        dtype=torch.half,
        device=DEVICE,
    ),
    "TRTorch FP32" : benchmark(
        float_trt_model,
        RESOLUTIONS,
        dtype=torch.float,
        device=DEVICE,
    ),
    "TRTorch FP16" : benchmark(
        half_trt_model,
        RESOLUTIONS,
        dtype=torch.half,
        device=DEVICE,
    )
}

In [None]:
values = benchmarks["PyTorch FP32"] + benchmarks["PyTorch FP16"] + benchmarks["TRTorch FP32"] + benchmarks["TRTorch FP16"]
models = ["PyTorch FP32"]*len(RESOLUTIONS) + ["PyTorch FP16"]*len(RESOLUTIONS)+ ["TRTorch FP32"]*len(RESOLUTIONS)+ ["TRTorch FP16"]*len(RESOLUTIONS)
df = pd.DataFrame(zip(RESOLUTIONS*4, values, models))
df.columns = ["Resolution", "Duration (s)", "Method"]
df["Image Resolution"] = df["Resolution"].apply(lambda x: RESOLUTIONS.index(x))


resolutions_ticks = [f"{resolution}x{resolution}" for resolution in RESOLUTIONS]
fig = px.bar(
    data_frame=df,
    x="Image Resolution",
    y="Duration (s)",
    color="Method",
    barmode="group",
    title="ResNet101 Inference Time", 
    height=500,
     

)
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = list(range(len(resolutions_ticks))),
        ticktext = resolutions_ticks
    ),
    bargroupgap=0.,bargap=0.3
)

fig.show()

### Compute the speedup factor at different resolutions

In [None]:
FIRST_METHOD = "PyTorch FP16"
SECOND_METHOD = "TRTorch FP16"


first_method_perf = df[df["Method"]==FIRST_METHOD].groupby('Resolution', as_index=False).first()[['Duration (s)']]
second_method_perf = df[df["Method"]==SECOND_METHOD].groupby('Resolution', as_index=False).first()[['Duration (s)']]
speedup = first_method_perf/ second_method_perf
speedup["Image Resolution"] = range(8)
speedup.columns = ["Ratio", "Image Resolution"]

fig = px.bar(
    data_frame=speedup,
    x="Image Resolution",
    y="Ratio",
    title=f"Inference Speed-Up Factor : {SECOND_METHOD} vs {FIRST_METHOD}",

)
fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = list(range(8)),
        ticktext = resolutions_ticks
    )
)
