# TorchScript

TorchScript is an intermediate representation of a PyTorch model (subclass of nn.Module) that can then be run in a high-performance environment such as C++.

TorchScript从Pytorch 1.0开始引入，它是一种新的progamming model，是Python语言的一个子集，能够被TorchScript的编译器进行解析、编译和优化。编译过后的torchscript可以被序列化成文件，随后被C++的pytorch backend加载执行。

TorchSript能够支持torch包中的大量的operations，我们可以把torch理解为torchscript这个语言的标准库，使用这个标准库中的Tensor操作组成的一系列运算都能被TorchScript的编译器编译。同时，我们也会存在实现一些特殊算子的需求，我们可以使用C++/CUDA来扩展算子，这些算子基于ATen来实现，它们能够被TorchScript编译，进而其序列文件可以被Python或C++来加载

## Define the Pytorch NN Module

In [1]:
import torch


class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x


class MyCell(torch.nn.Module):
    def __init__(self, dg=None):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        if self.dg is None:
            new_h = torch.tanh(self.linear(x) + h)
        else:
            new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

In [2]:
cell = MyCell()
x = torch.randn(3, 4)
h = torch.randn(3, 4)

print(cell)
print(cell(x, h))

MyCell(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.4873,  0.9206, -0.6983, -0.6580],
        [-0.9721,  0.7046,  0.2144, -0.9662],
        [-0.0251,  0.6369,  0.9375, -0.3248]], grad_fn=<TanhBackward0>), tensor([[ 0.4873,  0.9206, -0.6983, -0.6580],
        [-0.9721,  0.7046,  0.2144, -0.9662],
        [-0.0251,  0.6369,  0.9375, -0.3248]], grad_fn=<TanhBackward0>))


# Tracing Module

In [3]:
traced_cell = torch.jit.trace(cell, (x, h))
print(traced_cell)
print(traced_cell(x, h))

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)
(tensor([[ 0.4873,  0.9206, -0.6983, -0.6580],
        [-0.9721,  0.7046,  0.2144, -0.9662],
        [-0.0251,  0.6369,  0.9375, -0.3248]], grad_fn=<TanhBackward0>), tensor([[ 0.4873,  0.9206, -0.6983, -0.6580],
        [-0.9721,  0.7046,  0.2144, -0.9662],
        [-0.0251,  0.6369,  0.9375, -0.3248]], grad_fn=<TanhBackward0>))


What exactly has this done? It has invoked the Module, recorded the operations that occurred when the Module was run, and created an instance of torch.jit.ScriptModule (of which TracedModule is an instance)

TorchScript records its definitions in an Intermediate Representation (or IR), commonly referred to in Deep learning as a graph. We can examine the graph with the .graph property:

In [4]:
print(traced_cell.graph)

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /tmp/ipykernel_267643/808754513.py:20:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /tmp/ipykernel_267643/808754513.py:20:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /tmp/ipykernel_267643/808754513.py:20:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)



this is a very low-level representation and most of the information contained in the graph is not useful for end users. Instead, we can use the .code property to give a Python-syntax interpretation of the code:

In [5]:
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



# Why use trace module

* TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.
* This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python
* TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution
* TorchScript allows us to inference with many backend/device runtimes that require a broader view of the program than individual operators.

# Uing Scripting to Convert Modules

## Tracing Module存在的问题

In [11]:
cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(cell, (x, h))
# MyDecisionGate编译出来的code里，直接返回了 torch.neg(argument_1)
print(traced_cell.dg.code)
print(traced_cell.code)

def forward(self,
    argument_1: Tensor) -> NoneType:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (linear).forward(x, )
  _1 = (dg).forward(_0, )
  _2 = torch.tanh(torch.add(_0, h))
  return (_2, _2)



  if x.sum() > 0:


Looking at the .code output, we can see that the if-else branch is nowhere to be found! Why? Tracing does exactly what we said it would: run the code, record the operations that happen and construct a ScriptModule that does exactly that. Unfortunately, things like control flow are erased.

Tracing只是精确的记录：在编译模型时，给定的输入时，整个代码的执行的过程。

## Script Compiler

script compiler 并不像 tracking那样记录整个Tensor的执行过程，而是像其他语言编译器一样，直接分析源代码，不需要模拟计算。

How can we faithfully represent this module in TorchScript? We provide a script compiler, which does direct analysis of your Python source code to transform it into TorchScript. Let’s convert MyDecisionGate using the script compiler:

In [12]:
scripted_cell = torch.jit.script(cell)
print(scripted_cell.dg.code)
print(scripted_cell.code)

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



# Mixing Scripting and Tracing

Some situations call for using tracing rather than scripting (e.g. a module has many architectural decisions that are made based on constant Python values that we would like to not appear in TorchScript). In this case, scripting can be composed with tracing: `torch.jit.script` will inline the code for a traced module, and tracing will inline the code for a scripted module.

我们可以把代码中存在分支条件的部分，先用jit.script来编译；然后再在整个执行上使用jit.trace来转换。

In [16]:
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_cell.dg), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h


# dg是jit.script的
# cell是git.trace的
# rnn_loop是jit.script的
rnn_loop = torch.jit.script(MyRNNLoop())
print("rnn loop: ")
print(rnn_loop.code)
print("rnn loop cell: ")
print(rnn_loop.cell.code)
print("rnn loop cell dg: ")
print(rnn_loop.cell.dg.code)

rnn loop: 
def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)

