<br>

<div align=center><font color=maroon size=6><b>Introduction to TorchScript</b></font></div>

<br>

<font size=4><b>References:</b></font>
1. Pytorch official tutorials: <a href="https://pytorch.org/tutorials/index.html" style="text-decoration:none;">WELCOME TO PYTORCH TUTORIALS</a>
    * `Tutorials > `<a href="https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html" style="text-decoration:none;">Introduction to TorchScript</a>
    * 
    * Docs > <a href="https://pytorch.org/docs/stable/jit.html" style="text-decoration:none;">TorchScript</a>
    * Docs > <a href="https://pytorch.org/cppdocs/" style="text-decoration:none;">PyTorch C++ API</a> 

<br>
<br>
<br>

<font size=3>James Reed (jamesreed@fb.com), Michael Suo (suo@fb.com), rev2</font>

<br>

This tutorial is an introduction to <u><font size=3 color=blue><b>TorchScript</b></font>, an intermediate representation of a PyTorch model (subclass of `nn.Module`) that can then be run in a high-performance environment such as C++.</u>

1. The basics of model authoring in PyTorch, including:
    * Modules

    * Defining `forward` functions

    * Composing modules into a hierarchy of modules

2. Specific methods for converting PyTorch modules to TorchScript, our high-performance deployment runtime
    * Tracing an existing module
    * Using scripting to directly compile a module
    * How to compose both approaches
    * Saving and loading TorchScript modules

<a href="" style="text-decoration:none;"></a>

We hope that after you complete this tutorial, you will proceed to go through <a href="https://pytorch.org/tutorials/advanced/cpp_export.html" style="text-decoration:none;">the follow-on tutorial</a> which will walk you through an example of actually calling a TorchScript model from C++.

In [1]:
# This is all you need to use both PyTorch and TorchScript!

import torch

In [2]:
torch.__version__

'1.10.2'

<br>

# Basics of PyTorch Model Authoring

Let’s start out be defining a simple Module. A Module is the basic unit of composition in PyTorch. It contains:

* A constructor, which prepares the module for invocation
* A set of `Parameters` and sub-`Modules`. These are initialized by the constructor and can be used by the module during invocation.
* A `forward` function. <font color=maroon>This is the code that is run when the module is invoked.</font>

Let’s examine a small example:

In [3]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
       
    
    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        
        return new_h, new_h

In [4]:
mycell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)

mycell(x, h)

(tensor([[0.7989, 0.6132, 0.6792, 0.5400],
         [0.8511, 0.9035, 0.8611, 0.6989],
         [0.8547, 0.6441, 0.6022, 0.8200]]),
 tensor([[0.7989, 0.6132, 0.6792, 0.5400],
         [0.8511, 0.9035, 0.8611, 0.6989],
         [0.8547, 0.6441, 0.6022, 0.8200]]))

So we’ve:

* Created a class that subclasses `torch.nn.Module`.
* Defined a constructor. The constructor doesn’t do much, just calls the constructor for `super`.
* Defined a `forward` function, which takes two inputs and returns two outputs. The actual contents of the `forward` function are not really important, but it’s sort of a fake <a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/" style="text-decoration:none;">RNN cell</a>–that is–it’s a function that is applied on a loop.


We instantiated the module, and made `x` and `h`, which are just `3x4` matrices of random values. Then we invoked the cell with `my_cell(x, h)`. This in turn calls our `forward` function.

<br>

Let’s do something a little more interesting:

In [5]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
    
    
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        
        return new_h, new_h

In [6]:
my_cell = MyCell()
print(my_cell)

my_cell(x, h)

MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)


(tensor([[ 0.1444, -0.0906,  0.6671,  0.6956],
         [ 0.2240,  0.3039,  0.7659,  0.8195],
         [ 0.0999, -0.4597,  0.5132,  0.9275]], grad_fn=<TanhBackward0>),
 tensor([[ 0.1444, -0.0906,  0.6671,  0.6956],
         [ 0.2240,  0.3039,  0.7659,  0.8195],
         [ 0.0999, -0.4597,  0.5132,  0.9275]], grad_fn=<TanhBackward0>))

<br>

... ... 

What exactly is happening here? `torch.nn.Linear` is a `Module` from the PyTorch standard library. Just like `MyCell`, it can be invoked using the call syntax. <font color=maroon>We are building a hierarchy of `Modules`.</font>

<font color=maroon>`print` on a `Module` will give a visual representation of the Module’s subclass hierarchy. In our example, we can see our `Linear` subclass and its parameters.</font>

By composing `Modules` in this way, we can succintly and readably <font color=maroon>author models with reusable components.</font>

You may have noticed `grad_fn` on the outputs. <font color=maroon>This is a detail of PyTorch’s method of automatic differentiation, called <a href="" style="text-decoration:none;"><font size=3><b>autograd</b></font></a>. In short, this system allows us to compute derivatives through potentially complex programs. The design allows for a massive amount of flexibility in model authoring.</font>

