### Tracing

Tracing is an export technique that runs our model with certain inputs and traces or records all operations executed into the model's graph.

The API can be simply used as torch.jit.trace(model, input).

A model is called "traceable" if torch.jit.trace(model, input) succeeds for standard input.

A simple example of tracing in PyTorch follows. Here we first define a custom model class and then instantiate it. We then trace the model instance by passing some sample inputs to it, like so -

In [2]:
import torch
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

net = MyModel() #PYTorch model. -> nn.Module.
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_net = torch.jit.trace(net, (x, h)) #Invoke to create ScriptModule. torch.jit.ScriptModule
print(traced_net)
traced_net(x, h)


MyModel(
  original_name=MyModel
  (linear): Linear(original_name=Linear)
)


(tensor([[-0.4162,  0.9441,  0.8803,  0.5506],
         [-0.0428,  0.8970,  0.8539,  0.7946],
         [-0.3904,  0.8450,  0.0802,  0.6143]], grad_fn=<TanhBackward0>),
 tensor([[-0.4162,  0.9441,  0.8803,  0.5506],
         [-0.0428,  0.8970,  0.8539,  0.7946],
         [-0.3904,  0.8450,  0.0802,  0.6143]], grad_fn=<TanhBackward0>))

During tracing, the Python code is automatically converted into the subset (TorchScript) of Python by recording only the actual operators on tensors and simply executing and discarding the other surrounding Python code.

* torch.jit.trace invokes the Module, records the computations that occur when the Module was run on the inputs, and then creates an instance of the torch.jit.ScriptModule, essentially code written in plain Python converted to the TorchScript mode.

TorchScript also records the model definitions in what is called an Intermediate Representation (or IR) or a graph that we can access with the .graph property of the traced model, like so -

In [4]:
print(traced_net.graph)


graph(%self.1 : __torch__.MyModel,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # <ipython-input-2-ff2b47965842>:8:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # <ipython-input-2-ff2b47965842>:8:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # <ipython-input-2-ff2b47965842>:8:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)



In [5]:
print(traced_net.code)


def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



In [8]:

print(traced_net(x, h))


(tensor([[ 0.3966, -0.0609,  0.7879,  0.7994],
        [ 0.3194, -0.1219,  0.6428,  0.9492],
        [ 0.7625, -0.1400, -0.1661,  0.8294]],
       grad_fn=<DifferentiableGraphBackward>), tensor([[ 0.3966, -0.0609,  0.7879,  0.7994],
        [ 0.3194, -0.1219,  0.6428,  0.9492],
        [ 0.7625, -0.1400, -0.1661,  0.8294]],
       grad_fn=<DifferentiableGraphBackward>))


Scripting
The second way of converting PyTorch modules to TorchScript format is scripting, which can be used with the torch.jit.script API.

With scripting, we can write our code directly in TorchScript mode introducing a certain level of verbosity. The support offered by the scripting technique is much wider than that offered by the tracing technique.

A simple example demonstrating scripting and also why scripting might be required over tracing is as follows -

In [None]:
import torch
class DecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class Model(torch.nn.Module):
    def __init__(self, dg):
        super(Model, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

net = Model(DecisionGate())
traced_net = torch.jit.trace(net, (x, h))

print(traced_net.dg.code)
# print(traced_net.code)


def forward(self,
    argument_1: Tensor) -> NoneType:
  return None



  if x.sum() > 0:


As can be seen in the warning produced in the output from the above code that attempts tracing, the control flow was totally erased in the traced model and hence the IR is incorrect. This happened because of how the tracing technique functions - it runs the model code, records the operations "that happen" and then constructs a ScriptModule that does just that and hence removes operations like control flow.

To get around this, we have scripting that directly analyzes our models written in Python source code to transform it into TorchScript mode.

The following code shows how scripting captured the model graph correctly.

In [11]:
scripted_gate = torch.jit.script(DecisionGate())

net = Model(scripted_gate)
scripted_net = torch.jit.script(net)

print(scripted_gate.code)
print(scripted_net.code)


def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



Difference Between Scripting and Tracing
In this section, we will detail the major points that differentiate the two techniques to convert PyTorch modules to the TorchScript` format - the tracing and the scripting techniques- from each other and highlight the benefits of using one over the other.



*   Tracing lets us use the dynamic tensor ops in Python as it records tensor operations. It cannot trace control flow, data structures, or Python constructs.
*  On the other hand, scripting, with some code changes, supports all of the features that are compatible with the JIT compiler, a full list of which can be found here. In addition, it preserves the Python control flow and offers wider support for data structures like lists or dictionaries.
* The generalizability of traced models needs to be ensured explicitly, while scripted models are always generalizable.

* Although scripting is a good way to support advanced graphs containing control flow etc., there are a plethora of things that are not supported by the JIT compiler, like classes, builtins like range and zip, dynamic types, etc. hence it limits us in our ability to use abstract types and advanced features of python as a programming language which eventually means that our code can get messy more often than not.

In any case, for Scripting and Tracing to work properly, the model must be a connected graph representable in the TorchScript format.

In [12]:
scripted_net.save('model.pt')
loaded_net = torch.jit.load('model.pt')

print(loaded_net)
print(loaded_net.code)


RecursiveScriptModule(
  original_name=Model
  (dg): RecursiveScriptModule(original_name=DecisionGate)
  (linear): RecursiveScriptModule(original_name=Linear)
)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



### Fun

Python provides the unique reference counter feature, which is used for memory management. The reference counter counts the total number of references made internally in Python to assign a value to a data object. When the reference counts reach zero, the assigned memory of the object is released.

In [16]:
import sys
a = [] #1
sys.getrefcount(a)


2

In [17]:
b = a
sys.getrefcount(a)

4

In [1]:
import time
from threading import Thread
COUNT = 100000000

def countdown(num):
    while num>0:
        num -= 1

start_time = time.time()
countdown(COUNT)
end_time = time.time()

print('Time taken in seconds -', end_time - start_time)

Time taken in seconds - 9.3363516330719


In [2]:
import time
from threading import Thread

COUNT = 100000000

def countdown(num):
    while num>0:
        num -= 1

thread1 = Thread(target=countdown, args=(COUNT//2,))
thread2 = Thread(target=countdown, args=(COUNT//2,))

start_time = time.time()
thread1.start()
thread2.start()
thread1.join()
thread2.join()
end_time = time.time()
print('Time taken in seconds -', end_time - start_time)

Time taken in seconds - 6.329412937164307


In [3]:
from multiprocessing import Pool
import time

COUNT = 100000000
def countdown(num):
    while num>0:
        num -= 1

if __name__ == '__main__':
    pool = Pool(processes=2)
    start_time = time.time()
    r1 = pool.apply_async(countdown, [COUNT//2])
    r2 = pool.apply_async(countdown, [COUNT//2])
    pool.close()
    pool.join()
    end_time = time.time()
    print('Time taken in seconds -', end_time - start_time)

Time taken in seconds - 7.0444176197052
