# Introduction To TorchScript

In [1]:
import torch
print(torch.__version__)

1.9.0+cu111


## 1. Basics of PyTorch Model Authoring

A ```Module``` is the basic unit of composition in PyTorch. It contains:
- A constructor, which prepares the module for invocation.
- A set of ```Parameters``` and sub-```Modules```. These are initialized by the constructor and can be used by the module using invocation.
- A ```forward``` function. This is the code that will run when the module is invoked.

In [2]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        
    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h, new_h
    
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

(tensor([[0.5077, 0.7341, 0.9227, 0.7825],
        [0.6454, 0.5516, 0.7677, 0.7136],
        [0.7743, 0.8029, 0.9093, 0.7055]]), tensor([[0.5077, 0.7341, 0.9227, 0.7825],
        [0.6454, 0.5516, 0.7677, 0.7136],
        [0.7743, 0.8029, 0.9093, 0.7055]]))


So:
1. Created a class that subclasses ```torch.nn.Module```
2. Defined a Constructor. The Constructor doesn't do much, just calls the constructor for ```super```.
3. Defined a ```forward``` function, which takes two inputs and returns two outputs. The actual contents of the ```forward``` function are not important.

It is a sort of Fake ```RNN-Cell``` it’s a function that is applied on a loop.

We instantiated the module, and made ```x``` and ```h```, which are just 3x4 matrices of random values. Then we invoked the cell with ```my_cell(x, h)```. This in turn calls our ```forward``` function.

In [4]:
# A little something more....!!
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
    
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.0566,  0.7795,  0.8494,  0.8585],
        [-0.0117,  0.5529,  0.8503,  0.8103],
        [-0.3661,  0.1831,  0.9443,  0.8218]], grad_fn=<TanhBackward>), tensor([[ 0.0566,  0.7795,  0.8494,  0.8585],
        [-0.0117,  0.5529,  0.8503,  0.8103],
        [-0.3661,  0.1831,  0.9443,  0.8218]], grad_fn=<TanhBackward>))


We’ve redefined the module ```MyCell```, but this time a ```self.linear``` attribute is added, and we invoke ```self.linear``` in the forward function.

```torch.nn.Linear``` is a Module from the PyTorch standard library. Just like ```MyCell```, it can be invoked using the call syntax. We are building a hierarchy of ```Module```s.

```print``` on a ```Module``` will give a visual representation of the ```Module```’s subclass hierarchy. In our example, we can see our ```Linear``` subclass and its parameters.

You may have noticed ```grad_fn``` on the outputs. This is a detail of PyTorch’s method of <b>automatic differentiation</b>, called [autograd](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html). In short, this system allows us to compute derivatives through potentially complex programs. The design allows for a massive amount of flexibility in model authoring.

### Examine Flexibity:

In [5]:
class DecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x
    
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = DecisionGate()
        self.linear = torch.nn.Linear(4, 4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))

MyCell(
  (dg): DecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.8774,  0.6067,  0.8555, -0.2656],
        [ 0.8366,  0.4300,  0.8122, -0.0296],
        [ 0.8353,  0.1606,  0.9389, -0.6636]], grad_fn=<TanhBackward>), tensor([[ 0.8774,  0.6067,  0.8555, -0.2656],
        [ 0.8366,  0.4300,  0.8122, -0.0296],
        [ 0.8353,  0.1606,  0.9389, -0.6636]], grad_fn=<TanhBackward>))


Once again redefined the ```MyCell``` class, also we have defined ```DecisionGate```. This module utilizes <b>control flow</b>. Control flow consists of things like ```loops``` and ```if``` statements.

Many frameworks take the approach of computing symbolic derivatives given a full program representation.

But <b>```PyTorch```</b> and <b>```TensorFlow```</b> use ```Gradient Tape```. First, record operations as they occur, and replay them backwards in computing derivatives. 

In this way, the framework does not have to explicitly define derivatives for all constructs in the language.

Working of ```autograd```

<img src="https://github.com/pytorch/pytorch/raw/master/docs/source/_static/img/dynamic_graph.gif"></img>

## 2. Basics of TorchScript

TorchScript provides tools to capture the definition of the model, even in light of flexible and dynamic nature of PyTorch.

### Tracing ```Modules``` 

Tracing

In [6]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
    
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    
    
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)


(tensor([[0.7824, 0.4234, 0.7828, 0.6842],
         [0.9050, 0.3459, 0.7464, 0.2836],
         [0.7555, 0.6908, 0.8771, 0.8293]], grad_fn=<TanhBackward>),
 tensor([[0.7824, 0.4234, 0.7828, 0.6842],
         [0.9050, 0.3459, 0.7464, 0.2836],
         [0.7555, 0.6908, 0.8771, 0.8293]], grad_fn=<TanhBackward>))

