#  <font color='blue'>TorchScript</font>

---
<p align="center">
<img src="https://www.learnopencv.com/wp-content/uploads/2020/09/c3-w15-torchscript.png" width="1000">
</p>
<center> <a href="https://youtu.be/2awmrMRf0dA?t=397"> Reference: TorchScript and PyTorch JIT | Deep Dive</a><center>
   
---
    
**What is TorchScript, and why do we need it?**

TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency. For example, a model trained using PyTorch can be serialized and saved using TorchScript, and that serialized models can be loaded and used in the `C++` code. 

It makes it possible to train a model in PyTorch in an eager model, then exports the model to the production environment, where the python program may be disadvantageous for performance and multi-threading reasons. 

**We need TorchScript for the portability and performance of a PyTorch model:**

**Portability:** Models should be exportable to a wide verity of different environments, such as `C++`, mobile, embedded systems, etc. However, a tight couple with a  python environment makes it difficult. 

**Performance:** Common patterns in the neural network can be optimized further to improve inference latency and throughput. However, numerous other optimization techniques can not be applied due to the level of python's dynamism. 


**TorchScript does not just make the PyTorch model portable but also optimize it. How does TorchScript optimize the model?**

We know that the compiler optimizes the codes by converting codes to machine-friendly language (mostly bytecode) that speed up the execution. However, python is an interpreter language that executes code directly that leads to a compromise in execution speed. The TorchScript usage a concept called [JIT](http://aboullaite.me/understanding-jit-compiler-just-in-time-compiler/#:~:text=The%20JIT%20compiler%20is%20enabled,directly%20instead%20of%20interpreting%20it.) (Just-In-Time) compiler that runs with the interpreter and optimizes specific regions of the code, which are called hot components (loops, function calls) and converts into machine code. While executing the code with the interpreter, the machine-code generated by JIT-compiler will be executed instead of the interpreter output that makes execution fast. 

---

**Torchscript develops an intermediate representation (IR) in the form of a graph similar to [LLVM](https://llvm.org/docs/index.html) and uses the JIT compiler to optimize this intermediate representation.**



There is an example of intermediate representation (IR) for adding two tensors.

```
graph(%a : Long(),
      %b : Long()):
  %2 : int = prim::Constant[value=1]()
  %3 : Long() = aten::add(%a, %b, %2)
  return (%3)
```

The intermediate representation (**IR**) used in the TorchScript has many components that you can find [here](https://github.com/pytorch/pytorch/blob/53af9df557aff745edf24193ece784fd008c6f19/torch/csrc/jit/OVERVIEW.md). In this notebook, instead of going into theoretical details (not required to use TorchScript), we will see how the TorchScript transforms the PyTorch model to an intermediate graph representation (IR) and save the IR. We will also see how to load the IR model in python. In the coming section, we will also see how we can load IR in `C++`. 

## <font color='blue'> 1. Generating Intermediate Representation </font>

**There are two ways to get IR from python code:**

1. Tracing, and

2. Scripting.

Let's see both one-by-one.

###   <font color='green'>1.1. Tracing</font>

The tracer runs the function with the input tensor and records the tensor operations performed and turns that into a torch script module. **It won't preserve the control flow and other language features like data-structures**.

Torchscript provides a simple function `torch.jit.trace()` that takes the target function and the input tensor and returns the TorchScript formate (IR representation). The resultant TorchScript code is dependent on the input tensor shape. For example, if we apply a tensor of **batch size** `32`, then the TorchScript code only accepts tensor of **batch size** `32` to get the same answer obtained with PyTorch code. Thus we should be clear about the tensor shape before applying `torch.jit.trace()`.


**Let's take an example of a simple neural network and transform it into IR using `torch.jit.trace()`.**

In [1]:
'''
Converting a single layer neural network using torch.jit.trace
'''
import torch
import torch.nn as nn

# get device
device = 'cpu' if not torch.cuda.is_available() else 'cuda'

# a simple NN
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.layer1 = nn.Linear(in_features=4, 
                                out_features=2)
    def forward(self,x):
        x = self.layer1(x)
        return x
    

# Transform the model to device
net = NN().to(device)

**Note:** 

1. Various optimization is only supported on the CUDA device, so it is better to have it.

2. If we are transferring the PyTorch model to IR using TorchScript only for inference purposes, we must freeze (`requires_grad=False`) the model. It will make inference fast because it will not store gradients. 


In [2]:
# freeze the model
for parm in net.parameters():
    parm.requires_grad=False


# create a random input, a input is mandatory for IR.
x = torch.randn(1,4, device=device)

# Use torch script to get IR
trace_nn = torch.jit.trace(net, x)

print(trace_nn.graph) # To visualize the IR

graph(%self.1 : __torch__.NN,
      %input : Float(1:4, 4:1)):
  %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="layer1"](%self.1)
  %16 : Tensor = prim::CallMethod[name="forward"](%14, %input)
  return (%16)



**Let's call `trace_nn`.**

In [3]:
trace_nn(x)

tensor([[-0.3251,  0.0045]])

**Let's take another example to show that `torch.jit.trace()` does not create a perfect IR when there is a branch (`control statements`, `for loop`, etc.).**

In [4]:
def fun(x):
    if x.size(0) == 1:
        return x+3
    else:
        return x + 20


# here, x.size(0) = 1, so this input satisfies if-condition
x = torch.randn(1,4).to(device)

trace_fun = torch.jit.trace(fun,x)

print(trace_fun.graph_for(x))

graph(%x : Float(*, *)):
  %4 : int = prim::Constant[value=3]()
  %1 : int = prim::Constant[value=1]() # <ipython-input-4-33c58bd45e7c>:3:0
  %5 : Float(*, *) = aten::add(%x, %4, %1)
  return (%5)



  


We can see that the graph did not record the `if-else` condition; it just registered the tensor operations for the satisfied condition (for the given tensor `if` condition satisfied).

**We can also see the warning!!!**

**Now, let's choose another input such that it satisfies the `else`-condition.**


In [5]:
# here, x.size(0) = 2, so this input does not satisfy if-condition, so else-condition will be executed. 
x = torch.randn(2,4).to(device)

trace_fun = torch.jit.trace(fun, x)

print(trace_fun.graph_for(x))

graph(%x : Float(*, *)):
  %4 : int = prim::Constant[value=20]()
  %1 : int = prim::Constant[value=1]() # <ipython-input-4-33c58bd45e7c>:5:0
  %5 : Float(*, *) = aten::add(%x, %4, %1)
  return (%5)



  


**As you see, the value for `%4` got changed per the condition being met.**

Let's be more confident by calling the `trace_fun` with an input of dimension `(1, 4)` (`size(0) = 1`). 

In [6]:
# x.size(0) = 1
x = torch.ones(1, 4).to(device)

trace_fun(x)

tensor([[21., 21., 21., 21.]])

**We can see that even the first dimension of the input is one, the output is the result of the else condition. So the trace is failing to register the control-flow.**


**How can it be fixed? `Scripting` is the answer to it.**

###  <font color='green'>1.2. Scripting</font>


Scripting is another way of converting code from eager mode to script mode. It preserves the control-flow by lexing, parsing the whole target python code.   Applying `torch.jit.script()` or `@torch.jit.script (`decorator`) on a deep learning model, by default, it scripts the forward function and recursively scripts the submodules and function which are called inside the forward function.

**Let's see it by an example,**

Let's take the above function `fun` and check how it is preserving the conditional branch.

In [7]:
# note that we do not need any input to use script

scripted_fun = torch.jit.script(fun)

print(scripted_fun.graph) # To visualize the intermediate representation

graph(%x.1 : Tensor):
  %2 : int = prim::Constant[value=0]() # <ipython-input-4-33c58bd45e7c>:2:14
  %4 : int = prim::Constant[value=1]() # <ipython-input-4-33c58bd45e7c>:2:20
  %7 : int = prim::Constant[value=3]() # <ipython-input-4-33c58bd45e7c>:3:17
  %11 : int = prim::Constant[value=20]() # <ipython-input-4-33c58bd45e7c>:5:19
  %3 : int = aten::size(%x.1, %2) # <ipython-input-4-33c58bd45e7c>:2:7
  %5 : bool = aten::eq(%3, %4) # <ipython-input-4-33c58bd45e7c>:2:7
  %22 : Tensor = prim::If(%5) # <ipython-input-4-33c58bd45e7c>:2:4
    block0():
      %9 : Tensor = aten::add(%x.1, %7, %4) # <ipython-input-4-33c58bd45e7c>:3:15
      -> (%9)
    block1():
      %13 : Tensor = aten::add(%x.1, %11, %4) # <ipython-input-4-33c58bd45e7c>:5:15
      -> (%13)
  return (%22)



Note that `%3` records the value of `x ie %x.1` at dimension `0 ie %2` by calling `aten::size(%x.1, %2)`.  
This `%3` is now compared with `%4`, and the boolean value is stored in `%5` by calling `aten::eq(%3, %4)`.
    
Notice the expression `%22 : Tensor = prim::If(%5)`.  
This basically means that whatever is returned after the `if-else` block, store it in `%22`.  
`block0()` evaluates the `if-condition` and `block1()` evaluates the `else-condtion`.

**Note:** Do not need input to use the script. 

**The same flow can be recorded using the decorator `@torch.jit.script` too.**

**The following is the same example as shown above.**

In [8]:
@torch.jit.script
def fun(x):
    if x.size(0) == 1:
        return x+3
    else:
        return x+20
print(fun.graph)

graph(%x.1 : Tensor):
  %2 : int = prim::Constant[value=0]() # <ipython-input-8-f1a0376eb796>:3:14
  %4 : int = prim::Constant[value=1]() # <ipython-input-8-f1a0376eb796>:3:20
  %7 : int = prim::Constant[value=3]() # <ipython-input-8-f1a0376eb796>:4:17
  %11 : int = prim::Constant[value=20]() # <ipython-input-8-f1a0376eb796>:6:17
  %3 : int = aten::size(%x.1, %2) # <ipython-input-8-f1a0376eb796>:3:7
  %5 : bool = aten::eq(%3, %4) # <ipython-input-8-f1a0376eb796>:3:7
  %22 : Tensor = prim::If(%5) # <ipython-input-8-f1a0376eb796>:3:4
    block0():
      %9 : Tensor = aten::add(%x.1, %7, %4) # <ipython-input-8-f1a0376eb796>:4:15
      -> (%9)
    block1():
      %13 : Tensor = aten::add(%x.1, %11, %4) # <ipython-input-8-f1a0376eb796>:6:15
      -> (%13)
  return (%22)



In the above two blocks, we have seen how the IR graph gets defined for function.  However, we can also have a look at the underlying-function created when we use a `torch.jit.script()` over a function.


The code to get the function is again a single line; we need to call `fun.code` instead of `fun.graph`.

In [9]:
# A much better depiction of the function `fun` we defined previously.
print(fun.code)

def fun(x: Tensor) -> Tensor:
  if torch.eq(torch.size(x, 0), 1):
    _0 = torch.add(x, 3, 1)
  else:
    _0 = torch.add(x, 20, 1)
  return _0



###  <font color='green'>1.3. Scripting a Neural-Network</font>
 
Now let us apply the `torch.jit.script()` to the single-layer neural network (`net`), which we defined previously.


Once we script it, we can run the TorchScript code with input to get the information of the Tensors and the Operations underuse, to optimize the performance.


In [10]:
# create a tiny feed-forward layer and set its paramters to be non-trainable 
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.layer1 = nn.Linear(in_features=4, out_features=2)
        
    def forward(self,x):
        x = self.layer1(x)
        return x


# transform model to device
small_net = Network().to(device)

# freeze the model
for parm in small_net.parameters():
    parm.requires_grad=False

In [11]:
scripted_nn = torch.jit.script(small_net)
scripted_nn(x)
print(scripted_nn.graph_for(x))

print("Output is:\n ", scripted_nn(x))

graph(%self : __torch__.Network,
      %x.1 : Float(*, *)):
  %4 : __torch__.torch.nn.modules.linear.___torch_mangle_7.Linear = prim::GetAttr[name="layer1"](%self)
  %5 : Float(*, *) = prim::GetAttr[name="weight"](%4)
  %6 : Float(*) = prim::GetAttr[name="bias"](%4)
  %10 : Float(*, *) = aten::t(%5) # /home/prakash/anaconda3/envs/pl/lib/python3.7/site-packages/torch/nn/functional.py:1674:39
  %17 : int = prim::Constant[value=1]()
  %18 : Float(*, *) = aten::mm(%x.1, %10) # <string>:3:24
  %19 : Float(*, *) = aten::add(%6, %18, %17) # <string>:3:17
  return (%19)

Output is:
  tensor([[-0.5723,  0.0863]])


In the above cell, `%2` stores the attributes of `layer1`, which is basically the `nn.Linear` layer.    
`%4` stores the `weights` and `%5` stores the `bias`.  
`%16` stores the tranposed version of `%4 ie the weights` to calcualate `X * W.t` using the `aten::t()` function. 

`%29` stores the value of `X* W.t()` using the `aten::mm() ie the Matrix Multiplication` and `%30` returns the  addition of the bias to `%29`.

### <font color='green'>Saving and Loading the Scripted Model</font>

In [12]:
# We can also save and load this scripted-model.
torch.jit.save(scripted_nn, 'tiny_model.pt')

# Now we can load this model
reloaded_tiny_model = torch.jit.load('tiny_model.pt')

print("Output is ", reloaded_tiny_model(x))

Output is  tensor([[-0.5723,  0.0863]])


**Now that we know how to create an IR for a Pytorch-based model, we shall see how we can load the IR in `Libtorch` in the next section.**

## <font color='blue'>References:-</font>

1. <a href="https://www.youtube.com/watch?v=2awmrMRf0dA&feature=youtu.be" >Video on Torchscript</a>
2. <a href="https://github.com/pytorch/pytorch/blob/53af9df557aff745edf24193ece784fd008c6f19/torch/csrc/jit/OVERVIEW.md" >Pytorch docs on JIT</a> 