##### Introduction to Torchscript

Torchscript helps you to save your models in a way that will greatly improve its inference perfomance and aditionally allow you to encorporate your models into production code in C++ designed for fast inference. Inference using Torchscript is especially powerful when using a GPU, as we will see. 

This is especially relevant in practice, a good video on the topic is https://www.youtube.com/watch?v=St3gdHJzic0. 

Note that torchscript is currently no longer in development, a similar feature meant to be its successor is in active development. See https://pytorch.org/docs/stable/jit.html and  https://pytorch.org/docs/stable/export.html.

For the example, we are going to load the preloaded model restnet50 from `torchvision` which is based on https://arxiv.org/abs/1512.03385 and look how fast its inference is on images. 

In [9]:
import torch
import torchvision
import time # lets us track time
import numpy as np
from prettytable import PrettyTable # easily can do tables in prints 

In [None]:
def measure_inference_time(model, input_tensor, iterations=100, device = 'cpu'):
    model.to(device)
    model.eval()
    input_tensor = input_tensor.to(device)
    with torch.no_grad():
        start_time = time.time()
        for _ in range(iterations):
            _ = model(input_tensor)
        end_time = time.time()

    avg_inference_time = (end_time - start_time) / iterations
    return avg_inference_time

device = "cuda" if torch.cuda.is_available() else "cpu"

input_image = torch.randn(1,3,224,224)

model = torchvision.models.resnet34(pretrained=True)

time_without_torchscript = measure_inference_time(model,input_image, device=device)

#translate model to Torchscript
torchscript_model = torch.jit.script(model)
time_with_torchscript = measure_inference_time(torchscript_model, input_image,device=device)

table = PrettyTable()
table.field_names = ["Type", "Inference Time"]
table.add_row(["Normal PyTorch", f"{time_without_torchscript:.6f}"])
table.add_row(["TorchScript", f"{time_with_torchscript:.6f}"])

# Print the table
print(table)

# Compare results
print(f"TorchScript is {time_without_torchscript / time_with_torchscript:.2f}x faster than normal PyTorch.")




Average inference time over 100 Iterations (normal PyTorch): 0.005733 seconds
Average inference time over 100 Iterations (Torchscript): 0.002930 seconds
+----------------+----------------+
|      Type      | Inference Time |
+----------------+----------------+
| Normal PyTorch |    0.005733    |
|  TorchScript   |    0.002930    |
+----------------+----------------+
TorchScript is 1.96x faster than normal PyTorch.