<b>What exactly has ```torch.jit.trace``` done...??</b>

It has invoked a ```Module```, recorded the operations that occurred when the ```Module``` was run, and created an instance of ```torch.jit.ScriptModule``` (```TracedModule``` is an instance of this)

TorchScript records its definitions in an ```Intermediate Representation``` (IR) referred to in Deep Learning as a <i>```graph```</i>.

We can examine the <i>graph</i> with the ```.graph``` property

In [7]:
print(traced_cell.graph)

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %18 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%18, %x)
  %11 : int = prim::Constant[value=1]() # <ipython-input-6-171edff523ea>:7:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # <ipython-input-6-171edff523ea>:7:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # <ipython-input-6-171edff523ea>:7:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)



In [10]:
# We can also print out Python-syntax interpretation of the code for end users.
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = torch.add((self.linear).forward(x, ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)



So why did we do all this? There are several reasons:

1. TorchScript code can be invoked in its own interpreter, which is basically a restricted Python Interpreter. This Interpreter does not acquire the <b>Global Interpreter Lock</b>, and so many requests can be processed on the same instance simultanously.

2. This format allows us to <b>save the whole model to disk and load it into another environment</b>, such as in server written in a language other than Python.

3. TorchScript gives a representation in which we can do <b>compiler optimizations</b> on the code to provide more efficient execution.

4. TorchScript allows us to interfere with many <b>backend/device runtimes</b> that require a broader view of the program than individual operators.

We can see that invoking ```traced_cell``` produces the same results as the Python Code./

In [11]:
print(my_cell(x, h))
print(traced_cell(x, h))

(tensor([[0.7824, 0.4234, 0.7828, 0.6842],
        [0.9050, 0.3459, 0.7464, 0.2836],
        [0.7555, 0.6908, 0.8771, 0.8293]], grad_fn=<TanhBackward>), tensor([[0.7824, 0.4234, 0.7828, 0.6842],
        [0.9050, 0.3459, 0.7464, 0.2836],
        [0.7555, 0.6908, 0.8771, 0.8293]], grad_fn=<TanhBackward>))
(tensor([[0.7824, 0.4234, 0.7828, 0.6842],
        [0.9050, 0.3459, 0.7464, 0.2836],
        [0.7555, 0.6908, 0.8771, 0.8293]], grad_fn=<TanhBackward>), tensor([[0.7824, 0.4234, 0.7828, 0.6842],
        [0.9050, 0.3459, 0.7464, 0.2836],
        [0.7555, 0.6908, 0.8771, 0.8293]], grad_fn=<TanhBackward>))


## 3. Using Scripting to Convert Modules

In [12]:
class DecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x
    

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)
    
    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h
    
my_cell = MyCell(DecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)

  if x.sum() > 0:


def forward(self,
    argument_1: Tensor) -> NoneType:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.dg
  _1 = (self.linear).forward(x, )
  _2 = (_0).forward(_1, )
  _3 = torch.tanh(torch.add(_1, h))
  return (_3, _3)



Looking at the ```.code``` output, we can see that the ```if-else``` branch is nowhere found...!!

```Tracing``` just runs the code, record the operations <i><b>that happen</b></i> and construct a ```ScriptModule``` and things like <b>Control Flow</b> are erased.

A <b>```script Compiler```</b>, which does the direct analysis of the Python source code to transform it into TorchScript. Let's Convert ```DecisionGate``` using the script compiler

In [16]:
script_gate = torch.jit.script(DecisionGate())

my_cell = MyCell(script_gate)
script_cell = torch.jit.script(my_cell)

print(script_gate.code)
print(script_cell.code)

print('Script Gate Graph: ', script_gate.graph)
print('Script Cell Graph: ', script_cell.graph)

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h))
  return (new_h, new_h)

Script Gate Graph:  graph(%self : __torch__.___torch_mangle_7.DecisionGate,
      %x.1 : Tensor):
  %3 : NoneType = prim::Constant()
  %5 : int = prim::Constant[value=0]() # <ipython-input-12-f2a457493ea5>:3:21
  %4 : Tensor = aten::sum(%x.1, %3) # <ipython-input-12-f2a457493ea5>:3:11
  %6 : Tensor = aten::gt(%4, %5) # <ipython-input-12-f2a457493ea5>:3:11
  %8 : bool = aten::Bool(%6) # <ipython-input-12-f2a457493ea5>:3:11
  %20 : Tensor = prim::If(%8) # <ipython-input-12-f2a457493ea5>:3:8
    block0():
      -> (%x.1)
    block1():
      %11 : Tensor = aten::neg(%x.1) # <ipython-input-12-f2a457493ea5>:6:19
      -> (%11)
  return (%20)

