INTRODUCTION TO TORCHSCRIPT
---

This tutorial is an introduction to TorchScript, an intermediate representation of a PyTorch model (subclass of nn.Module) that can then be run in a high-performance environment such as C++.

In [1]:
import torch  # This is all you need to use both PyTorch and TorchScript!
print(torch.__version__)

1.13.0+cu116


# Basics of PyTorch Model Authoring

A Module is the basic unit of composition in PyTorch. It contains:

1. A constructor, which prepares the module for invocation
2. A set of Parameters and sub-Modules. These are initialized by the constructor and can be used by the module during invocation.
3. A forward function. This is the code that is run when the module is invoked.

In [3]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        # torch.nn.Linear is a Module from the PyTorch standard library. 
        # Just like MyCell, it can be invoked using the call syntax. 
        # We are building a hierarchy of Modules.
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell()
print(my_cell)
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

MyCell(
  (dg): MyDecisionGate()
  (linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[ 0.8623,  0.7705,  0.4927,  0.0870],
        [ 0.8000,  0.7732, -0.1148,  0.5965],
        [ 0.6882,  0.5490, -0.0672,  0.6857]], grad_fn=<TanhBackward0>), tensor([[ 0.8623,  0.7705,  0.4927,  0.0870],
        [ 0.8000,  0.7732, -0.1148,  0.5965],
        [ 0.6882,  0.5490, -0.0672,  0.6857]], grad_fn=<TanhBackward0>))


# Basics of TorchScript

In short, TorchScript provides tools to capture the definition of your model, even in light of the flexible and dynamic nature of PyTorch.

## Tracing Modules

In [8]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
# It has invoked the Module, recorded the operations that occured when the Module was run, 
# and created an instance of torch.jit.ScriptModule (of which TracedModule is an instance)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)


(tensor([[-0.1919, -0.3593, -0.0930,  0.6423],
         [ 0.2444, -0.4079,  0.4662,  0.2460],
         [ 0.0764, -0.6685,  0.1550,  0.6251]], grad_fn=<TanhBackward0>),
 tensor([[-0.1919, -0.3593, -0.0930,  0.6423],
         [ 0.2444, -0.4079,  0.4662,  0.2460],
         [ 0.0764, -0.6685,  0.1550,  0.6251]], grad_fn=<TanhBackward0>))

TorchScript records its definitions in an Intermediate Representation (or IR), commonly referred to in Deep learning as a graph. We can examine the graph with the .graph property:

In [5]:
print(traced_cell.graph)

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /tmp/ipykernel_2607982/260609686.py:7:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /tmp/ipykernel_2607982/260609686.py:7:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /tmp/ipykernel_2607982/260609686.py:7:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)



However, this is a very low-level representation and most of the information contained in the graph is not useful for end users. Instead, we can use the .code property to give a Python-syntax interpretation of the code:

In [6]:
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



We can see that invoking traced_cell produces the same results as the Python module:

In [9]:
print(my_cell(x, h))
print(traced_cell(x, h))

(tensor([[-0.1919, -0.3593, -0.0930,  0.6423],
        [ 0.2444, -0.4079,  0.4662,  0.2460],
        [ 0.0764, -0.6685,  0.1550,  0.6251]], grad_fn=<TanhBackward0>), tensor([[-0.1919, -0.3593, -0.0930,  0.6423],
        [ 0.2444, -0.4079,  0.4662,  0.2460],
        [ 0.0764, -0.6685,  0.1550,  0.6251]], grad_fn=<TanhBackward0>))
(tensor([[-0.1919, -0.3593, -0.0930,  0.6423],
        [ 0.2444, -0.4079,  0.4662,  0.2460],
        [ 0.0764, -0.6685,  0.1550,  0.6251]],
       grad_fn=<DifferentiableGraphBackward>), tensor([[-0.1919, -0.3593, -0.0930,  0.6423],
        [ 0.2444, -0.4079,  0.4662,  0.2460],
        [ 0.0764, -0.6685,  0.1550,  0.6251]],
       grad_fn=<DifferentiableGraphBackward>))


# Using Scripting to Convert Modules

 Problem with the one with the control-flow-laden submodule.

In [10]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)

def forward(self,
    argument_1: Tensor) -> Tensor:
  return torch.neg(argument_1)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)



  This is separate from the ipykernel package so we can avoid doing imports until


Looking at the .code output, we can see that the if-else branch is nowhere to be found! Why? Tracing does exactly what we said it would: run the code, record the operations that `happen` and construct a ScriptModule that does exactly that. Unfortunately, things like control flow are erased.

How can we faithfully represent this module in TorchScript? We provide a script compiler, which does direct analysis of your Python source code to transform it into TorchScript. Let’s convert MyDecisionGate using the script compiler:

In [11]:
scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



In [20]:
# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
scripted_cell(x, h)

(tensor([[0.7392, 0.2868, 0.7941, 0.5196],
         [0.5743, 0.4229, 0.8638, 0.8357],
         [0.5842, 0.4139, 0.8051, 0.5587]],
        grad_fn=<DifferentiableGraphBackward>),
 tensor([[0.7392, 0.2868, 0.7941, 0.5196],
         [0.5743, 0.4229, 0.8638, 0.8357],
         [0.5842, 0.4139, 0.8051, 0.5587]],
        grad_fn=<DifferentiableGraphBackward>))

# Mixing Scripting and Tracing

Some situations call for using tracing rather than scripting (e.g. a module has many architectural decisions that are made based on constant Python values that we would like to not appear in TorchScript). In this case, scripting can be composed with tracing: torch.jit.script will inline the code for a traced module, and tracing will inline the code for a scripted module.

In [21]:
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)



In [22]:
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



# Saving and Loading models

We provide APIs to save and load TorchScript modules to/from disk in an archive format. This format includes code, parameters, attributes, and debug information, meaning that the archive is a freestanding representation of the model that can be loaded in an entirely separate process. Let’s save and load our wrapped RNN module:

In [23]:
traced.save('wrapped_rnn.pt')

loaded = torch.jit.load('wrapped_rnn.pt')

print(loaded)
print(loaded.code)

RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)



As you can see, serialization preserves the module hierarchy and the code we’ve been examining throughout. The model can also be loaded, for example, into C++ for python-free execution.