<br>

Now let’s examine said flexibility:

In [7]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

In [8]:
my_cell = MyCell()
print(my_cell)

my_cell(x, h)

MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)


(tensor([[0.2316, 0.3318, 0.6326, 0.1353],
         [0.7614, 0.7721, 0.7870, 0.2646],
         [0.4776, 0.0963, 0.4394, 0.6265]], grad_fn=<TanhBackward0>),
 tensor([[0.2316, 0.3318, 0.6326, 0.1353],
         [0.7614, 0.7721, 0.7870, 0.2646],
         [0.4776, 0.0963, 0.4394, 0.6265]], grad_fn=<TanhBackward0>))

<br>

We’ve once again redefined our MyCell class, but here we’ve defined MyDecisionGate. This module utilizes **control flow**. Control flow consists of things like loops and if-statements.

<font size=3>Many frameworks take the approach of `computing symbolic derivatives` given a full program representation. However, in PyTorch, we use a `gradient tape`. <font color=maroon>We record operations as they occur, and replay them backwards in computing derivatives. In this way, the framework does not have to explicitly define derivatives for all constructs in the language.</font></font>

<img src="./images/How autograd works.png" width=600px>

<div align=center><font size=2>How autograd works</font></div>

<br>
<br>
<br>

# Basics of TorchScript

Now let’s take our running example and see how we can apply TorchScript.

In short, TorchScript provides tools to capture the definition of your model, even in light of the flexible and dynamic nature of PyTorch. Let’s begin by examining what we call <font size=3 color=maroon><b>tracing</b></font>.

## <font style="color:red;font-size:110%;font-weight:bold">Tracing `Modules`</font>

<font style="color:red;font-size:180%;font-weight:bold;">torch.jit.trace()</font>

In [9]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        
        return new_h, new_h

In [10]:
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)

traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell)

traced_cell(x, h)

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)


(tensor([[0.8850, 0.2197, 0.3758, 0.6387],
         [0.4030, 0.5595, 0.7519, 0.7046],
         [0.8786, 0.5741, 0.8806, 0.3584]], grad_fn=<TanhBackward0>),
 tensor([[0.8850, 0.2197, 0.3758, 0.6387],
         [0.4030, 0.5595, 0.7519, 0.7046],
         [0.8786, 0.5741, 0.8806, 0.3584]], grad_fn=<TanhBackward0>))

<br>

We’ve rewinded a bit and taken the second version of our `MyCell` class. As before, we’ve instantiated it, but this time, we’ve called **`torch.jit.trace`**, passed in the `Module`, and passed in example inputs the network might see.

<font color=maroon>What exactly has this done? It has invoked the `Module`, recorded the operations that occured when the `Module` was run, and created an instance of **`torch.jit.ScriptModule`** (of which `TracedModule` is an instance)</font>

<br>

### `.graph` property

TorchScript records its definitions in an <font size=3 color=maroon><b>Intermediate Representation</b></font> (or <font color=maroon size=3>IR</font>), commonly referred to in Deep learning as a graph. We can examine the graph with the **`.graph`** property:

In [11]:
traced_cell.graph

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # C:\Users\18617\AppData\Local\Temp/ipykernel_13796/2435684823.py:7:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # C:\Users\18617\AppData\Local\Temp/ipykernel_13796/2435684823.py:7:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # C:\Users\18617\AppData\Local\Temp/ipykernel_13796/2435684823.py:7:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)

<br>

### `.code` property

However, this is a very low-level representation and most of the information contained in the graph is not useful for end users. Instead, we can use the **`.code`** property to give a Python-syntax interpretation of the code:

In [12]:
traced_cell.code

'def forward(self,\n    x: Tensor,\n    h: Tensor) -> Tuple[Tensor, Tensor]:\n  linear = self.linear\n  _0 = torch.tanh(torch.add((linear).forward(x, ), h))\n  return (_0, _0)\n'

In [13]:
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



<br>
<br>

## Why Tracing Modules?

So why did we do all this? There are several reasons:

* TorchScript code can be invoked in <font color=maroon>its own interpreter</font>, which is basically <font color=maroon>a restricted Python interpreter</font>. This interpreter does not acquire the `Global Interpreter Lock`, and so many requests can be processed on the same instance simultaneously.


* <font color=maroon>This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python.</font>


* <font color=maroon>TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution.</font>


* <font color=maroon>TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.</font>

We can see that invoking `traced_cell` produces the same results as the Python module:

In [14]:
my_cell(x, h)

