# TorchScript

TorchScript is an intermediate representation of a PyTorch model (subclass of nn.Module) that can then be run in a high-performance environment such as C++.

TorchScript从Pytorch 1.0开始引入，它是一种新的progamming model，是Python语言的一个子集，能够被TorchScript的编译器进行解析、编译和优化。编译过后的torchscript可以被序列化成文件，随后被C++的pytorch backend加载执行。

TorchSript能够支持torch包中的大量的operations，我们可以把torch理解为torchscript这个语言的标准库，使用这个标准库中的Tensor操作组成的一系列运算都能被TorchScript的编译器编译。同时，我们也会存在实现一些特殊算子的需求，我们可以使用C++/CUDA来扩展算子，这些算子基于ATen来实现，它们能够被TorchScript编译，进而其序列文件可以被Python或C++来加载

## Define the Pytorch NN Module

In [26]:
import torch


class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x


class MyCell(torch.nn.Module):
    def __init__(self, dg=None):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        if self.dg is None:
            new_h = torch.tanh(self.linear(x) + h)
        else:
            new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

In [27]:
cell = MyCell()
x = torch.randn(3, 4)
h = torch.randn(3, 4)

print(cell)
print(cell(x, h))

MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[-0.7866, -0.7074, -0.6436,  0.9907],
        [-0.8610,  0.9400, -0.0895, -0.3651],
        [-0.8377,  0.2309,  0.6709, -0.5585]], grad_fn=<TanhBackward0>), tensor([[-0.7866, -0.7074, -0.6436,  0.9907],
        [-0.8610,  0.9400, -0.0895, -0.3651],
        [-0.8377,  0.2309,  0.6709, -0.5585]], grad_fn=<TanhBackward0>))


# Tracing Module

In [28]:
traced_cell = torch.jit.trace(cell, (x, h))
print(traced_cell)
print(traced_cell(x, h))

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)
(tensor([[-0.7866, -0.7074, -0.6436,  0.9907],
        [-0.8610,  0.9400, -0.0895, -0.3651],
        [-0.8377,  0.2309,  0.6709, -0.5585]], grad_fn=<TanhBackward0>), tensor([[-0.7866, -0.7074, -0.6436,  0.9907],
        [-0.8610,  0.9400, -0.0895, -0.3651],
        [-0.8377,  0.2309,  0.6709, -0.5585]], grad_fn=<TanhBackward0>))


What exactly has this done? It has invoked the Module, recorded the operations that occurred when the Module was run, and created an instance of torch.jit.ScriptModule (of which TracedModule is an instance)

TorchScript records its definitions in an Intermediate Representation (or IR), commonly referred to in Deep learning as a graph. We can examine the graph with the .graph property:

In [30]:
print(traced_cell.graph)

graph(%self.1 : __torch__.___torch_mangle_30.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.___torch_mangle_29.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /tmp/ipykernel_3199122/808754513.py:20:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /tmp/ipykernel_3199122/808754513.py:20:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /tmp/ipykernel_3199122/808754513.py:20:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)



this is a very low-level representation and most of the information contained in the graph is not useful for end users. Instead, we can use the .code property to give a Python-syntax interpretation of the code:

In [32]:
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



# Why use trace module

* TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.
* This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python
* TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution
* TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.

# Uing Scripting to Convert Modules

## Tracing Module存在的问题

In [34]:
cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(cell, (x, h))
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (linear).forward(x, )
  _1 = (dg).forward(_0, )
  _2 = torch.tanh(torch.add(_0, h))
  return (_2, _2)



  if x.sum() > 0:


Looking at the .code output, we can see that the if-else branch is nowhere to be found! Why? Tracing does exactly what we said it would: run the code, record the operations that happen and construct a ScriptModule that does exactly that. Unfortunately, things like control flow are erased.

How can we faithfully represent this module in TorchScript? We provide a script compiler, which does direct analysis of your Python source code to transform it into TorchScript. Let’s convert MyDecisionGate using the script compiler:

In [37]:
scripted_cell = torch.jit.script(cell)
print(scripted_cell.dg.code)
print(scripted_cell.code)

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



# Mixing Scripting and Tracking

Some situations call for using tracing rather than scripting (e.g. a module has many architectural decisions that are made based on constant Python values that we would like to not appear in TorchScript). In this case, scripting can be composed with tracing: `torch.jit.script` will inline the code for a traced module, and tracing will inline the code for a scripted module.

In [38]:
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_cell.dg), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h


rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)



And an example of the second case:

In [39]:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)


traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



# Saving and Loading models

In [40]:
traced.save("/tmp/wrapped_rnn.pt")
loaded = torch.jit.load("/tmp/wrapped_rnn.pt")
print(loaded)
print(loaded.code)

RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



# Tracing vs Scripting 

https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/