rnn loop cell: 
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)

rnn loop cell dg: 
def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0



And an example of the second case:

In [17]:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)


traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



# Saving and Loading models

In [18]:
traced.save("/tmp/wrapped_rnn.pt")
loaded = torch.jit.load("/tmp/wrapped_rnn.pt")
print(loaded)
print(loaded.code)

RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



# Tracing vs Scripting 

https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/

## 术语对齐

* Export: 将pytorch eager模式的代码转换为计算图（TS-format）
* Tracking: 使用`torch.jit.trace(model, input)`接口进行Export，机制是：给定一个输入后运行模型，计算所有执行的操作
* Scripting: 使用`torch.jit.script(model)`接口进行Export，机制是：解析python源代码，把代码编译为计算图
* TorchScript: 是一个有在上下文中被赋予多重含义的概念。大部分时候它是指被Export出来的计算图的表示；有时也表示Export方法。
* Scriptable: 是指一个Module可以被`torch.jit.script`调用成功
* Traceable: 是指一个Moduel可以被`jit.trace`导出成功。
* Generalize: 是指一个traced model能够泛化到其他的input上去。Scripted models一般来说不会存在泛化问题。
* Dynamic control flow：是依赖于输出的数据的控制操作，比如：
    ```python
    if x[0] == 4:
        x += 1
    ```

## Scripting存在问题

Scripting编译器只能支持一个很小的python子集，对于大部分的basic syntax支持的挺好，但是对于classes, range, zip，动态类型, Union，**kwargs，继承等就基本不支持了。

Scripting最大的问题是，它是一个黑盒，有时候即使导出成功了，也可能在一些边界的地方不work。没有一个清晰的list来说明哪些支持，哪些不支持。

为了追求scriptable，会导致我们敬小慎微的在安全区域内写代码，整个代码会因为缺少高级抽象，而变得混乱，代码质量下降。

## 让 traced model变得泛化

Tracking机制相比于Scripting来说，它的约束边界非常清晰。

1. Module要求是一个single-device的，不能是DataParallel的。
2. Module中的计算只能是一个TS-format格式的计算图的组合，不能调用进行tensor或numpy array转换或者调用OpenCV的函数等。
3. Module的输入只能是Tensor或者Tuple[Tensor]，或者Dict[str, Tensor]，或者是它们的nested组合。字典中的value必须是同一种类型。但module中的submodule没有输入类型的限制。
4. 在Tracking过程中，tensor.size(),tensor.size()[0],tensor.shape[2]等操作的结果会被认为是一个Tensor，它们在eager模式下都是int，这样做是为了保持执行过程被跟踪。

做到上面4点，一个Module往往是traceable的，但只是traecable还不够，我们需要generalization。下面的情况会破坏generatization

1. 动态控制流
2. 从Tensor中获取一个常量值：len(t)、t.item()、与numpy的一些转换等

In [19]:
def f1(x):
    return torch.arange(x.shape[0])


def f2(x):
    return torch.arange(len(x))


a = torch.rand(1)
b = torch.rand(2)

In [20]:
torch.jit.trace(f1, a)(b)

tensor([0, 1])

In [21]:
torch.jit.trace(f2, a)(b)

  return torch.arange(len(x))


tensor([0])

## scripting支持除了forward外的其他接口

In [43]:
class Detector(torch.nn.Module):
    do_keypoint: bool

    def forward(self, img):
        box = self.predict_boxes(img)
        if self.do_keypoint:
            kpts = self.predict_keypoint(img, box)

    @torch.jit.export
    def predict_boxes(self, img):
        pass

    @torch.jit.export
    def predict_keypoint(self, img, box):
        pass


model = Detector()
model.do_keypoint = True
scripted_model = torch.jit.script(model)

RuntimeError: 

predict_keypoint(__torch__.___torch_mangle_80.Detector self, Tensor img, Tensor box) -> NoneType:
Expected a value of type 'Tensor (inferred)' for argument 'box' but instead found type 'NoneType'.
Inferred 'box' to be of type 'Tensor' because it was not annotated with an explicit type.
:
  File "/tmp/ipykernel_267643/1239083116.py", line 7
      box = self.predict_boxes(img)
      if self.do_keypoint:
          kpts = self.predict_keypoint(img, box)
                 ~~~~~~~~~~~~~~~~~~~~~ <--- HERE


# `@script_if_tracking`

In [40]:
import torch


@torch.jit.script_if_tracing
def decision_gate(x):
    if x.sum() > 0:
        return x
    else:
        return -x


class MyCell(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(decision_gate(self.linear(x)) + h)
        return new_h, new_h


x = torch.randn(3, 4)
h = torch.randn(3, 4)
model = MyCell()
model(x, h)
traced_model = torch.jit.trace(model, (x, h))
traced_model(x, h)

(tensor([[ 0.6388, -0.2312,  0.8956, -0.9861],
         [ 0.8869,  0.0744,  0.9764,  0.4940],
         [ 0.9850,  0.5327,  0.8098, -0.1661]], grad_fn=<TanhBackward0>),
 tensor([[ 0.6388, -0.2312,  0.8956, -0.9861],
         [ 0.8869,  0.0744,  0.9764,  0.4940],
         [ 0.9850,  0.5327,  0.8098, -0.1661]], grad_fn=<TanhBackward0>))

# Dynamic Parallelism in Torchscript

https://pytorch.org/tutorials/advanced/torch-script-parallelism.html