(tensor([[0.8850, 0.2197, 0.3758, 0.6387],
         [0.4030, 0.5595, 0.7519, 0.7046],
         [0.8786, 0.5741, 0.8806, 0.3584]], grad_fn=<TanhBackward0>),
 tensor([[0.8850, 0.2197, 0.3758, 0.6387],
         [0.4030, 0.5595, 0.7519, 0.7046],
         [0.8786, 0.5741, 0.8806, 0.3584]], grad_fn=<TanhBackward0>))

In [15]:
traced_cell(x, h)

(tensor([[0.8850, 0.2197, 0.3758, 0.6387],
         [0.4030, 0.5595, 0.7519, 0.7046],
         [0.8786, 0.5741, 0.8806, 0.3584]], grad_fn=<TanhBackward0>),
 tensor([[0.8850, 0.2197, 0.3758, 0.6387],
         [0.4030, 0.5595, 0.7519, 0.7046],
         [0.8786, 0.5741, 0.8806, 0.3584]], grad_fn=<TanhBackward0>))

<br>
<br>
<br>

# Using Scripting to Convert Modules

There’s a reason we used version two of our module, and not the one with the control-flow-laden submodule. 

Let’s examine that now:

In [16]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x



class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        
        return new_h, new_h

In [17]:
my_cell = MyCell(MyDecisionGate())

traced_cell = torch.jit.trace(my_cell, (x, h))

  if x.sum() > 0:


In [18]:
print(traced_cell.dg.code)

def forward(self,
    argument_1: Tensor) -> NoneType:
  return None



In [19]:
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (linear).forward(x, )
  _1 = (dg).forward(_0, )
  _2 = torch.tanh(torch.add(_0, h))
  return (_2, _2)



<br>

Looking at the `.code` output, we can see that the `if-else` branch is nowhere to be found! Why? <font color=maroon>Tracing does exactly what we said it would: run the code, record the operations that happen and construct a ScriptModule that does exactly that. Unfortunately, <b>things like control flow are erased</b>.</font>

<br>
<br>

## <font style="color:red;font-size:110%;font-weight:bold">script `compiler`</font>

<font style="color:red;font-size:180%;font-weight:bold;">torch.jit.script()</font>

How can we faithfully represent this module in TorchScript? We provide a <font size=3 color=maroon><b>script compiler</b></font>, which does direct analysis of your Python source code to transform it into TorchScript. 

Let’s convert `MyDecisionGate` using the script compiler:

In [20]:
scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

In [21]:
print(scripted_gate.code)

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0



In [22]:
print(scripted_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



Hooray! We’ve now faithfully captured the behavior of our program in TorchScript. 

Let’s now try running the program:

In [23]:
x, h = torch.rand(3, 4), torch.rand(3, 4)   # New inputs
traced_cell(x, h)

(tensor([[-0.0620,  0.4553,  0.8167,  0.2478],
         [ 0.5178,  0.6599,  0.3854,  0.4780],
         [ 0.4654,  0.6470,  0.7033,  0.0836]], grad_fn=<TanhBackward0>),
 tensor([[-0.0620,  0.4553,  0.8167,  0.2478],
         [ 0.5178,  0.6599,  0.3854,  0.4780],
         [ 0.4654,  0.6470,  0.7033,  0.0836]], grad_fn=<TanhBackward0>))

<br>
<br>

## Mixing Scripting and Tracing

Some situations call for using tracing rather than scripting (e.g. a module has many architectural decisions that are made based on constant Python values that we would like to not appear in TorchScript). In this case, scripting can be composed with tracing: **`torch.jit.script`** will inline the code for a traced module, and tracing will inline the code for a scripted module.

An example of the first case:

In [24]:
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
        
        
    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
            
        return y, h

In [25]:
rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)



<br>

And an example of the second case:

In [26]:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        
        return torch.relu(y)

In [27]:
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



This way, scripting and tracing can be used when the situation calls for each of them and used together.

<br>
<br>
<br>

# Saving and Loading models

<font color=maroon size=3>We provide APIs to save and load TorchScript modules to/from disk in an archive format. This format includes code, parameters, attributes, and debug information, meaning that the archive is a freestanding representation of the model that can be loaded in an entirely separate process. </font>
    
    
Let’s save and load our wrapped RNN module:

In [28]:
traced.save('./model_weights/wrapped_rnn.pt')

loaded = torch.jit.load('./model_weights/wrapped_rnn.pt')

print(loaded)
print(loaded.code)

RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



As you can see, **serialization** preserves the module hierarchy and the code we’ve been examining throughout. The model can also be loaded, for example, <a href="https://pytorch.org/tutorials/advanced/cpp_export.html" style="text-decoration:none;">into C++</a> for python-free execution.

<br>
<br>
<br>

# Further Reading

We’ve completed our tutorial! For a more involved demonstration, check out the NeurIPS demo for converting machine translation models using TorchScript:

<a href="https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ" style="text-decoration:none;">https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ</a>

<br>
<br>
<br>