Script Cell Graph:  graph(

In [17]:
# Some new inputs for the Program
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell(x, h)

(tensor([[-0.0726,  0.0737,  0.0985,  0.8380],
         [-0.1994,  0.8374,  0.3498,  0.6823],
         [-0.1749,  0.3584,  0.6277,  0.8168]], grad_fn=<TanhBackward>),
 tensor([[-0.0726,  0.0737,  0.0985,  0.8380],
         [-0.1994,  0.8374,  0.3498,  0.6823],
         [-0.1749,  0.3584,  0.6277,  0.8168]], grad_fn=<TanhBackward>))

## 4. Mixing Scripting and Tracing

Some Situations call for using ```tracing``` rather than ```scripting```(e.g. a module has many architectural decisions that are made based on constant Python values that we would like to not appear in TorchScript). 

In that case, ```scripting ``` can be composed with ```tracing``` , ```torch.jit.script``` will inline the code for the traced module, and ```tracing``` will inline the code for a scripted module

### Example (First Case)

In [19]:
class RNNLoop(torch.nn.Module):
    def __init__(self):
        super(RNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(script_gate), (x, h))
    
    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h
    
rnn_loop = torch.jit.script(RNNLoop())
print(rnn_loop.code)
print('RNN Loop Graph: ', rnn_loop.graph)

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    _0 = (self.cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)

RNN Loop Graph:  graph(%self : __torch__.___torch_mangle_18.RNNLoop,
      %xs.1 : Tensor):
  %24 : bool = prim::Constant[value=1]() # <ipython-input-19-10d01ab8dd89>:8:8
  %6 : NoneType = prim::Constant()
  %2 : int = prim::Constant[value=3]() # <ipython-input-19-10d01ab8dd89>:7:27
  %3 : int = prim::Constant[value=4]() # <ipython-input-19-10d01ab8dd89>:7:30
  %20 : int = prim::Constant[value=0]() # <ipython-input-19-10d01ab8dd89>:8:31
  %5 : int[] = prim::ListConstruct(%2, %3)
  %h.1 : Tensor = aten::zeros(%5, %6, %6, %6, %6) # <ipython-input-19-10d01ab8dd89>:7:15
  %12 : int[] = prim::ListConstruct(%2, %3)
  %y.1 : Tensor = aten::zeros(%12, %6, %6, %6, %6) # <ipython-input-19-10d01ab8dd89>:7:34
  %21 : int

### Example (Second Case)

In [22]:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(RNNLoop())
    
    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)
    
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
print('Traced Graph: ', traced.graph)

def forward(self,
    xs: Tensor) -> Tensor:
  _0, y, = (self.loop).forward(xs, )
  return torch.relu(y)

Traced Graph:  graph(%self : __torch__.___torch_mangle_37.WrapRNN,
      %xs : Float(10, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu)):
  %22 : __torch__.___torch_mangle_36.RNNLoop = prim::GetAttr[name="loop"](%self)
  %18 : (Tensor, Tensor) = prim::CallMethod[name="forward"](%22, %xs)
  %19 : Tensor, %y : Tensor = prim::TupleUnpack(%18)
  %21 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::relu(%y) # <ipython-input-22-f8d7d0e5cdf2>:8:0
  return (%21)



## 5. Saving and Loading Models

In [23]:
traced.save('wrapped_rnn.pt')

In [25]:
loaded = torch.jit.load('wrapped_rnn.pt')

print(loaded)
print(loaded.code)
print(loaded.graph)

RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=RNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=DecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  _0, y, = (self.loop).forward(xs, )
  return torch.relu(y)

graph(%self.1 : __torch__.___torch_mangle_37.WrapRNN,
      %xs.1 : Tensor):
  %3 : __torch__.___torch_mangle_36.RNNLoop = prim::GetAttr[name="loop"](%self.1)
  %5 : (Tensor, Tensor) = prim::CallMethod[name="forward"](%3, %xs.1) # :0:0
  %6 : Tensor, %y.1 : Tensor = prim::TupleUnpack(%5)
  %9 : Tensor = aten::relu(%y.1) # <ipython-input-22-f8d7d0e5cdf2>:8:0
  return (%9)

