### TorchScript
TorchScript是PyTorch模型（nn.Module的子类）的中间表示，可以在高性能环境（例如C ++）中运行。

在本教程中，我们将介绍PyTorch中的模型创作基础，包括：

- 定义前向功能
- 将模块组成模块的层次结构
- 将PyTorch模块转换为TorchScript（我们的高性能部署运行时）的特定方法
- 跟踪现有模块
- 使用脚本直接编译模块
- 如何组合这两种方法
- 保存和加载TorchScript模块

https://pytorch.panchuang.net/EigthSection/torchScript/

In [26]:
import torch
import random
random.seed(42)
import torch.nn as nn
torch.manual_seed(42)

<torch._C.Generator at 0x7f402bcf3530>

### 不包含控制流

In [29]:
class MyCellV1(nn.Module):
    def __init__(self):
        super(MyCellV1, self).__init__()
        self.linear = nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell_v1 = MyCellV1()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell_v1)
print(my_cell_v1(x, h))

MyCellV1(
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.5042, -0.4369, -0.4018, -0.3723],
        [-0.0158, -0.0780,  0.0066, -0.4506],
        [ 0.6633,  0.0949,  0.5463,  0.1301]], grad_fn=<TanhBackward0>), tensor([[ 0.5042, -0.4369, -0.4018, -0.3723],
        [-0.0158, -0.0780,  0.0066, -0.4506],
        [ 0.6633,  0.0949,  0.5463,  0.1301]], grad_fn=<TanhBackward0>))


In [30]:
### 包含数据流的模型
class MyDecisionGate(nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x


class MyCellV2(nn.Module):
    def __init__(self):
        super(MyCellV2, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell_v2 = MyCellV2()
print(my_cell_v2)
print(my_cell_v2(x, h))


MyCellV2(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.2552,  0.6342,  0.2439, -0.1345],
        [ 0.0541,  0.7186,  0.1472,  0.1920],
        [ 0.6442,  0.7811,  0.5944,  0.6411]], grad_fn=<TanhBackward0>), tensor([[ 0.2552,  0.6342,  0.2439, -0.1345],
        [ 0.0541,  0.7186,  0.1472,  0.1920],
        [ 0.6442,  0.7811,  0.5944,  0.6411]], grad_fn=<TanhBackward0>))


### TorchScript-Tracing模块

In [31]:
# 非控制流模型
traced_cell_v1 = torch.jit.trace(my_cell_v1, (x, h))
print(traced_cell_v1)
print(traced_cell_v1.code)
traced_cell_v1(x, h)

MyCellV1(
  original_name=MyCellV1
  (linear): Linear(original_name=Linear)
)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



(tensor([[ 0.5042, -0.4369, -0.4018, -0.3723],
         [-0.0158, -0.0780,  0.0066, -0.4506],
         [ 0.6633,  0.0949,  0.5463,  0.1301]], grad_fn=<TanhBackward0>),
 tensor([[ 0.5042, -0.4369, -0.4018, -0.3723],
         [-0.0158, -0.0780,  0.0066, -0.4506],
         [ 0.6633,  0.0949,  0.5463,  0.1301]], grad_fn=<TanhBackward0>))

In [32]:
# 对于存在控制流的代码，直接使用trace进行导出，导出的模型不对
traced_cell_v2 = torch.jit.trace(my_cell_v2, (x, h))
print(traced_cell_v2)
print(traced_cell_v2.code)  # 没有if分支
print(traced_cell_v2(x, h))

MyCellV2(
  original_name=MyCellV2
  (dg): MyDecisionGate(original_name=MyDecisionGate)
  (linear): Linear(original_name=Linear)
)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)

(tensor([[ 0.2552,  0.6342,  0.2439, -0.1345],
        [ 0.0541,  0.7186,  0.1472,  0.1920],
        [ 0.6442,  0.7811,  0.5944,  0.6411]], grad_fn=<TanhBackward0>), tensor([[ 0.2552,  0.6342,  0.2439, -0.1345],
        [ 0.0541,  0.7186,  0.1472,  0.1920],
        [ 0.6442,  0.7811,  0.5944,  0.6411]], grad_fn=<TanhBackward0>))


  if x.sum() > 0:


In [51]:
class MyDecisionGateV2(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCellV3(nn.Module):
    def __init__(self, dg):
        super(MyCellV3, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)
        #self.linear = torch.ones(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

# my_cell_v3 = MyCellV3(MyDecisionGate())
# traced_cell_v3 = torch.jit.trace(my_cell_v3, (x, h))
# print(traced_cell_v3.code)
# print(traced_cell_v3.graph)
print("+++"*20)
# 对于存在控制流的模型, 正确的做法
scripted_gate = torch.jit.script(MyDecisionGateV2())
my_cell_v3 = MyCellV3(scripted_gate)
traced_cell_v3 = torch.jit.script(my_cell_v3)
print(scripted_gate.code)
print(traced_cell_v3.code)
print(traced_cell_v3.graph)
print(traced_cell_v3(x, h))

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)

graph(%self : __torch__.___torch_mangle_110.MyCellV3,
      %x.1 : Tensor,
      %h.1 : Tensor):
  %9 : int = prim::Constant[value=1]()
  %dg : __torch__.___torch_mangle_109.MyDecisionGateV2 = prim::GetAttr[name="dg"](%self)
  %linear : __torch__.torch.nn.modules.linear.___torch_mangle_7.Linear = prim::GetAttr[name="linear"](%self)
  %6 : Tensor = prim::CallMethod[name="forward"](%linear, %x.1) # /tmp/ipykernel_14520/123614427.py:16:35
  %7 : Tensor = prim::CallMethod[name="forward"](%dg, %6) # /tmp/ipykernel_14520/123614427.py:16:27
  %10 : Tensor = aten::add(%7, %h.1, %9) #

In [52]:


class MyCellV4(nn.Module):
    def __init__(self):
        super(MyCellV4, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
        #self.linear = torch.ones(4, 4)

    def forward(self, x, h):
        output = self.linear(x)
        if output.sum() > 0:
            output = output
        else:
            output = -output
        output += h
        new_h = torch.tanh(output)
        return new_h, new_h

my_cell_v4 = MyCellV4()
traced_cell_v4 = torch.jit.script(my_cell_v4)
print(traced_cell_v4.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  output = (linear).forward(x, )
  if bool(torch.gt(torch.sum(output), 0)):
    output0 = output
  else:
    output0 = torch.neg(output)
  output1 = torch.add_(output0, h)
  new_h = torch.tanh(output1)
  return (new_h, new_h)



In [56]:
traced_cell_v3(x,h), traced_cell_v4(x,h), print(traced_cell_v3.code), print(traced_cell_v4.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  output = (linear).forward(x, )
  if bool(torch.gt(torch.sum(output), 0)):
    output0 = output
  else:
    output0 = torch.neg(output)
  output1 = torch.add_(output0, h)
  new_h = torch.tanh(output1)
  return (new_h, new_h)



((tensor([[0.6765, 0.4108, 0.7917, 0.8712],
          [0.2232, 0.1857, 0.6127, 0.5740],
          [0.7781, 0.3506, 0.8757, 0.8357]],
         grad_fn=<DifferentiableGraphBackward>),
  tensor([[0.6765, 0.4108, 0.7917, 0.8712],
          [0.2232, 0.1857, 0.6127, 0.5740],
          [0.7781, 0.3506, 0.8757, 0.8357]],
         grad_fn=<DifferentiableGraphBackward>)),
 (tensor([[0.5366, 0.4260, 0.5049, 0.3330],
          [0.2708, 0.2315, 0.5867, 0.1803],
          [0.7839, 0.4473, 0.8336, 0.5838]], grad_fn=<TanhBackward0>),
  tensor([[0.5366, 0.4260, 0.5049, 0.3330],
          [0.2708, 0.2315, 0.5867, 0.1803],
          [0.7839, 0.4473, 0.8336, 0.5838]], grad_fn=<TanhBackward0>)),
 None,
 None)

In [43]:
class MyModule(nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = nn.Parameter(torch.randn(N, M))

    def forward(self, input):
        if input.sum() > 0:
            output = self.weight.mv(input)
        else:
            output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
print(sm.code)

def forward(self,
    input: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(input), 0)):
    weight = self.weight
    output = torch.mv(weight, input)
  else:
    weight0 = self.weight
    output = torch.add(weight0, input)
  return output



### 3.1 混合脚本(Scripting)和跟踪(Tracing)

In [39]:
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCellV3(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
#print(rnn_loop.cell.dg.code)

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)



In [40]:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



In [41]:
traced.save('wrapped_rnn.zip')

loaded = torch.jit.load('wrapped_rnn.zip')

print(loaded)
print(loaded.code)

RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCellV3
      (dg): RecursiveScriptModule(original_name=MyDecisionGateV2)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

